diff --git a/.bazelrc b/.bazelrc index 375b0547574278..14482254c4b838 100644 --- a/.bazelrc +++ b/.bazelrc @@ -225,13 +225,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -240,22 +243,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -382,6 +385,13 @@ build:windows --features=archive_param_file build:windows --copt=/d2ReducedOptimizeHugeFunctions build:windows --host_copt=/d2ReducedOptimizeHugeFunctions +# Before VS 2017 15.8, the member "type" would non-conformingly have an +# alignment of only alignof(max_align_t). VS 2017 15.8 was fixed to handle this +# correctly, but the fix inherently changes layout and breaks binary +# compatibility (*only* for uses of aligned_storage with extended alignments). +build:windows --copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE +build:windows --host_copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE + # Enable the runfiles symlink tree on Windows. This makes it possible to build # the pip package on Windows without an intermediate data-file archive, as the # build_pip_package script in its current form (as of Aug 2023) uses the @@ -569,10 +579,7 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" + # ROCm # TODO(rocm) Is this actualy used? @@ -609,6 +616,9 @@ build:rbe_win_clang --compiler=clang-cl build:rbe_win_clang --linkopt=/FORCE:MULTIPLE build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE +# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits. +build:rbe_windows_x86_cpu --config=rbe_win_clang + # END TF REMOTE BUILD EXECUTION OPTIONS # TFLite build configs for generic embedded Linux @@ -671,7 +681,6 @@ build:release_cpu_linux_base --linkopt="-fuse-ld=lld" # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -686,7 +695,6 @@ build:release_cpu_linux --config=release_cpu_linux_base # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -717,9 +725,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -829,17 +836,19 @@ test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/ # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. -# CPU PYCPP: +# LINUX CPU PYCPP: test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# CUDA PYCPP: + +# LINUX CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/python/integration_testing/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# ARM64 PYCPP + +# LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on # Linux x86 so that we can use RBE. Since tests still need to run on the single # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. @@ -872,6 +881,13 @@ build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +# WINDOWS X86-64 CPU PYCPP +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only +test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... + # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index 60af559a7c42cd..2013b8c0c9ea43 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.1" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.8.2" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index f457db55292998..8b1a034a26110f 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -38,7 +38,7 @@ jobs: run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index e72eab86787864..ceb213e46415cf 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -46,7 +46,7 @@ jobs: persist-credentials: false - name: "Run analysis" - uses: ossf/scorecard-action@dc50aa9510b46c811795eb24b2f1ba02a914e534 # v2.3.3 + uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 with: results_file: results.sarif results_format: sarif @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 + uses: actions/upload-artifact@0b2256b8c012f0828dc542b3febcab082c67f72b # v4.3.4 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b611370bb5703a7efb587f9d136a52ea24c5c38c # v3.25.11 + uses: github/codeql-action/upload-sarif@afb54ba388a7dca6ecae48f608c4ff05ff4cc77a # v3.25.15 with: sarif_file: results.sarif diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml index 2c81873f879719..c72cc988629422 100644 --- a/.github/workflows/sigbuild-docker-branch.yml +++ b/.github/workflows/sigbuild-docker-branch.yml @@ -43,16 +43,16 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - name: Login to DockerHub - uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Login to GCR - uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: gcr.io username: _json_key @@ -67,7 +67,7 @@ jobs: - name: Build and push id: docker_build - uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0 + uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml index 7de12e154f74c7..e21ddb0e507dca 100644 --- a/.github/workflows/sigbuild-docker-presubmit.yml +++ b/.github/workflows/sigbuild-docker-presubmit.yml @@ -47,15 +47,24 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - name: Login to GCR if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') - uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: gcr.io username: _json_key password: ${{ secrets.GCP_CREDS }} + - + name: Login to AR + # Once this is verified, change the label's name. For now, we will piggyback on gcr.io actions. + if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') + uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + with: + registry: us-central1-docker.pkg.dev + username: _json_key + password: ${{ secrets.GCP_CREDS }} - name: Grab the date to do cache busting (assumes same day OK to keep) run: | @@ -64,7 +73,7 @@ jobs: - name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied id: docker_build - uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0 + uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0 with: push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }} context: ./tensorflow/tools/tf_sig_build_dockerfiles @@ -74,6 +83,7 @@ jobs: CACHEBUSTER=${{ steps.date.outputs.DATE }} tags: | gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }} + us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ github.event.number }}-${{ matrix.python-version }} cache-from: | type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }} type=registry,ref=gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }} diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml index 062338edda107c..78e7fd75085523 100644 --- a/.github/workflows/sigbuild-docker.yml +++ b/.github/workflows/sigbuild-docker.yml @@ -46,20 +46,28 @@ jobs: uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@d70bba72b1f3fd22344832f00baa16ece964efeb # v3.3.0 + uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - name: Login to DockerHub - uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_TOKEN }} - name: Login to GCR - uses: docker/login-action@0d4c9c5ea7693da7b068278f7b52bda2a190a446 # v3.2.0 + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 with: registry: gcr.io username: _json_key password: ${{ secrets.GCP_CREDS }} + - + name: Login to AR + # Once this is verified, removed gcr.io actions. + uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 + with: + registry: us-central1-docker.pkg.dev + username: _json_key + password: ${{ secrets.GCP_CREDS }} - name: Grab the upcoming TF version to tag this container run: | @@ -74,7 +82,7 @@ jobs: - name: Build and push id: docker_build - uses: docker/build-push-action@15560696de535e4014efeff63c48f16952e52dd1 # v6.2.0 + uses: docker/build-push-action@5176d81f87c23d6fc96624dfdbcd9f3830bbe445 # v6.5.0 with: push: true context: ./tensorflow/tools/tf_sig_build_dockerfiles @@ -87,6 +95,8 @@ jobs: tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} gcr.io/tensorflow-sigs/build:latest-${{ matrix.python-version }} gcr.io/tensorflow-sigs/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} + us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:latest-${{ matrix.python-version }} + us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} cache-from: type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }} cache-to: type=inline - diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 89c61463462745..17b77f808d9c80 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -253,13 +253,21 @@ There are two ways to run TensorFlow unit tests. export flags="--config=opt -k" ``` - If the tests are to be run on the GPU, add CUDA paths to LD_LIBRARY_PATH and - add the `cuda` option flag + If the tests are to be run on the GPU: + * For TensorFlow versions starting from v.2.18.0: + Add the `cuda` option flag. - ```bash - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" - export flags="--config=opt --config=cuda -k" - ``` + ```bash + export flags="--config=opt --config=cuda -k" + ``` + + * For TensorFlow versions prior v.2.18.0: + Add CUDA paths to LD_LIBRARY_PATH and add the `cuda` option flag. + + ```bash + export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH" + export flags="--config=opt --config=cuda -k" + ``` For example, to run all tests under tensorflow/python, do: diff --git a/RELEASE.md b/RELEASE.md index d863eb2166127a..cfac66a90bb99c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -11,7 +11,23 @@ * `tf.lite` * C API: - * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step forward towards a cleaner API for `TfLiteOperator`. Function `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, released on 7/11/2024, and we do not expect there will be much code using this function yet. Any code breakages can be easily resolved by passing nullptr as the new, 4th parameter. + * An optional, fourth parameter was added `TfLiteOperatorCreate` as a step + forward towards a cleaner API for `TfLiteOperator`. Function + `TfLiteOperatorCreate` was added recently, in TensorFlow Lite version 2.17.0, + released on 7/11/2024, and we do not expect there will be much code using this + function yet. Any code breakages can be easily resolved by passing nullptr as + the new, 4th parameter. + * SignatureRunner is now supported for models with no signatures. + +* TensorRT support is disabled in CUDA builds for code health improvement. + +* Hermetic CUDA support is added. + + Hermetic CUDA uses a specific downloadable version of CUDA instead of the + user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL + distributions, and then use CUDA libraries and tools as dependencies in + various Bazel targets. This enables more reproducible builds for Google ML + projects and supported CUDA versions. ### Known Caveats @@ -35,6 +51,11 @@ should run synchronously, as opposed to be parallelizable when `options.experimental_optimization.map_parallelization=True`. This saves memory compared to setting `num_parallel_calls=1`. + * Add optional `use_unbounded_threadpool` argument to `map`, to specify that + the `map` should use an unbounded threadpool instead of the default pool + that is based on the number of cores on the machine. This can improve + throughput for map functions which perform IO or otherwise release the + CPU. * `tf.lite` * `Dequantize` op supports `TensorType_INT4`. * This change includes per-channel dequantization. diff --git a/WORKSPACE b/WORKSPACE index f8f467fccf5ce2..32ffd0433108c7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -64,3 +64,50 @@ tf_workspace1() load("@//tensorflow:workspace0.bzl", "tf_workspace0") tf_workspace0() + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/ci/devinfra/docker/windows/Dockerfile b/ci/devinfra/docker/windows/Dockerfile new file mode 100644 index 00000000000000..e1a7f949d5f48b --- /dev/null +++ b/ci/devinfra/docker/windows/Dockerfile @@ -0,0 +1,155 @@ +# This Dockerfile creates an image that has: +# - the correct MTU setting for networking from inside the container to work. +# - Visual Studio 2022 Build Tools +# - MSVC 14.39 +# - LLVM/Clang 18.1.4 +# - MSYS2 + curl, git, patch, vim, unzip, zip +# - Python 3.12.3 +# - Bazelisk 1.19.0 +# - JDK 21 (Azul Zulu) + +FROM mcr.microsoft.com/windows/servercore:ltsc2019 + +SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", \ + "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue';$VerbosePreference = 'Continue';"] + +# This should only be necessary when running on A GCP VM, on a default +# network, which has the MTU of 1460, +# due to 40 bytes being reserved for GCP's internal usage. +# Note, an invalid sub-interface name will lead to an obscure error, e.g.: +# "The filename, directory name, or volume label syntax is incorrect." +# In such cases, check that the name of the sub-interface is valid: +# `netsh interface show interface` +RUN netsh interface ipv4 set subinterface \"vEthernet (Ethernet)\" mtu=1460 store=persistent + +RUN md C:\TEMP +RUN md C:\TMP + +# Install 7-Zip. +RUN (New-Object Net.WebClient).DownloadFile('https://www.7-zip.org/a/7z2201-x64.msi', '7z.msi'); \ + Start-Process msiexec.exe -ArgumentList \"/i 7z.msi /qn /norestart /log C:\\TEMP\\7z_install_log.txt\" -wait; \ + Remove-Item .\7z.msi; + +# Download the Visual Studio 2022 Installer. +RUN (New-Object Net.WebClient).DownloadFile('https://aka.ms/vs/17/release/vs_community.exe', 'C:\TEMP\vs_community.exe'); +# Install Visual Studio 2022 Build Tools + Compiler +SHELL ["cmd", "/S", "/C"] +# Packages, and component versions, can be found here: +# https://learn.microsoft.com/en-us/visualstudio/install/workload-component-id-vs-build-tools +RUN C:\TEMP\vs_community.exe \ + --quiet --wait --norestart --nocache \ + --add Microsoft.VisualStudio.Component.VC.Tools.x86.x64 \ + --add Microsoft.VisualStudio.Workload.NativeDesktop \ + --add Microsoft.VisualStudio.Component.VC.14.39.17.9.x86.64 \ + --add Microsoft.VisualStudio.Component.Windows11SDK.22621 \ + || IF "%ERRORLEVEL%"=="3010" EXIT 0 + +SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", \ + "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"] + +# Install Clang. +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://github.com/llvm/llvm-project/releases/download/llvmorg-18.1.4/LLVM-18.1.4-win64.exe', \ + 'LLVM.exe'); \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x LLVM.exe -oC:\tools\LLVM' -Wait; \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\LLVM\bin'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +# Install MSYS2, and add some extra tools. +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20240113.tar.xz', \ + 'msys2.tar.xz'); \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x msys2.tar.xz -oC:\TEMP\msys2.tar' -Wait; \ + Start-Process -FilePath \"C:\Program Files\7-Zip\7z.exe\" -ArgumentList 'x C:\TEMP\msys2.tar -oC:\tools' -Wait; \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\msys64;C:\tools\msys64\usr\bin\'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +# Disable signature checking on pacman because we cannot initialize the keyring. +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.mingw32 -Value 'SigLevel = Never' +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.mingw64 -Value 'SigLevel = Never' +RUN Add-Content -Path C:\tools\msys64\etc\pacman.d\mirrorlist.msys -Value 'SigLevel = Never' + +# Install pacman packages. +RUN C:\tools\msys64\usr\bin\bash.exe -lc \ + 'pacman --noconfirm -Syy curl git patch vim unzip zip' + +# Install Python as a general utility/tool. +ENV PYTHON_VERSION 3.12.3 + +RUN $url = ('https://www.python.org/ftp/python/{0}/python-{0}-amd64.exe' -f $env:PYTHON_VERSION); \ + Write-Host ('Downloading {0} ...' -f $url); \ + [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \ + (New-Object Net.WebClient).DownloadFile($url, 'C:\tmp\pyinstall.exe'); \ + \ + Write-Host 'Installing...'; \ + Start-Process -FilePath \"C:\tmp\pyinstall.exe\" -ArgumentList '/quiet InstallAllUsers=1 PrependPath=1 TargetDir=C:\Python312' -Wait; \ + \ + Write-Host 'Verifying install ...'; \ + Write-Host ' python --version'; C:\python312\python.exe --version; \ + \ + Write-Host 'Verifying pip install ...'; \ + C:\python312\python.exe -m pip --version; \ + \ + Write-Host 'Removing ...'; \ + Remove-Item C:\tmp\pyinstall.exe -Force; \ + \ + Write-Host 'Complete.'; + +# Install pip packages. +RUN python -m pip install --ignore-installed --force-reinstall --upgrade \ + setuptools packaging + +# Install JDK 21. +RUN \ + Add-Type -AssemblyName \"System.IO.Compression.FileSystem\"; \ + $zulu_pkg = \"zulu21.34.19-ca-jdk21.0.3-win_x64.zip\"; \ + $zulu_url = \"https://cdn.azul.com/zulu/bin/${zulu_pkg}\"; \ + $zulu_zip = \"c:\\temp\\${zulu_pkg}\"; \ + $zulu_extracted_path = \"c:\\temp\\\" + [IO.Path]::GetFileNameWithoutExtension($zulu_zip); \ + $zulu_root = \"c:\\openjdk\"; \ + (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \ + [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:\\temp\"); \ + Move-Item $zulu_extracted_path -Destination $zulu_root; \ + Remove-Item $zulu_zip; \ + $env:PATH = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\") + \";${zulu_root}\\bin\"; \ + [Environment]::SetEnvironmentVariable(\"PATH\", $env:PATH, \"Machine\"); \ + $env:JAVA_HOME = $zulu_root; \ + [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $env:JAVA_HOME, \"Machine\") + +# Point to the LLVM installation. +# The Bazel Windows guide claims it can find LLVM automatically, +# but it likely only works if it's installed somewhere inside C:\Program Files. +ENV BAZEL_LLVM "C:\tools\LLVM" + +# These variables may be useful, but so far haven't been. Keeping for posterity. +# ENV CLANG_COMPILER_PATH "C:\tools\llvm\bin\clang.exe" +# ENV CC "C:\tools\llvm\bin\clang.exe" +# ENV BAZEL_COMPILER "C:\tools\llvm\bin\clang.exe" + +ENV BAZEL_SH "C:\tools\msys64\usr\bin\bash.exe" +ENV BAZEL_VS "C:\Program Files\Microsoft Visual Studio\2022\BuildTools" +ENV BAZEL_VC "C:\Program Files\Microsoft Visual Studio\2022\Community\VC" + +# Environment variables to work around MSYS issues. +ENV MSYS_NO_PATHCONV 1 +ENV MSYS2_ARG_CONV_EXCL * + +# This should only be necessary if there are multiple, differently-versioned +# MSVC compilers installed, and a particular one should be used. +# To find exact versions available: +# - Navigate to the relevant folder, e.g. +# C:\Program Files\Microsoft Visual Studio\2022 +# - Search for the `cl.exe` file: `gci -r -fi cl.exe` +# - The version will be part of the found path, e.g. +# 2022\Community\VC\Tools\MSVC\14.39.33519\bin\Hostx64\x64 +# ENV BAZEL_VC_FULL_VERSION 14.39.33519 + +# Install Bazelisk. +RUN md C:\tools\bazel +RUN (New-Object Net.WebClient).DownloadFile( \ + 'https://github.com/bazelbuild/bazelisk/releases/download/v1.19.0/bazelisk-windows-amd64.exe', \ + 'C:\tools\bazel\bazel.exe'); \ + $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine') + ';C:\tools\bazel'; \ + [Environment]::SetEnvironmentVariable('PATH', $env:PATH, 'Machine'); + +SHELL ["cmd.exe", "/s", "/c"] diff --git a/ci/devinfra/docker_windows/Dockerfile b/ci/devinfra/docker_windows/Dockerfile deleted file mode 100644 index 540f82abf5c35f..00000000000000 --- a/ci/devinfra/docker_windows/Dockerfile +++ /dev/null @@ -1,256 +0,0 @@ -FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:c1b2be17aa0c1a5d9493a306395a6f07141aae8d7897f7ba319183f28719c990 - -# Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes -# downloads with Invoke-WebRequest not show the progress bar and is MUCH faster). -SHELL ["powershell.exe", "-ExecutionPolicy", "Bypass", "-Command", "$ErrorActionPreference='Stop'; $ProgressPreference='SilentlyContinue'; $VerbosePreference = 'Continue';"] - -# Workaround for networking (b/112379377) was closed as won't fix for MTU setting. -# Remaining lines handle making the metadata server on the VM accessible inside docker. -RUN Get-NetAdapter | Where-Object Name -like "*Ethernet*" | ForEach-Object { \ - & netsh interface ipv4 set subinterface $_.InterfaceIndex mtu=1460 store=persistent }; \ - $gateway = (Get-NetRoute | Where { $_.DestinationPrefix -eq \"0.0.0.0/0\" } | Sort-Object RouteMetric \ - | Select NextHop).NextHop; \ - $ifIndex = (Get-NetAdapter -InterfaceDescription \"Hyper-V Virtual Ethernet*\" | Sort-Object \ - | Select ifIndex).ifIndex; \ - New-NetRoute -DestinationPrefix 169.254.169.254/32 -InterfaceIndex $ifIndex -NextHop $gateway - -# Enable Long Paths for Win32 File/Folder APIs. -RUN New-ItemProperty -Path HKLM:\SYSTEM\CurrentControlSet\Control\FileSystem \ - -Name LongPathsEnabled -Value 1 -PropertyType DWORD -Force - -# Install Visual C++ Redistributable for Visual Studio 2015-2022. -RUN New-Item -Path "C:/" -Name "TEMP" -ItemType "directory"; \ - Invoke-WebRequest "https://aka.ms/vs/17/release/vc_redist.x64.exe" \ - -OutFile C:/TEMP/vc_redist.x64.exe -UseBasicParsing; \ - Start-Process -filepath C:/TEMP/vc_redist.x64.exe -ArgumentList '/install', '/passive', '/norestart' -Wait; \ - Remove-Item C:/TEMP/vc_redist.x64.exe - -# Install Visual Studio 2022 Build Tools. Install ManagedDesktopBuildTools separately to ensure all Optional workloads are installed too. -RUN Invoke-WebRequest "https://aka.ms/vs/17/release/vs_buildtools.exe" \ - -OutFile C:/TEMP/vs_buildtools.exe -UseBasicParsing; \ - Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \ - "--quiet", "--wait", "--nocache", \ - "--add", "Microsoft.VisualStudio.Workload.VCTools", \ - "--add", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64", \ - "--add", "Microsoft.VisualStudio.Component.Windows10SDK.19041" -Wait; \ - Start-Process -FilePath C:/TEMP/vs_buildtools.exe -ArgumentList "--installPath", "C:/VS", \ - "--quiet", "--wait", "--nocache", "--includeOptional", \ - "--add", "Microsoft.VisualStudio.Workload.ManagedDesktopBuildTools" -Wait; \ - Remove-Item C:/TEMP/vs_buildtools.exe; \ - [Environment]::SetEnvironmentVariable(\"BAZEL_VC\", \"C:\VS\VC\", \"Machine\"); \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\VS\VC\Tools\MSVC\14.33.31629\bin\Hostx64\x64;C:\VS\Common7\Tools;C:\VS\MSBuild\Current\Bin\", \"Machine\"); - -# Add signtool.exe to the PATH. Note this path may need to be edited if updates -# are made to the Windows 10 SDK. -RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files (x86)\Windows Kits\10\App Certification Kit\", \"Machine\"); - -# Install WiX toolset (v4) - Necessary for MSI Installer/Signing builds -RUN dotnet tool install --global wix - -# Install msys2, packages and add to path. -RUN [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \ - Invoke-WebRequest "https://repo.msys2.org/distrib/x86_64/msys2-base-x86_64-20220319.sfx.exe" \ - -OutFile msys2_install.exe -UseBasicParsing; \ - .\msys2_install.exe -y -oC:\; \ - Remove-Item msys2_install.exe; \ - function msys() { C:\msys64\usr\bin\bash.exe @('-lc') + @Args; } \ - msys ' '; \ - msys 'pacman --noconfirm -Syy bsdcpio bsdtar bzip2'; \ - msys 'pacman --noconfirm -Syy coreutils curl dash file filesystem findutils'; \ - msys 'pacman --noconfirm -Syy flex gawk gcc-libs grep gzip inetutils info'; \ - msys 'pacman --noconfirm -Syy less lndir mintty ncurses pactoys-git patch'; \ - msys 'pacman --noconfirm -Syy pax-git pkgfile rebase sed tar tftp-hpa time tzcode util-linux which'; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\msys64;C:\msys64\usr\bin\", \"Machine\"); - -# Install Go 1.19.1 -RUN Invoke-WebRequest "https://go.dev/dl/go1.19.1.windows-amd64.msi" \ - -OutFile C:/TEMP/go_install.msi -UseBasicParsing; \ - Start-Process C:/TEMP/go_install.msi -ArgumentList "/quiet", "/log", "C:/TEMP/go_install_log.txt", \ - "InstallAllUsers=1", "PrependPath=1" -wait; \ - Remove-Item C:/TEMP/go_install.msi; \ - Remove-Item C:/TEMP/go_install_log.txt - -# Install Python 3. -RUN Invoke-WebRequest "https://www.python.org/ftp/python/3.10.4/python-3.10.4-amd64.exe" \ - -OutFile C:/TEMP/python_install.exe -UseBasicParsing; \ - Start-Process C:/TEMP/python_install.exe -ArgumentList "/quiet", "/log", "C:/TEMP/python_install_log.txt", \ - "InstallAllUsers=1", "PrependPath=1" -wait; \ - Remove-Item C:/TEMP/python_install.exe; \ - Remove-Item C:/TEMP/python_install_log.txt - -# Install JDK 17 -RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \ - $zulu_url = \"https://cdn.azul.com/zulu/bin/zulu17.32.13-ca-jdk17.0.2-win_x64.zip\"; \ - $zulu_zip = \"c:/temp/jdk_install.zip\"; \ - $zulu_extracted_path = \"c:/temp/\" + [IO.Path]::GetFileNameWithoutExtension($zulu_url); \ - $zulu_root = \"c:/openjdk\"; \ - (New-Object Net.WebClient).DownloadFile($zulu_url, $zulu_zip); \ - [System.IO.Compression.ZipFile]::ExtractToDirectory($zulu_zip, \"c:/temp\"); \ - Move-Item $zulu_extracted_path -Destination $zulu_root; \ - Remove-Item $zulu_zip; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${zulu_root}\bin\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"JAVA_HOME\", $zulu_root, \"Machine\") - -# Install gcloud (install.bat installs directly into bin folder of extracted zip contents) -# Install needed gcloud components -RUN Add-Type -AssemblyName "System.IO.Compression.FileSystem"; \ - $pkg_url = \"https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-396.0.0-windows-x86_64.zip\"; \ - $pkg_zip = \"c:/temp/gcloud.zip\"; \ - $pkg_extracted_path = \"c:/google-cloud-sdk\"; \ - (New-Object Net.WebClient).DownloadFile($pkg_url, $pkg_zip); \ - [System.IO.Compression.ZipFile]::ExtractToDirectory($pkg_zip, \"c:/\"); \ - Start-Process cmd.exe -ArgumentList "/c", "/s", "$pkg_extracted_path/install.bat", "-q" -wait; \ - Remove-Item $pkg_zip; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";${pkg_extracted_path}\bin\", \"Machine\"); \ - $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - gcloud components install docker-credential-gcr kubectl gsutil; - -# Install cygwin and packages -# Running a seperate ps1 file since when running inside a Dockerfile, it does -# not work. -COPY install/install_cygwin.ps1 c:/ -RUN c:/install_cygwin.ps1; \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Cygwin64\bin\", \"Machine\"); -RUN Remove-Item c:/install_cygwin.ps1 - -# Install Chocolatey and packages -RUN Invoke-Expression ((New-Object Net.WebClient).DownloadString('https://chocolatey.org/install.ps1')); \ - $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - choco feature enable -n allowGlobalConfirmation; \ - choco install 7zip; \ - choco install 7zip.install; \ - choco install 7zip.portable; \ - choco install anaconda2 --version 5.0.1; \ - choco install anaconda3 --version 5.0.1; \ - choco install android-sdk --version 25.2.3.1; \ - choco install AndroidStudio --version 3.0.1.0; \ - choco install ant --version 1.10.1; \ - choco install ccleaner; \ - choco install chocolatey; \ - choco install chocolatey-core.extension; \ - choco install chocolatey-visualstudio.extension; \ - choco install chocolatey-windowsupdate.extension; \ - choco install cmake.install; \ - choco install dotnetcore-sdk; \ - choco install git; \ - choco install git.install; \ - choco install GoogleChrome; \ - choco install gradle --version 4.4.1; \ - choco install jdk8; \ - choco install KB2533623; \ - choco install KB2919355; \ - choco install KB2919442; \ - choco install KB2999226; \ - choco install KB3033929; \ - choco install KB3035131; \ - choco install maven; \ - choco install ninja; \ - choco install nodejs --version 9.3.0; \ - choco install nodejs.install --version 9.3.0; \ - choco install nuget.commandline; \ - choco install openjdk11; \ - choco install peazip; \ - choco install peazip.install; \ - choco install peazip.portable; \ - choco install php --version 7.2.0; \ - choco install protoc --version 3.2.0; \ - choco install ruby --version 2.5.0.1; \ - choco install swig --version 3.0.9; \ - choco install sysinternals; \ - choco install unrar; \ - choco install unzip; \ - choco install vcredist140; \ - choco install vcredist2015; \ - choco install vim; \ - choco install winrar; \ - choco install zip; \ - choco install Firefox; \ - choco install iisexpress; - -RUN cmd /c 'mklink /J c:\Anaconda c:\tools\anaconda2'; -RUN cmd /c 'mklink c:\programdata\chocolatey\bin\rar.exe \"c:\program files\winrar\rar.exe\"'; - -# Installing pip packages -RUN pip install --upgrade setuptools; \ - pip install altgraph appdirs cachetools certifi cffi chardet colorama \ - cryptography cycler Cython decorator google-api-python-client \ - google-auth google-auth-httplib2 grpcio httplib2 idna ipython-genutils \ - kiwisolver macholib matplotlib nose numpy packaging pandas pickleshare pip \ - prompt-toolkit protobuf psutil pyasn1 pyasn1-modules pycparser Pygments \ - pyparsing pyreadline python-dateutil pytz pywin32 requests rsa setuptools \ - simplegeneric six Tempita traitlets uritemplate urllib3 virtualenv wcwidth \ - wheel win-unicode-console; - -# Hardcoding Android license since I did not find any solution on accepting it -# through the docker build command. If the licensing agreement changes, this -# will need to be updated as well. -RUN New-Item -ItemType Directory -Path C:\Android\android-sdk\licenses; \ - Set-Content -Path .\Android\android-sdk\licenses\android-sdk-license -Value "`n24333f8a63b6825ea9c5514f83c2829b004d1fee" -NoNewLine; - -# Add sdkmanager to PATH -RUN $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Android\android-sdk\tools\bin\", \"Machine\"); - -# Install android packages -RUN $env:PATH = [Environment]::GetEnvironmentVariable('PATH', 'Machine'); \ - New-Item C:\Users\ContainerAdministrator\.android\repositories.cfg; \ - sdkmanager 'ndk-bundle'; \ - sdkmanager 'platforms;android-33'; \ - sdkmanager 'add-ons;addon-google_apis-google-24'; \ - sdkmanager 'cmake;3.10.2.4988404'; \ - sdkmanager 'cmake;3.18.1'; \ - sdkmanager 'cmake;3.22.1'; \ - sdkmanager 'cmake;3.6.4111459'; \ - sdkmanager 'emulator'; \ - sdkmanager 'system-images;android-27;google_apis;x86'; \ - sdkmanager 'sources;android-27'; \ - sdkmanager 'extras;google;Android_Emulator_Hypervisor_Driver'; \ - sdkmanager 'extras;google;auto'; \ - sdkmanager 'extras;google;google_play_services'; \ - sdkmanager 'extras;google;instantapps'; \ - sdkmanager 'extras;google;m2repository'; \ - sdkmanager 'extras;google;market_apk_expansion'; \ - sdkmanager 'extras;google;market_licensing'; \ - sdkmanager 'extras;google;simulators'; \ - sdkmanager 'extras;google;usb_driver'; \ - sdkmanager 'extras;google;webdriver'; \ - sdkmanager 'extras;android;m2repository'; \ - sdkmanager 'extras;intel;Hardware_Accelerated_Execution_Manager'; \ - sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout;1.0.0'; \ - sdkmanager 'extras;m2repository;com;android;support;constraint;constraint-layout-solver;1.0.2'; \ - sdkmanager 'patcher;v4'; \ - sdkmanager 'ndk;25.1.8937393'; \ - sdkmanager 'build-tools;27.0.3'; - -# Install Scoop and packages -RUN iex \"& {$(irm get.scoop.sh)} -RunAsAdmin\"; \ - scoop install perl; \ - scoop install bazel; \ - scoop install cuda; \ - scoop install azure-functions-core-tools; \ - scoop install azure-cli; - -# Setting environment variables -RUN [Environment]::SetEnvironmentVariable('CYGWIN', 'winsymlinks:native', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOME', 'C:\Users\ContainerAdministrator\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOMEDRIVE', 'C:', 'Machine'); \ - [Environment]::SetEnvironmentVariable('HOMEPATH', '\Users\ContainerAdministrator\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('GOROOT', 'C:\Program Files\Go\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('KOKORO_POSIX_ROOT', '/tmpfs', 'Machine'); \ - [Environment]::SetEnvironmentVariable('KOKORO_ROOT', 'T:\', 'Machine'); \ - [Environment]::SetEnvironmentVariable('SHELL', '/bin/bash', 'Machine'); \ - $old_path = [Environment]::GetEnvironmentVariable(\"PATH\", \"Machine\"); \ - [Environment]::SetEnvironmentVariable(\"PATH\", $old_path + \";C:\Program Files\CMake\bin\", \"Machine\"); - - -# Restore default shell for Windows containers. -SHELL ["cmd.exe", "/s", "/c"] - -# Default to PowerShell if no other command specified. -CMD ["powershell.exe", "-NoLogo", "-ExecutionPolicy", "Bypass"] diff --git a/ci/official/containers/linux_arm64/Dockerfile b/ci/official/containers/linux_arm64/Dockerfile index c2161dfe4ad6f3..428347a5b6a847 100644 --- a/ci/official/containers/linux_arm64/Dockerfile +++ b/ci/official/containers/linux_arm64/Dockerfile @@ -62,6 +62,9 @@ COPY devel.usertools /usertools COPY devel.bashrc /root/.bashrc COPY ld.so.conf /dt10/etc/ +# Make sure clang is on the path +RUN ln -s /usr/lib/llvm-18/bin/clang /usr/bin/clang + # Setup JAX Python environment. FROM devel as jax RUN /setup.packages.sh /cuda.packages.txt diff --git a/ci/official/containers/linux_arm64/build.sh b/ci/official/containers/linux_arm64/build.sh index 611d5f48ac0084..ffead7f1c31e74 100755 --- a/ci/official/containers/linux_arm64/build.sh +++ b/ci/official/containers/linux_arm64/build.sh @@ -16,8 +16,8 @@ # Builds the following Docker images for Linux ARM64. See the accompanying # Dockerfile for more details: -# - gcr.io/tensorflow-sigs/build-arm64:jax-latest-multi-python -# - gcr.io/tensorflow-sigs/build-arm64:tf-latest-multi-python +# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:jax-latest-multi-python +# - us-central1-docker.pkg.dev/tensorflow-sigs/build-arm64:tf-latest-multi-python set -exo pipefail @@ -40,16 +40,14 @@ else fi fi -# TODO(b/341050361): When these steps are verified, removed the GCR image code. AR_IMAGE_PATH="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64" # Build for both JAX and TF usage. We do these in one place because they share # almost all of the same cache layers export DOCKER_BUILDKIT=1 for target in jax tf; do - IMAGE="gcr.io/tensorflow-sigs/build-arm64:$target-$TAG" AR_IMAGE="$AR_IMAGE_PATH:$target-$TAG" - docker pull "$IMAGE" || true + docker pull "$AR_IMAGE" || true # Due to some flakiness of resources pulled in the build, allow the docker # command to reattempt build a few times in the case of failure (b/302558736) set +e @@ -58,8 +56,8 @@ for target in jax tf; do docker build \ --build-arg REQUIREMENTS_FILE=jax.requirements.txt \ --target=$target \ - --cache-from "$IMAGE" \ - -t "$IMAGE" -t "$AR_IMAGE" . && break + --cache-from "$AR_IMAGE" \ + -t "$AR_IMAGE" . && break done final=$? if [ $final -ne 0 ]; then @@ -68,8 +66,6 @@ for target in jax tf; do set -e if [[ -n "$KOKORO_BUILD_ID" ]]; then - gcloud auth configure-docker - docker push "$IMAGE" gcloud auth configure-docker us-central1-docker.pkg.dev docker push "$AR_IMAGE" fi diff --git a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats index 85cbc7b7058148..cdfc81499af7f0 100644 --- a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats +++ b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats @@ -216,6 +216,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ @@ -237,6 +239,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:build_pip_package, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/ci/official/envs/rbe b/ci/official/envs/rbe index 12cc600b0a76a9..35f817310b2f36 100644 --- a/ci/official/envs/rbe +++ b/ci/official/envs/rbe @@ -33,7 +33,17 @@ EOF fi TFCI_BAZEL_COMMON_ARGS="$TFCI_BAZEL_COMMON_ARGS --config rbe_$TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX" -# These flags share the user's gcloud credentials with the container, so that bazel -# inside the container can authenticate. Note: TF's CI does not have any credential -# stored here. -TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud" +if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + # Docker on Windows doesn't support the `host` networking mode, and so + # port-forwarding is required for the container to detect it's running on GCE. + export IP_ADDR=$(powershell -command "(Get-NetIPAddress -AddressFamily IPv4 -InterfaceAlias 'vEthernet (nat)').IPAddress") + netsh interface portproxy add v4tov4 listenaddress=$IP_ADDR listenport=80 connectaddress=169.254.169.254 connectport=80 + # A local firewall rule for the container is added in + # ci/official/utilities/setup_docker.sh. +else + # The volume mapping flag below shares the user's gcloud credentials, if any, + # with the container, in case the user has credentials stored there. + # This would allow Bazel to authenticate for RBE. + # Note: TF's CI does not have any credentials stored there. + TFCI_DOCKER_ARGS="$TFCI_DOCKER_ARGS -v $HOME/.config/gcloud:/root/.config/gcloud" +fi diff --git a/ci/official/envs/windows_x86 b/ci/official/envs/windows_x86 new file mode 100644 index 00000000000000..568a47f76530f8 --- /dev/null +++ b/ci/official/envs/windows_x86 @@ -0,0 +1,20 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +TFCI_DOCKER_ENABLE=1 +TFCI_DOCKER_PULL_ENABLE=1 +TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2019-rbe@sha256:1082ef4299a72e44a84388f192ecefc81ec9091c146f507bc36070c089c0edcc" +TFCI_BAZEL_COMMON_ARGS="--repo_env=TF_PYTHON_VERSION=$TFCI_PYTHON_VERSION" +TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu +TFCI_OUTPUT_DIR=build_output diff --git a/ci/official/pycpp.sh b/ci/official/pycpp.sh index cf346007949c1e..f6f20900f0a277 100755 --- a/ci/official/pycpp.sh +++ b/ci/official/pycpp.sh @@ -15,12 +15,19 @@ # ============================================================================== source "${BASH_SOURCE%/*}/utilities/setup.sh" +if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + PROFILE_JSON_PATH=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR") + PROFILE_JSON_PATH="$PROFILE_JSON_PATH/profile.json.gz" +else + PROFILE_JSON_PATH="$TFCI_OUTPUT_DIR/profile.json.gz" +fi + if [[ $TFCI_PYCPP_SWAP_TO_BUILD_ENABLE == 1 ]]; then - tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" + tfrun bazel build $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" else - tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$TFCI_OUTPUT_DIR/profile.json.gz" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" + tfrun bazel test $TFCI_BAZEL_COMMON_ARGS --profile "$PROFILE_JSON_PATH" --config="${TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX}_pycpp_test" fi # Note: the profile can be viewed by visiting chrome://tracing in a Chrome browser. # See https://docs.bazel.build/versions/main/skylark/performance.html#performance-profiling -tfrun bazel analyze-profile "$TFCI_OUTPUT_DIR/profile.json.gz" +tfrun bazel analyze-profile "$PROFILE_JSON_PATH" diff --git a/ci/official/utilities/cleanup_summary.sh b/ci/official/utilities/cleanup_summary.sh index dbe2203fa130af..6b6fdfaa855106 100755 --- a/ci/official/utilities/cleanup_summary.sh +++ b/ci/official/utilities/cleanup_summary.sh @@ -23,8 +23,9 @@ IMPORTANT: For bazel invocations that uploaded to ResultStore (e.g. RBE), you can view more detailed results that are probably easier to read than this log. Try the links below: EOF - # Find any "Streaming build results to" line, then print the last word in it, - # and don't print duplicates + # Find any "Streaming build results to" lines, + # de-duplicate, + # and print the last word from each awk '/Streaming build results to/ {print $NF}' "$TFCI_OUTPUT_DIR/script.log" | uniq } @@ -32,14 +33,15 @@ EOF # Each failed target there will have its own representation, making failures # easier to find and read. function resultstore_extract { - local \ - XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml" + local PYTHON_BIN XML_PATH + PYTHON_BIN=$(which python3 2>/dev/null || which python) + XML_PATH="$TFCI_OUTPUT_DIR/Bazel_Test_and_Build_Results/sponge_log.xml" - python3 \ + "$PYTHON_BIN" \ "$TFCI_GIT_DIR/ci/official/utilities/extract_resultstore_links.py" \ "$TFCI_OUTPUT_DIR/script.log" \ --print \ - --xml-out-path "$XML_PATH" || resultstore_extract_fallback + --xml-out-path "$XML_PATH" } if grep -q "Streaming build results to" "$TFCI_OUTPUT_DIR/script.log"; then diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index 691ec3a3a025ae..ede80f4372bc14 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -216,6 +216,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "somepath(//tensorflow/tools/pip_package:wheel, " \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ @@ -237,6 +239,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:wheel, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/ci/official/utilities/convert_msys_paths_to_win_paths.py b/ci/official/utilities/convert_msys_paths_to_win_paths.py new file mode 100644 index 00000000000000..ed1dd3b0925246 --- /dev/null +++ b/ci/official/utilities/convert_msys_paths_to_win_paths.py @@ -0,0 +1,76 @@ +#!/usr/bin/python3 +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +"""Converts MSYS Linux-like paths stored in env variables to Windows paths. + +This is necessary on Windows, because some applications do not understand/handle +Linux-like paths MSYS uses, for example, Docker. +""" + +import argparse +import os + + +def should_convert(var_name: str, + blacklist: list[str] | None, + whitelist_prefix: list[str] | None): + """Check the variable name against white/black lists.""" + if blacklist and var_name in blacklist: + return False + if not whitelist_prefix: + return True + + for prefix in whitelist_prefix: + if var_name.startswith(prefix): + return True + return False + + +def main(parsed_args: argparse.Namespace): + converted_vars = {} + + for var, value in os.environ.items(): + if not value or not should_convert(var, + parsed_args.blacklist, + parsed_args.whitelist_prefix): + continue + + # In Python, MSYS, Linux-like paths are automatically read as Windows paths + # with forward slashes, e.g. 'C:/Program Files', instead of + # '/c/Program Files', thus becoming converted simply by virtue of having + # been read. + converted_vars[var] = value + + var_str = '\n'.join(f'{k}="{v}"' + for k, v in converted_vars.items()) + # The string can then be piped into `source`, to re-set the + # 'converted' variables. + print(var_str) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description=( + 'Convert MSYS paths in environment variables to Windows paths.')) + parser.add_argument('--blacklist', + nargs='*', + help='List of variables to ignore') + parser.add_argument('--whitelist-prefix', + nargs='*', + help='Prefix for variables to include') + args = parser.parse_args() + + main(args) diff --git a/ci/official/utilities/setup.sh b/ci/official/utilities/setup.sh index 2277b7551db587..55d0e079379a9b 100755 --- a/ci/official/utilities/setup.sh +++ b/ci/official/utilities/setup.sh @@ -118,6 +118,12 @@ exec > >(tee "$TFCI_OUTPUT_DIR/script.log") 2>&1 # functionality instead. tfrun() { "$@"; } +if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + source ./ci/official/utilities/windows.sh + echo 'Converting MSYS Linux-like paths to Windows paths (for Docker, Python, etc.)' + source <(python ./ci/official/utilities/convert_msys_paths_to_win_paths.py --whitelist-prefix TFCI_) +fi + # Run all "tfrun" commands under Docker. See setup_docker.sh for details if [[ "$TFCI_DOCKER_ENABLE" == 1 ]]; then source ./ci/official/utilities/setup_docker.sh diff --git a/ci/official/utilities/setup_docker.sh b/ci/official/utilities/setup_docker.sh index 91618c75f3ba51..61db7c2e124d0a 100755 --- a/ci/official/utilities/setup_docker.sh +++ b/ci/official/utilities/setup_docker.sh @@ -37,10 +37,30 @@ if ! docker container inspect tf >/dev/null 2>&1 ; then # Pass all existing TFCI_ variables into the Docker container env_file=$(mktemp) env | grep ^TFCI_ > "$env_file" - docker run $TFCI_DOCKER_ARGS --name tf -w "$TFCI_GIT_DIR" -itd --rm \ - -v "$TFCI_GIT_DIR:$TFCI_GIT_DIR" \ + + WORKING_DIR="$TFCI_GIT_DIR" + if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + env_file=$(cygpath -m $env_file) + # Host dirs can only be mapped to an existing drive inside the container, so + # T:\ is replaced with C:\. + _TFCI_OUTPUT_DIR_WIN=$(replace_drive_letter_with_c "$TFCI_OUTPUT_DIR") + sed -iE 's|^TFCI_OUTPUT_DIR=.*|TFCI_OUTPUT_DIR='"$_TFCI_OUTPUT_DIR_WIN"'|g' $env_file + WORKING_DIR=$(replace_drive_letter_with_c "$TFCI_GIT_DIR") + echo "GCE_METADATA_HOST=$IP_ADDR" > $env_file + fi + + docker run $TFCI_DOCKER_ARGS --name tf -w "$WORKING_DIR" -itd --rm \ + -v "$TFCI_GIT_DIR:$WORKING_DIR" \ --env-file "$env_file" \ "$TFCI_DOCKER_IMAGE" \ bash + + if [[ `uname -s | grep -P '^MSYS_NT'` ]]; then + # Allow requests from the container. + # Additional setup is contained in ci/official/envs/rbe. + CONTAINER_IP_ADDR=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' tf) + netsh advfirewall firewall add rule name="Allow Metadata Proxy" dir=in action=allow protocol=TCP localport=80 remoteip="$CONTAINER_IP_ADDR" + fi + fi tfrun() { docker exec tf "$@"; } diff --git a/ci/official/utilities/windows.sh b/ci/official/utilities/windows.sh new file mode 100644 index 00000000000000..1ab2d89ef327f6 --- /dev/null +++ b/ci/official/utilities/windows.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Windows-specific utilities. +# + +# Docker on Windows has difficulty using volumes other than C:\, when it comes +# to setting up up volume mappings. +# Thus, the drive letter is replaced with C:\, in case it's +# something else (ex. T:), which is frequently the case inside Kokoro jobs. +function replace_drive_letter_with_c () { + sed -E "s|^[a-zA-Z]:|C:|g" <<< $1 +} diff --git a/configure.py b/configure.py index 592f5c0d2117e1..50ed76e9f23d14 100644 --- a/configure.py +++ b/configure.py @@ -16,7 +16,6 @@ import argparse import errno -import glob import json import os import platform @@ -31,9 +30,6 @@ from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_DEFAULT_CUDA_VERSION = '11' -_DEFAULT_CUDNN_VERSION = '2' -_DEFAULT_TENSORRT_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0' _SUPPORTED_ANDROID_NDK_VERSIONS = [19, 20, 21, 25] @@ -128,6 +124,12 @@ def write_action_env_to_bazelrc(var_name, var): write_to_bazelrc('build --action_env {}="{}"'.format(var_name, str(var))) +def write_repo_env_to_bazelrc(config_name, var_name, var): + write_to_bazelrc( + 'build:{} --repo_env {}="{}"'.format(config_name, var_name, str(var)) + ) + + def run_shell(cmd, allow_non_zero=False, stderr=None): if stderr is None: stderr = sys.stdout @@ -239,7 +241,7 @@ def setup_python(environ_cp): write_to_bazelrc('build --python_path=\"{}"'.format(python_bin_path)) environ_cp['PYTHON_BIN_PATH'] = python_bin_path - # If choosen python_lib_path is from a path specified in the PYTHONPATH + # If chosen python_lib_path is from a path specified in the PYTHONPATH # variable, need to tell bazel to include PYTHONPATH if environ_cp.get('PYTHONPATH'): python_paths = environ_cp.get('PYTHONPATH').split(':') @@ -778,11 +780,6 @@ def get_ndk_api_level(environ_cp, android_ndk_home_path): def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' - cuda_bin_symlink = '%s/bin/gcc' % environ_cp.get('CUDA_TOOLKIT_PATH') - - if os.path.islink(cuda_bin_symlink): - # os.readlink is only available in linux - default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) gcc_host_compiler_path = prompt_loop_or_load_from_env( environ_cp, @@ -947,108 +944,42 @@ def disable_clang_offsetof_extension(clang_version): write_to_bazelrc('build --copt=-Wno-gnu-offsetof-extensions') -def set_tf_cuda_paths(environ_cp): - """Set TF_CUDA_PATHS.""" - ask_cuda_paths = ( - 'Please specify the comma-separated list of base paths to look for CUDA ' - 'libraries and headers. [Leave empty to use the default]: ') - tf_cuda_paths = get_from_env_or_user_or_default(environ_cp, 'TF_CUDA_PATHS', - ask_cuda_paths, '') - if tf_cuda_paths: - environ_cp['TF_CUDA_PATHS'] = tf_cuda_paths - - -def set_tf_cuda_version(environ_cp): - """Set TF_CUDA_VERSION.""" +def set_hermetic_cuda_version(environ_cp): + """Set HERMETIC_CUDA_VERSION.""" ask_cuda_version = ( - 'Please specify the CUDA SDK version you want to use. ' - '[Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - tf_cuda_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDA_VERSION', - ask_cuda_version, - _DEFAULT_CUDA_VERSION) - environ_cp['TF_CUDA_VERSION'] = tf_cuda_version + 'Please specify the hermetic CUDA version you want to use ' + 'or leave empty to use the default version. ' + ) + hermetic_cuda_version = get_from_env_or_user_or_default( + environ_cp, 'HERMETIC_CUDA_VERSION', ask_cuda_version, None + ) + if hermetic_cuda_version: + environ_cp['HERMETIC_CUDA_VERSION'] = hermetic_cuda_version + write_repo_env_to_bazelrc( + 'cuda', 'HERMETIC_CUDA_VERSION', hermetic_cuda_version + ) -def set_tf_cudnn_version(environ_cp): - """Set TF_CUDNN_VERSION.""" +def set_hermetic_cudnn_version(environ_cp): + """Set HERMETIC_CUDNN_VERSION.""" ask_cudnn_version = ( - 'Please specify the cuDNN version you want to use. ' - '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION - tf_cudnn_version = get_from_env_or_user_or_default(environ_cp, - 'TF_CUDNN_VERSION', - ask_cudnn_version, - _DEFAULT_CUDNN_VERSION) - environ_cp['TF_CUDNN_VERSION'] = tf_cudnn_version - - -def set_tf_tensorrt_version(environ_cp): - """Set TF_TENSORRT_VERSION.""" - if not (is_linux() or is_windows()): - raise ValueError('Currently TensorRT is only supported on Linux platform.') - - if not int(environ_cp.get('TF_NEED_TENSORRT', False)): - return - - ask_tensorrt_version = ( - 'Please specify the TensorRT version you want to use. ' - '[Leave empty to default to TensorRT %s]: ') % _DEFAULT_TENSORRT_VERSION - tf_tensorrt_version = get_from_env_or_user_or_default( - environ_cp, 'TF_TENSORRT_VERSION', ask_tensorrt_version, - _DEFAULT_TENSORRT_VERSION) - environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version - - -def set_tf_nccl_version(environ_cp): - """Set TF_NCCL_VERSION.""" - if not is_linux(): - raise ValueError('Currently NCCL is only supported on Linux platform.') - - if 'TF_NCCL_VERSION' in environ_cp: - return - - ask_nccl_version = ( - 'Please specify the locally installed NCCL version you want to use. ' - '[Leave empty to use http://github.com/nvidia/nccl]: ') - tf_nccl_version = get_from_env_or_user_or_default(environ_cp, - 'TF_NCCL_VERSION', - ask_nccl_version, '') - environ_cp['TF_NCCL_VERSION'] = tf_nccl_version - - -def get_native_cuda_compute_capabilities(environ_cp): - """Get native cuda compute capabilities. - - Args: - environ_cp: copy of the os.environ. - - Returns: - string of native cuda compute capabilities, separated by comma. - """ - device_query_bin = os.path.join( - environ_cp.get('CUDA_TOOLKIT_PATH'), 'extras/demo_suite/deviceQuery') - if os.path.isfile(device_query_bin) and os.access(device_query_bin, os.X_OK): - try: - output = run_shell(device_query_bin).split('\n') - pattern = re.compile('\d*\\.\d*') - output = [pattern.search(x) for x in output if 'Capability' in x] - output = ','.join(x.group() for x in output if x is not None) - except subprocess.CalledProcessError: - output = '' - else: - output = '' - return output + 'Please specify the hermetic cuDNN version you want to use ' + 'or leave empty to use the default version. ' + ) + hermetic_cudnn_version = get_from_env_or_user_or_default( + environ_cp, 'HERMETIC_CUDNN_VERSION', ask_cudnn_version, None + ) + if hermetic_cudnn_version: + environ_cp['HERMETIC_CUDNN_VERSION'] = hermetic_cudnn_version + write_repo_env_to_bazelrc( + 'cuda', 'HERMETIC_CUDNN_VERSION', hermetic_cudnn_version + ) -def set_tf_cuda_compute_capabilities(environ_cp): - """Set TF_CUDA_COMPUTE_CAPABILITIES.""" +def set_hermetic_cuda_compute_capabilities(environ_cp): + """Set HERMETIC_CUDA_COMPUTE_CAPABILITIES.""" while True: - native_cuda_compute_capabilities = get_native_cuda_compute_capabilities( - environ_cp) - if not native_cuda_compute_capabilities: - default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES - else: - default_cuda_compute_capabilities = native_cuda_compute_capabilities + default_cuda_compute_capabilities = _DEFAULT_CUDA_COMPUTE_CAPABILITIES ask_cuda_compute_capabilities = ( 'Please specify a list of comma-separated CUDA compute capabilities ' @@ -1060,15 +991,20 @@ def set_tf_cuda_compute_capabilities(environ_cp): 'significantly increases your build time and binary size, and that ' 'TensorFlow only supports compute capabilities >= 3.5 [Default is: ' '%s]: ' % default_cuda_compute_capabilities) - tf_cuda_compute_capabilities = get_from_env_or_user_or_default( - environ_cp, 'TF_CUDA_COMPUTE_CAPABILITIES', - ask_cuda_compute_capabilities, default_cuda_compute_capabilities) + hermetic_cuda_compute_capabilities = get_from_env_or_user_or_default( + environ_cp, + 'HERMETIC_CUDA_COMPUTE_CAPABILITIES', + ask_cuda_compute_capabilities, + default_cuda_compute_capabilities, + ) # Check whether all capabilities from the input is valid all_valid = True # Remove all whitespace characters before splitting the string # that users may insert by accident, as this will result in error - tf_cuda_compute_capabilities = ''.join(tf_cuda_compute_capabilities.split()) - for compute_capability in tf_cuda_compute_capabilities.split(','): + hermetic_cuda_compute_capabilities = ''.join( + hermetic_cuda_compute_capabilities.split() + ) + for compute_capability in hermetic_cuda_compute_capabilities.split(','): m = re.match('[0-9]+.[0-9]+', compute_capability) if not m: # We now support sm_35,sm_50,sm_60,compute_70. @@ -1103,15 +1039,32 @@ def set_tf_cuda_compute_capabilities(environ_cp): break # Reset and Retry - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = '' + environ_cp['HERMETIC_CUDA_COMPUTE_CAPABILITIES'] = '' - # Set TF_CUDA_COMPUTE_CAPABILITIES - environ_cp['TF_CUDA_COMPUTE_CAPABILITIES'] = tf_cuda_compute_capabilities - write_action_env_to_bazelrc( - 'TF_CUDA_COMPUTE_CAPABILITIES', tf_cuda_compute_capabilities + # Set HERMETIC_CUDA_COMPUTE_CAPABILITIES + environ_cp['HERMETIC_CUDA_COMPUTE_CAPABILITIES'] = ( + hermetic_cuda_compute_capabilities + ) + write_repo_env_to_bazelrc( + 'cuda', + 'HERMETIC_CUDA_COMPUTE_CAPABILITIES', + hermetic_cuda_compute_capabilities, ) +def set_cuda_local_path(environ_cp, dist_name, env_var): + ask_path = ( + 'Please specify the local {} path you want to use ' + 'or leave empty to use the default version. ' + ).format(dist_name) + local_path = get_from_env_or_user_or_default( + environ_cp, env_var, ask_path, None + ) + if local_path: + environ_cp[env_var] = local_path + write_repo_env_to_bazelrc('cuda', env_var, local_path) + + def set_other_cuda_vars(environ_cp): """Set other CUDA related variables.""" # If CUDA is enabled, always use GPU during build and test. @@ -1227,73 +1180,6 @@ def configure_ios(environ_cp): symlink_force(filepath, new_filepath) -def validate_cuda_config(environ_cp): - """Run find_cuda_config.py and return cuda_toolkit_path, or None.""" - - def maybe_encode_env(env): - """Encodes unicode in env to str on Windows python 2.x.""" - if not is_windows() or sys.version_info[0] != 2: - return env - for k, v in env.items(): - if isinstance(k, unicode): - k = k.encode('ascii') - if isinstance(v, unicode): - v = v.encode('ascii') - env[k] = v - return env - - cuda_libraries = ['cuda', 'cudnn'] - if is_linux(): - if int(environ_cp.get('TF_NEED_TENSORRT', False)): - cuda_libraries.append('tensorrt') - if environ_cp.get('TF_NCCL_VERSION', None): - cuda_libraries.append('nccl') - if is_windows(): - if int(environ_cp.get('TF_NEED_TENSORRT', False)): - cuda_libraries.append('tensorrt') - print('WARNING: TensorRT support on Windows is experimental\n') - - paths = glob.glob('**/third_party/gpus/find_cuda_config.py', recursive=True) - if not paths: - raise FileNotFoundError( - "Can't find 'find_cuda_config.py' script inside working directory") - proc = subprocess.Popen( - [environ_cp['PYTHON_BIN_PATH'], paths[0]] + cuda_libraries, - stdout=subprocess.PIPE, - env=maybe_encode_env(environ_cp)) - - if proc.wait(): - # Errors from find_cuda_config.py were sent to stderr. - print('Asking for detailed CUDA configuration...\n') - return False - - config = dict( - tuple(line.decode('ascii').rstrip().split(': ')) for line in proc.stdout) - - print('Found CUDA %s in:' % config['cuda_version']) - print(' %s' % config['cuda_library_dir']) - print(' %s' % config['cuda_include_dir']) - - print('Found cuDNN %s in:' % config['cudnn_version']) - print(' %s' % config['cudnn_library_dir']) - print(' %s' % config['cudnn_include_dir']) - - if 'tensorrt_version' in config: - print('Found TensorRT %s in:' % config['tensorrt_version']) - print(' %s' % config['tensorrt_library_dir']) - print(' %s' % config['tensorrt_include_dir']) - - if config.get('nccl_version', None): - print('Found NCCL %s in:' % config['nccl_version']) - print(' %s' % config['nccl_library_dir']) - print(' %s' % config['nccl_include_dir']) - - print('\n') - - environ_cp['CUDA_TOOLKIT_PATH'] = config['cuda_toolkit_path'] - return True - - def get_gcc_compiler(environ_cp): gcc_env = environ_cp.get('CXX') or environ_cp.get('CC') or which('gcc') if gcc_env is not None: @@ -1344,9 +1230,6 @@ def main(): environ_cp['TF_DOWNLOAD_CLANG'] = '0' environ_cp['TF_NEED_MPI'] = '0' - if is_macos(): - environ_cp['TF_NEED_TENSORRT'] = '0' - if is_ppc64le(): # Enable MMA Dynamic Dispatch support if 'gcc' and if linker >= 2.35 gcc_env = get_gcc_compiler(environ_cp) @@ -1395,62 +1278,14 @@ def main(): else: environ_cp['TF_NEED_CUDA'] = str( int(get_var(environ_cp, 'TF_NEED_CUDA', 'CUDA', False))) - if (environ_cp.get('TF_NEED_CUDA') == '1' and - 'TF_CUDA_CONFIG_REPO' not in environ_cp): - - set_action_env_var( - environ_cp, - 'TF_NEED_TENSORRT', - 'TensorRT', - False, - bazel_config_name='tensorrt') - - environ_save = dict(environ_cp) - for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): - - if validate_cuda_config(environ_cp): - cuda_env_names = [ - 'TF_CUDA_VERSION', - 'TF_CUBLAS_VERSION', - 'TF_CUDNN_VERSION', - 'TF_TENSORRT_VERSION', - 'TF_NCCL_VERSION', - 'TF_CUDA_PATHS', - # Items below are for backwards compatibility when not using - # TF_CUDA_PATHS. - 'CUDA_TOOLKIT_PATH', - 'CUDNN_INSTALL_PATH', - 'NCCL_INSTALL_PATH', - 'NCCL_HDR_PATH', - 'TENSORRT_INSTALL_PATH' - ] - # Note: set_action_env_var above already writes to bazelrc. - for name in cuda_env_names: - if name in environ_cp: - write_action_env_to_bazelrc(name, environ_cp[name]) - break - - # Restore settings changed below if CUDA config could not be validated. - environ_cp = dict(environ_save) - - set_tf_cuda_version(environ_cp) - set_tf_cudnn_version(environ_cp) - if is_windows(): - set_tf_tensorrt_version(environ_cp) - if is_linux(): - set_tf_tensorrt_version(environ_cp) - set_tf_nccl_version(environ_cp) - - set_tf_cuda_paths(environ_cp) - - else: - raise UserInputError( - 'Invalid CUDA setting were provided %d ' - 'times in a row. Assuming to be a scripting mistake.' - % _DEFAULT_PROMPT_ASK_ATTEMPTS - ) + if environ_cp.get('TF_NEED_CUDA') == '1': + set_hermetic_cuda_version(environ_cp) + set_hermetic_cudnn_version(environ_cp) + set_hermetic_cuda_compute_capabilities(environ_cp) + set_cuda_local_path(environ_cp, 'CUDA', 'LOCAL_CUDA_PATH') + set_cuda_local_path(environ_cp, 'CUDNN', 'LOCAL_CUDNN_PATH') + set_cuda_local_path(environ_cp, 'NCCL', 'LOCAL_NCCL_PATH') - set_tf_cuda_compute_capabilities(environ_cp) if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( 'LD_LIBRARY_PATH') != '1': write_action_env_to_bazelrc('LD_LIBRARY_PATH', diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 3556df1d1f99e3..5ebf9b1fa20fed 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -1053,7 +1053,7 @@ package_group( "//learning/serving/experimental/remote_predict/...", "//perftools/accelerators/xprof/convert/...", "//perftools/accelerators/xprof/integration_tests/...", - "//smartass/brain/configure/...", + "//smartass/brain/...", "//tensorflow/...", "//tensorflow_decision_forests/...", "//tensorflow_federated/...", @@ -1350,7 +1350,7 @@ tf_cc_shared_library( "//tensorflow/core:tensorflow", "//tensorflow/core/data:standalone", # Exports for pywrap_tensorflow_internal. Many of these are transitive - # depedencies of the above, but must be explicitly listed for + # dependencies of the above, but must be explicitly listed for # cc_shared_library to work. "//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_internal", @@ -1445,7 +1445,7 @@ tf_cc_shared_library( "//tensorflow/lite/kernels/internal:compatibility", "//tensorflow/lite/kernels:builtin_ops", "//tensorflow/lite/kernels:reference_ops", - "//tensorflow/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/lite/toco/logging:conversion_log_util", "//tensorflow/lite/toco/logging:toco_conversion_log_proto_cc", "//tensorflow/lite/toco:model_flags_proto_cc", diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc index 86b201fb2fb431..1123ccbf33284f 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/tf_buffer.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/function.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace parallel_device { diff --git a/tensorflow/c/eager/parallel_device/parallel_device_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_test.cc index ec2ce95665c94c..88ef5c14b14564 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_test.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" // NOTE(allenl): These tests currently go through TFE_Execute and so are // integration testing rather than purely testing the parallel device. They diff --git a/tensorflow/c/experimental/grappler/grappler_test.cc b/tensorflow/c/experimental/grappler/grappler_test.cc index 57f1a65ad0c6fe..357432e2c58018 100644 --- a/tensorflow/c/experimental/grappler/grappler_test.cc +++ b/tensorflow/c/experimental/grappler/grappler_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/tf_buffer.h" #include "tensorflow/c/tf_buffer_internal.h" #include "tensorflow/c/tf_status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/c/experimental/next_pluggable_device/BUILD b/tensorflow/c/experimental/next_pluggable_device/BUILD index 56586f757f369b..45c55c315f5350 100644 --- a/tensorflow/c/experimental/next_pluggable_device/BUILD +++ b/tensorflow/c/experimental/next_pluggable_device/BUILD @@ -87,7 +87,6 @@ tf_cc_test( "//tensorflow/core/tfrt/common:pjrt_util", "@com_google_absl//absl/log:check", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", @@ -98,5 +97,6 @@ tf_cc_test( "@local_xla//xla/pjrt/c:pjrt_c_api_hdrs", "@local_xla//xla/pjrt/c:pjrt_c_api_wrapper_impl", "@local_xla//xla/pjrt/cpu:cpu_client", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc index 7f45fd91a1baea..1952364d882776 100644 --- a/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc +++ b/tensorflow/c/experimental/next_pluggable_device/tensor_pjrt_buffer_util_test.cc @@ -30,10 +30,10 @@ limitations under the License. #include "xla/pjrt/pjrt_c_api_client.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/async_value_tensor.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/status_matchers.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index 7284261e862693..76f1db67bf2f09 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -22,11 +22,11 @@ cc_library( "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", - "//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/c/eager:tracing_utils", - "//tensorflow/core:framework", - "//tensorflow/core/platform:errors", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index db9464db41a5b6..23deef1d637f2d 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -17,11 +17,14 @@ limitations under the License. #include "tensorflow/c/experimental/ops/array_ops.h" +#include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/abstract_operation.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/tracing_utils.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tsl/platform/errors.h" using tensorflow::tracing::MaybeSetOpName; diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index f4d170ac98f402..466c36f1dde8ae 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -18,8 +18,11 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ #define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_ +#include "absl/types/span.h" #include "tensorflow/c/eager/abstract_context.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { namespace ops { diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc index 1fc16e093c011d..6392e30ce1fa07 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/cpp_config.cc @@ -24,7 +24,7 @@ namespace cpp { CppConfig::CppConfig(const string &category, const string &name_space) : category(category), - unit(str_util::Lowercase(category)), + unit(absl::AsciiStrToLower(category)), namespaces(absl::StrSplit(name_space, "::")) {} } // namespace cpp diff --git a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc index 41d1dea64b3689..ef54bca31c107b 100644 --- a/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc +++ b/tensorflow/c/experimental/ops/gen/cpp/renderers/renderer.cc @@ -47,7 +47,7 @@ Renderer& Renderer::CodeLines(const string& text) { } Renderer& Renderer::Statement(const string& text) { - if (str_util::EndsWith(text, ";")) { + if (absl::EndsWith(text, ";")) { LOG(WARNING) << "Superfluous terminating ';' in '" << text << "'"; context_.code.AddLineWithIndent(text); } else { diff --git a/tensorflow/c/experimental/ops/gen/model/BUILD b/tensorflow/c/experimental/ops/gen/model/BUILD index 918acaabb6b8cb..89e51ec57df46e 100644 --- a/tensorflow/c/experimental/ops/gen/model/BUILD +++ b/tensorflow/c/experimental/ops/gen/model/BUILD @@ -9,13 +9,10 @@ cc_library( srcs = glob(["*.cc"]), hdrs = glob(["*.h"]), deps = [ - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:str_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/c/experimental/ops/gen/model/arg_spec.cc b/tensorflow/c/experimental/ops/gen/model/arg_spec.cc index 2a9dd4882d92de..43e3b3f0b8bfa9 100644 --- a/tensorflow/c/experimental/ops/gen/model/arg_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/arg_spec.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/framework/op_def.pb.h" + namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/model/arg_type.cc b/tensorflow/c/experimental/ops/gen/model/arg_type.cc index afc05adc16788f..9286e2dd6f09cd 100644 --- a/tensorflow/c/experimental/ops/gen/model/arg_type.cc +++ b/tensorflow/c/experimental/ops/gen/model/arg_type.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/arg_type.h" +#include "tensorflow/core/framework/op_def.pb.h" + namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/ops/gen/model/attr_spec.cc b/tensorflow/c/experimental/ops/gen/model/attr_spec.cc index 3aec7acfe9791a..ae27a352694d98 100644 --- a/tensorflow/c/experimental/ops/gen/model/attr_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/attr_spec.cc @@ -14,7 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" -#include "tensorflow/core/lib/strings/str_util.h" +#include "absl/strings/match.h" +#include "tensorflow/core/framework/op_def.pb.h" namespace tensorflow { namespace generator { @@ -28,7 +29,7 @@ AttrSpec::AttrSpec(const OpDef::AttrDef& attr_def) { description_ = attr_def.description(); full_type_ = attr_def.type(); default_value_ = attr_def.default_value(); - if (str_util::StartsWith(full_type_, "list(")) { + if (absl::StartsWith(full_type_, "list(")) { is_list_ = true; // strip surrounding "list(%s)" base_type_ = full_type_.substr(5, full_type_.length() - 6); diff --git a/tensorflow/c/experimental/ops/gen/model/op_spec.cc b/tensorflow/c/experimental/ops/gen/model/op_spec.cc index d590e2dfddc80e..1adc0c45d40291 100644 --- a/tensorflow/c/experimental/ops/gen/model/op_spec.cc +++ b/tensorflow/c/experimental/ops/gen/model/op_spec.cc @@ -17,6 +17,11 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "tensorflow/c/experimental/ops/gen/model/arg_spec.h" +#include "tensorflow/c/experimental/ops/gen/model/attr_spec.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace generator { diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc index 58c7a22fee787c..630638b6c6cc9f 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler.cc @@ -175,7 +175,7 @@ Status InitPluginProfiler(TFInitProfilerFn init_fn) { return factory.CreatePluggableProfiler(options); }; - tensorflow::profiler::RegisterProfilerFactory(std::move(create_func)); + tsl::profiler::RegisterProfilerFactory(std::move(create_func)); return OkStatus(); } diff --git a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc index 3a6de512637928..463f64c10f4b8f 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/restore_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/cc/saved_model/constants.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/types.pb.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index d06608fcebe231..18d7498186cdb9 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/test_utils.h" #include "tensorflow/c/tensor_interface.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc index fe403d8c0bf530..36e4b8e5e66b67 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor.cc +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -22,8 +22,10 @@ limitations under the License. #include "tensorflow/c/experimental/stream_executor/stream_executor.h" #include +#include #include #include +#include #include "absl/functional/any_invocable.h" #include "absl/status/status.h" @@ -303,15 +305,6 @@ class CStreamExecutor : public StreamExecutorCommon { size, c_status.get()); return StatusFromTF_Status(c_status.get()); } - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, - uint64 size) override { - OwnedTFStatus c_status(TF_NewStatus()); - SP_Stream stream_handle = static_cast(stream)->Handle(); - SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location); - stream_executor_->memset(&device_, stream_handle, &device_mem, pattern, - size, c_status.get()); - return StatusFromTF_Status(c_status.get()); - } void DeallocateStream(Stream* stream) override { static_cast(stream)->Destroy(); } @@ -405,8 +398,7 @@ class CStreamExecutor : public StreamExecutorCommon { } absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override { + std::optional> priority) override { auto stream = std::make_unique(&device_, stream_executor_, this); TF_RETURN_IF_ERROR(stream->Create()); return std::move(stream); @@ -440,7 +432,6 @@ CPlatform::CPlatform(SP_Platform platform, name_(platform.name) {} CPlatform::~CPlatform() { - executor_cache_.DestroyAllExecutors(); platform_fns_.destroy_device_fns(&platform_, &device_fns_); platform_fns_.destroy_stream_executor(&platform_, &stream_executor_); platform_fns_.destroy_timer_fns(&platform_, &timer_fns_); @@ -457,24 +448,21 @@ CPlatform::DescriptionForDevice(int ordinal) const { builder.set_name(name_); return builder.Build(); } -absl::StatusOr CPlatform::ExecutorForDevice(int ordinal) { - stream_executor::StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); +absl::StatusOr CPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } -absl::StatusOr CPlatform::GetExecutor( - const StreamExecutorConfig& config) { +absl::StatusOr CPlatform::ExecutorForDevice(int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> CPlatform::GetUncachedExecutor( - const StreamExecutorConfig& config) { + int ordinal) { // Fill device creation params SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE}; SP_Device device{SP_DEVICE_STRUCT_SIZE}; device_params.device = &device; device_params.ext = nullptr; - device_params.ordinal = config.ordinal; + device_params.ordinal = ordinal; OwnedTFStatus c_status(TF_NewStatus()); // Create Device diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h index e3e025cf6902d6..769f640d6968d2 100644 --- a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ #include +#include #include #include @@ -97,14 +98,15 @@ class CPlatform : public Platform { absl::StatusOr> DescriptionForDevice( int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; - - void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } + absl::StatusOr FindExisting(int ordinal) override; private: + // Returns a device constructed with the ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + absl::StatusOr> GetUncachedExecutor( + int ordinal); + SP_Platform platform_; void (*destroy_platform_)(SP_Platform*); SP_PlatformFns platform_fns_; diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 0082c653335bab..20aded819a792c 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -794,10 +794,7 @@ int TF_GetDeviceId(TF_OpKernelContext* ctx) { #else const auto* device = reinterpret_cast( device_base->UnderlyingDevice()); - const absl::StatusOr id = tsl::GetDeviceIdFromDeviceParsedName( - device->parsed_name(), tensorflow::DeviceType(device->device_type())); - if (!id.ok()) return -1; - return *id; + return tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name()); #endif // defined(IS_MOBILE_PLATFORM) || defined(IS_SLIM_BUILD) } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index da27e61d380081..02fd6786698cd1 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -588,12 +588,12 @@ tf_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/cc/saved_model/bundle_v2_test.cc b/tensorflow/cc/saved_model/bundle_v2_test.cc index a0bbb82d704214..138028257b968a 100644 --- a/tensorflow/cc/saved_model/bundle_v2_test.cc +++ b/tensorflow/cc/saved_model/bundle_v2_test.cc @@ -28,10 +28,10 @@ limitations under the License. #include "json/reader.h" #include "json/value.h" #include "tensorflow/cc/saved_model/metrics.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc index 1f6b0e150850e4..3182afcf1803ee 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_utils_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc index d74c0cd2d531f7..eb4ef40b8927f6 100644 --- a/tensorflow/cc/tools/freeze_saved_model_test.cc +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/state_ops.h" #include "tensorflow/cc/saved_model/loader.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" namespace tensorflow { diff --git a/tensorflow/cc/training/coordinator_test.cc b/tensorflow/cc/training/coordinator_test.cc index e3f9f78988f01d..211fd1e68011e4 100644 --- a/tensorflow/cc/training/coordinator_test.cc +++ b/tensorflow/cc/training/coordinator_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/cc/training/coordinator.h" #include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/env.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { diff --git a/tensorflow/cc/training/queue_runner_test.cc b/tensorflow/cc/training/queue_runner_test.cc index 9a5f612b81ff81..f4de69b25a61b2 100644 --- a/tensorflow/cc/training/queue_runner_test.cc +++ b/tensorflow/cc/training/queue_runner_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/ops/random_ops.h" #include "tensorflow/cc/ops/state_ops.h" #include "tensorflow/cc/training/coordinator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" @@ -40,7 +41,6 @@ limitations under the License. #include "tensorflow/core/protobuf/queue_runner.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 4a6e752b984fd3..fc351b2cd829b5 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -167,17 +167,14 @@ cc_library( deps = [ ":tfcompile_lib", "//tensorflow/compiler/tf2xla:tf2xla_proto_cc", - "//tensorflow/compiler/tf2xla:tf2xla_util", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", - "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status", "@local_xla//xla:debug_options_flags", - "@local_xla//xla/service:compiler", ], ) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index e2ab2504319e80..b6b70a6f04d0f5 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -18,29 +18,18 @@ limitations under the License. #include #include -#include "absl/strings/match.h" -#include "absl/strings/str_join.h" +#include "absl/status/status.h" #include "absl/strings/string_view.h" -#include "tensorflow/compiler/aot/codegen.h" #include "tensorflow/compiler/aot/compile.h" #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "xla/debug_options_flags.h" -#include "xla/service/compiler.h" -#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/graph/tensor_id.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/strings/numbers.h" -#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/command_line_flags.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace tfcompile { diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index eddd237908fb16..6efe665f4c9f99 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -518,7 +518,6 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", - "//learning/brain/tfrt/tpu_plugin:__pkg__", "//learning/brain/tfrt/tpu_common:__pkg__", "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], @@ -539,9 +538,6 @@ cc_library( ":internal", # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", - "//learning/brain/tfrt/tpu_plugin:__pkg__", - "//learning/brain/tfrt/tpu_common:__pkg__", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", ], deps = [ ":variable_info", @@ -612,8 +608,6 @@ cc_library( # We reuse VariableInfo in TFRT's implementation of TpuExecuteOp. "//learning/brain/tfrt/tf_tpu:__pkg__", "//learning/brain/tfrt/tpu_plugin:__pkg__", - "//learning/brain/tfrt/tpu_common:__pkg__", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", "//tensorflow/core/tfrt/gpu/kernel:__pkg__", ], deps = [ @@ -678,7 +672,6 @@ tf_cc_test( "//tensorflow/core/tfrt/common:pjrt_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", @@ -686,6 +679,7 @@ tf_cc_test( "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/tests:literal_test_util", "@local_xla//xla/tsl/framework:device_id_utils", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -712,11 +706,11 @@ tf_cuda_only_cc_test( "//tensorflow/core/tfrt/common:pjrt_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/tests:literal_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -726,7 +720,6 @@ cc_library( hdrs = ["xla_compile_util.h"], visibility = [ ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", "//tensorflow/core/tfrt/gpu/kernel:__pkg__", ], deps = [ @@ -770,10 +763,7 @@ cc_library( name = "device_compiler", hdrs = ["device_compiler.h"], copts = tf_copts(), - visibility = [ - ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", - ], + visibility = [":internal"], deps = [ ":device_compilation_cache", ":device_compilation_cluster_signature", @@ -1118,7 +1108,6 @@ cc_library( ], visibility = [ ":internal", - "//tensorflow/core/tfrt/utils:__pkg__", "//third_party/cloud_tpu/inference_converter:__pkg__", "//waymo/onboard/ml/chauffeur_net:__pkg__", ], @@ -1564,10 +1553,7 @@ cc_library( name = "device_compiler_client", srcs = ["device_compiler_client.cc"], hdrs = ["device_compiler_client.h"], - visibility = [ - ":internal", - "//tensorflow/core/common_runtime/next_pluggable_device:__pkg__", - ], + visibility = [":internal"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core/util:determinism", @@ -1596,6 +1582,7 @@ cc_library( cc_library( name = "device_executable_persistor", + srcs = ["device_executable_persistor.cc"], hdrs = ["device_executable_persistor.h"], deps = [ ":xla_compilation_cache_proto_cc", @@ -1608,6 +1595,8 @@ cc_library( "//tensorflow/core/platform:status", "//tensorflow/core/platform:statusor", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:util", "@local_xla//xla/pjrt:pjrt_client", diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index ef46760f5065b4..5421637e80e5e0 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -734,7 +734,7 @@ static auto const ops_triggering_xla_compilation = "XlaVariadicSort", "XlaWhile"}; -static bool NodeCanTriggerXlaCompilation(const NodeDef& node) { +bool NodeCanTriggerXlaCompilation(const NodeDef& node) { return node.attr().find(kXlaClusterIdAttr) != node.attr().end() || HasBoolAttr(node, kXlaMustCompileAttr) || HasBoolAttr(node, kXlaCompileAttr) || diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 7c38cc92c541b7..18f6e5197b9cae 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -333,6 +333,9 @@ tensorflow::MemoryTypeVector GetOutputMemoryTypes( // Check whether graph can trigger XLA compilation. bool CanTriggerXlaCompilation(const GraphDef& graph); +// Returns true iff the node can trigger XLA compilation. +bool NodeCanTriggerXlaCompilation(const NodeDef& node); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_COMPILABILITY_CHECK_UTIL_H_ diff --git a/tensorflow/compiler/jit/device_context_test.cc b/tensorflow/compiler/jit/device_context_test.cc index be85ff99586f55..d02337d36d7a35 100644 --- a/tensorflow/compiler/jit/device_context_test.cc +++ b/tensorflow/compiler/jit/device_context_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/jit/device_executable_persistor.cc b/tensorflow/compiler/jit/device_executable_persistor.cc new file mode 100644 index 00000000000000..b673af75cbdcd9 --- /dev/null +++ b/tensorflow/compiler/jit/device_executable_persistor.cc @@ -0,0 +1,37 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/device_executable_persistor.h" + +#include + +#include "absl/strings/str_cat.h" + +namespace tensorflow { + +std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key) { + static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; + return absl::StrCat( + key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, + key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, + key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, + key.device_type(), + key.compiled_using_pjrt() + ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt") + : "", + ".pb"); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_executable_persistor.h b/tensorflow/compiler/jit/device_executable_persistor.h index 78d208942ed770..0f546c0f196acc 100644 --- a/tensorflow/compiler/jit/device_executable_persistor.h +++ b/tensorflow/compiler/jit/device_executable_persistor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" #include "tensorflow/compiler/jit/xla_device_compiler_client.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" @@ -35,6 +36,9 @@ limitations under the License. namespace tensorflow { +// Returns the persisted compilation cache file name for the given key. +std::string XlaSerializedCacheKeyToFileName(const XlaSerializedCacheKey& key); + // Offers a way to persist and/or load compiled `ExecutableType`s along with the // corresponding HLO (`CompilationResult`) to/from `persistent_cache_directory` // (if one was provided during construction) on disk using `ClientType`. @@ -142,8 +146,6 @@ class DeviceExecutablePersistor { const xla::HloModuleProto& hlo_module, const XlaSerializedCacheEntry& entry) const; - std::string XlaSerializedCacheKeyToString( - const XlaSerializedCacheKey& key) const; std::string GetFilePath(const XlaSerializedCacheKey& key) const; const DeviceType device_type_; @@ -172,25 +174,10 @@ DeviceExecutablePersistor:: persistent_cache_directory_read_only_( config.persistent_cache_directory_read_only) {} -template -std::string DeviceExecutablePersistor:: - XlaSerializedCacheKeyToString(const XlaSerializedCacheKey& key) const { - static constexpr char kXlaSerializedCacheKeySeparator[] = "__"; - return absl::StrCat( - key.prefix(), key.prefix().empty() ? "" : kXlaSerializedCacheKeySeparator, - key.signature_fingerprint(), kXlaSerializedCacheKeySeparator, - key.cluster_fingerprint(), kXlaSerializedCacheKeySeparator, - key.device_type(), - key.compiled_using_pjrt() - ? absl::StrCat(kXlaSerializedCacheKeySeparator, "pjrt") - : ""); -} - template std::string DeviceExecutablePersistor::GetFilePath( const XlaSerializedCacheKey& key) const { - const std::string file_name = - absl::StrCat(XlaSerializedCacheKeyToString(key), ".pb"); + const std::string file_name = XlaSerializedCacheKeyToFileName(key); return io::JoinPath(persistent_cache_directory_, file_name); } @@ -299,9 +286,10 @@ DeviceExecutablePersistor::SaveSerializedEntry( // Write to temp location, then when that completes, atomically move into the // final location. - std::string temp_path = io::JoinPath( - persistent_cache_directory_, XlaSerializedCacheKeyToString(entry.key())); - if (!env->CreateUniqueFileName(&temp_path, ".pb.tmp")) { + std::string temp_path = + io::JoinPath(persistent_cache_directory_, + XlaSerializedCacheKeyToFileName(entry.key())); + if (!env->CreateUniqueFileName(&temp_path, ".tmp")) { return absl::UnavailableError(absl::StrCat( "Could not create a unique file inside ", persistent_cache_directory_)); } diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc index 1ae4f6d4cd9938..0ef7156ef9f593 100644 --- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc +++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc @@ -280,7 +280,6 @@ Status ConvertTensorFlowSliceToStaticShapedSlice( void ReplaceTensorFlowSliceWithStaticShapedSlice(Graph* g, Node* slice, Node* static_shaped_slice) { - absl::InlinedVector edges_to_remove; std::vector slice_out_edges; absl::c_copy(slice->out_edges(), std::back_inserter(slice_out_edges)); for (const Edge* e : slice_out_edges) { diff --git a/tensorflow/compiler/jit/pjrt_device_context.cc b/tensorflow/compiler/jit/pjrt_device_context.cc index 51b6e5770ce592..794f32d3fea9a1 100644 --- a/tensorflow/compiler/jit/pjrt_device_context.cc +++ b/tensorflow/compiler/jit/pjrt_device_context.cc @@ -52,12 +52,10 @@ absl::StatusOr> HostTensorToPjRtBuffer( cpu_tensor->shape(), cpu_tensor->dtype(), /*fast_mem=*/false, layout_preference)); const xla::Layout* device_layout = &(shape.layout()); - // The device id should match the local_hardware_id in + // The device id should match the local_device_id in // tensorflow/compiler/xla/pjrt/pjrt_client.h. - TF_ASSIGN_OR_RETURN( - const int pjrt_device_id, - tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name(), - DeviceType(device->device_type()))); + const int pjrt_device_id = + tsl::GetDeviceIdFromDeviceParsedName(device->parsed_name()); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * pjrt_device, pjrt_client->LookupAddressableDevice( xla::PjRtLocalDeviceId(pjrt_device_id))); @@ -260,12 +258,10 @@ void PjRtDeviceToDeviceCopy(DeviceContext* send_dev_context, xla::PjRtBuffer* src_device_buffer = tensorflow::AsyncValueTensor::FromTensor(input)->GetBuffer().get(); - // The device id should match the local_hardware_id in + // The device id should match the local_device_id in // tensorflow/compiler/xla/pjrt/pjrt_client.h. const int pjrt_dst_device_id = - tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name(), - DeviceType(dst->device_type())) - .value(); + tsl::GetDeviceIdFromDeviceParsedName(dst->parsed_name()); xla::PjRtDevice* pjrt_dst_device = (*pjrt_dst_client) ->LookupAddressableDevice(xla::PjRtLocalDeviceId(pjrt_dst_device_id)) diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc index 3f96101da79c00..eaabf18c79603c 100644 --- a/tensorflow/compiler/jit/shape_inference_test.cc +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -36,7 +37,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc index 62da04c3e7510f..bec124f1866689 100644 --- a/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc +++ b/tensorflow/compiler/jit/xla_host_send_recv_device_context_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index ebeeaef7fe483e..27a8f16b5f1323 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -853,9 +853,8 @@ Status RunPjRtExecutable( ->use_pjrt_tensor_buffer; const DeviceType& device_type = GetDeviceType(ctx); - TF_ASSIGN_OR_RETURN(const int pjrt_device_id, - tsl::GetDeviceIdFromDeviceParsedName( - ctx->device()->parsed_name(), device_type)); + const int pjrt_device_id = + tsl::GetDeviceIdFromDeviceParsedName(ctx->device()->parsed_name()); TF_ASSIGN_OR_RETURN(xla::PjRtDevice * device, pjrt_client->LookupAddressableDevice( xla::PjRtLocalDeviceId(pjrt_device_id))); diff --git a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc index 0ba66c2e2617bc..563e75c5d61b28 100644 --- a/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_gpu_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/tests/literal_test_util.h" #include "xla/tsl/framework/allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/fake_input.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/jit/xla_launch_util_test.cc b/tensorflow/compiler/jit/xla_launch_util_test.cc index d19e4fc2548bb1..443fdf3d1999a7 100644 --- a/tensorflow/compiler/jit/xla_launch_util_test.cc +++ b/tensorflow/compiler/jit/xla_launch_util_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tsl/framework/allocator.h" #include "xla/tsl/framework/device_id_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/fake_input.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/tfrt/common/create_pjrt_client_util.h" #include "tensorflow/core/tfrt/common/pjrt_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -674,9 +674,8 @@ TEST_F(PjRtExecutionUtilTest, RunPjRtExecutableWithoutCtx) { ->tensorflow_accelerator_device_info() ->use_pjrt_tensor_buffer; const DeviceType& device_type = GetDeviceType(context_.get()); - TF_ASSERT_OK_AND_ASSIGN(const int pjrt_device_id, - tsl::GetDeviceIdFromDeviceParsedName( - context_->device()->parsed_name(), device_type)); + const int pjrt_device_id = + tsl::GetDeviceIdFromDeviceParsedName(context_->device()->parsed_name()); TF_ASSERT_OK_AND_ASSIGN(xla::PjRtDevice * pjrt_device, pjrt_client_->LookupAddressableDevice( xla::PjRtLocalDeviceId(pjrt_device_id))); diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index e689b4c0b3191c..c87dc83bdde956 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -7,6 +7,10 @@ """ load("@bazel_skylib//lib:paths.bzl", "paths") +load( + "@local_xla//xla:lit.bzl", + "lit_script_with_xla_gpu_cuda_data_dir", +) # Default values used by the test runner. _default_test_file_exts = ["mlir", ".pbtxt", ".td"] @@ -76,7 +80,8 @@ def glob_lit_tests( tags_override = {}, driver = _default_driver, features = [], - exec_properties = {}): + exec_properties = {}, + hermetic_cuda_data_dir = None): """Creates all plausible Lit tests (and their inputs) under this directory. Args: @@ -94,6 +99,8 @@ def glob_lit_tests( and specifying a default driver will abort the tests. features: [str], list of extra features to enable. exec_properties: a dictionary of properties to pass on. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. """ # Ignore some patterns by default for tests and input data. @@ -108,12 +115,24 @@ def glob_lit_tests( # failure. all_tests = [] for curr_test in tests: - all_tests.append(curr_test + ".test") + final_test_name = curr_test + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(curr_test) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + curr_test, + output_file, + hermetic_cuda_data_dir, + ) + final_test_name = output_file + all_tests.append(final_test_name + ".test") # Instantiate this test with updated parameters. _run_lit_test( - name = curr_test + ".test", - data = data + [curr_test] + per_test_extra_data.get(curr_test, []), + name = final_test_name + ".test", + data = data + [final_test_name] + + per_test_extra_data.get(curr_test, []), size = size_override.get(curr_test, default_size), tags = default_tags + tags_override.get(curr_test, []), driver = driver, diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 78e0d07a29e907..78699e8418cf5e 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -31,6 +31,16 @@ package_group( ], ) +filegroup( + name = "tflite_internal_cc_3p_api_deps_src", + srcs = [ + "allocation.cc", + "allocation.h", + "mmap_allocation.cc", + ], + visibility = ["//tensorflow/lite:__pkg__"], +) + td_library( name = "tensorflow_lite_ops_td_files", srcs = [ @@ -81,7 +91,7 @@ gentbl_cc_library( ( [ "-gen-pass-decls", - "-name=TensorFlowLite", + "-name=TensorFlowLiteTd", ], "transforms/passes.h.inc", ), @@ -318,6 +328,13 @@ cc_library( ], ) +cc_library( + name = "stateful_error_reporter", + hdrs = ["stateful_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"], +) + gentbl_cc_library( name = "tensorflow_lite_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), @@ -333,9 +350,29 @@ gentbl_cc_library( ) cc_library( - name = "tensorflow_lite", + name = "utils", + hdrs = ["utils/utils.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "attribute_utils", + srcs = ["utils/attribute_utils.cc"], + hdrs = ["utils/attribute_utils.h"], + deps = [ + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tensorflow_lite_ops", srcs = [ - "ir/tfl_canonicalize.inc", "ir/tfl_ops.cc", "ir/tfl_ops.cc.inc", "ir/tfl_ops.h.inc", @@ -347,22 +384,75 @@ cc_library( "ir/tfl_ops_interface.cc.inc", "ir/tfl_ops_interface.h.inc", "runtime_verifiers.inc", - "utils/attribute_utils.cc", ], hdrs = [ "ir/tfl_ops.h", + ], + deps = [ + ":converter_inc", + ":cost_estimators", + ":size_utils", + ":tensorflow_lite_canonicalize_inc_gen", + ":tensorflow_lite_op_enums_inc_gen", + ":tensorflow_lite_op_interfaces_inc_gen", + ":tensorflow_lite_ops_inc_gen", + ":utils", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:framework", + "//tensorflow/core/platform:status", + "//tensorflow/lite/schema:schema_fbs", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ControlFlowInterfaces", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:LoopLikeInterface", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +cc_library( + name = "tensorflow_lite", + srcs = [ + "ir/tfl_canonicalize.inc", + ], + hdrs = [ + "ir/tfl_ops.h", + "transforms/optimize.h", "transforms/passes.h", "utils/attribute_utils.h", "utils/utils.h", ], deps = [ + ":attribute_utils", ":converter_inc", ":cost_estimators", ":size_utils", ":tensorflow_lite_canonicalize_inc_gen", ":tensorflow_lite_op_enums_inc_gen", ":tensorflow_lite_op_interfaces_inc_gen", + ":tensorflow_lite_ops", ":tensorflow_lite_ops_inc_gen", + ":tensorflow_lite_optimize", ":tensorflow_lite_passes_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -389,6 +479,7 @@ cc_library( "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -404,9 +495,9 @@ cc_library( "utils/variables_utils.h", ], deps = [ - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", ], @@ -464,7 +555,9 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:status", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BytecodeOpInterface", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -513,11 +606,12 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", - "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:Support", ], ) @@ -532,11 +626,8 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", - "//tensorflow/core:framework", "@flatbuffers", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -553,9 +644,10 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", + "//tensorflow/core/ir/types:Dialect", "@flatbuffers", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", @@ -573,6 +665,8 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:status", ], ) @@ -586,8 +680,9 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/lite/core/c:common", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "@flatbuffers//:runtime_cc", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -604,9 +699,9 @@ cc_library( ], deps = [ ":tensorflow_lite", - "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -616,6 +711,7 @@ tf_cc_test( srcs = ["utils/lstm_utils_test.cc"], deps = [ ":lstm_utils", + ":tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -635,14 +731,13 @@ tf_cc_test( deps = [ ":perception_ops_utils", ":tensorflow_lite", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -768,13 +863,16 @@ cc_library( "transforms/optimize.cc", ], hdrs = [ - "transforms/passes.h", + "transforms/optimize.h", ], deps = [ + ":attribute_utils", ":constant_utils", ":convert_type", - ":tensorflow_lite", + ":tensorflow_lite_ops", + ":tensorflow_lite_optimize_inc_gen", ":tensorflow_lite_passes_inc_gen", + ":utils", ":validators", "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", @@ -942,8 +1040,8 @@ cc_library( deps = [ ":tensorflow_lite", ":tensorflow_lite_passes_inc_gen", + "//tensorflow/compiler/mlir/lite/kernels/internal/utils:sparsity_format_converter", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "//tensorflow/lite/kernels/internal/utils:sparsity_format_converter", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@eigen_archive//:eigen3", @@ -1061,7 +1159,7 @@ cc_library( ":convert_type", ":converter_inc", ":tensorflow_lite", - "//tensorflow/compiler/mlir/lite/core/c:private_common", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", @@ -1090,8 +1188,8 @@ tf_native_cc_binary( name = "flatbuffer_to_string", srcs = ["flatbuffer_to_string.cc"], deps = [ + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_reflection", - "//tensorflow/lite/core:model_builder", "@flatbuffers", ], ) @@ -1125,7 +1223,7 @@ cc_library( "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite:control_edges", "//tensorflow/compiler/mlir/lite/core:macros", - "//tensorflow/compiler/mlir/lite/core/c:private_common", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/compiler/mlir/lite/delegates/flex:allowlisted_flex_ops_lib", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", @@ -1141,7 +1239,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:portable_gif_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/core:framework", "//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/tools/versioning", "//tensorflow/lite/tools/versioning:gpu_compatibility", @@ -1184,8 +1281,10 @@ cc_library( ":size_utils", ":tensorflow_lite", "//tensorflow/compiler/mlir/lite:control_edges", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", @@ -1199,8 +1298,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", - "//tensorflow/core/platform:status", - "//tensorflow/lite:model_builder", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", @@ -1232,12 +1329,13 @@ cc_library( "utils/convert_type.h", ], deps = [ - ":tensorflow_lite", + ":tensorflow_lite_ops", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", ], @@ -1318,6 +1416,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/lite/toco:toco_flags_proto_cc", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], @@ -1351,7 +1450,6 @@ tf_cc_binary( "//tensorflow/core:core_cpu_base", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", - "//tensorflow/lite:framework", "//tensorflow/lite/toco:toco_flags_proto_cc", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -1385,10 +1483,12 @@ cc_library( ":tensorflow_lite_quantize", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", + "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", "//tensorflow/compiler/mlir/lite/stablehlo:composite_lowering", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", + "//tensorflow/compiler/mlir/lite/stablehlo:lift_callsite_loc_caller", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep @@ -1426,8 +1526,10 @@ cc_library( "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/lite/debug", + "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", @@ -1455,7 +1557,6 @@ cc_library( "//tensorflow/core/ir/types:Dialect", "//tensorflow/core/platform:status", "//tensorflow/lite/toco:toco_flags_proto_cc", - "//tensorflow/lite/tools/optimize:quantize_weights", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1473,6 +1574,7 @@ cc_library( "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@stablehlo//:stablehlo_ops", @@ -1480,22 +1582,6 @@ cc_library( ], ) -cc_library( - name = "empty_passes", - hdrs = ["transforms/passes.h"], - visibility = [ - "//configs/devtools/hawkeye/tflite:__subpackages__", - "//learning/brain/models/app_benchmarks:__subpackages__", - ], - deps = [ - ":tensorflow_lite_passes_inc_gen", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Pass", - ], -) - cc_library( name = "offset_buffer", hdrs = ["offset_buffer.h"], @@ -1535,6 +1621,32 @@ cc_library( visibility = ["//tensorflow/lite:__pkg__"], ) +exports_files(srcs = ["allocation.h"]) + +cc_library( + name = "allocation", + srcs = [ + "allocation.cc", + ] + select({ + ":tflite_mmap_disabled": [ + "mmap_allocation_disabled.cc", + ], + "//conditions:default": [ + "mmap_allocation.cc", + ], + }), + hdrs = [ + "allocation.h", + ], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts_warnings(), + visibility = [ + "//tensorflow/compiler/mlir/lite/core:__pkg__", + "//tensorflow/lite:__pkg__", + ], + deps = ["//tensorflow/compiler/mlir/lite/core/api:error_reporter"], +) + exports_files(srcs = ["utils/control_edges.h"]) cc_library( diff --git a/tensorflow/lite/allocation.cc b/tensorflow/compiler/mlir/lite/allocation.cc similarity index 77% rename from tensorflow/lite/allocation.cc rename to tensorflow/compiler/mlir/lite/allocation.cc index b187ef093b5c77..3cad6908c889ad 100644 --- a/tensorflow/lite/allocation.cc +++ b/tensorflow/compiler/mlir/lite/allocation.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" #include #include @@ -21,9 +21,11 @@ limitations under the License. #include #include +#include +#include #include -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { @@ -100,11 +102,37 @@ MemoryAllocation::MemoryAllocation(const void* ptr, size_t num_bytes, } #endif // __arm__ +// `android_local_test` doesn't support zipalign b/356640509 so we need this +// workaround to keep our clients working. +// TODO: b/356413060 - Remove the workaround once b/356640509 is fixed. +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + if ((reinterpret_cast(ptr) & 0x3) != 0) { + aligned_ptr_ = ::aligned_alloc(4, num_bytes); + if (aligned_ptr_ == nullptr) { + TF_LITE_REPORT_ERROR(error_reporter, "Failed to allocate aligned buffer"); + buffer_ = nullptr; + buffer_size_bytes_ = 0; + return; + } + memcpy(aligned_ptr_, ptr, num_bytes); + buffer_ = aligned_ptr_; + } else { + buffer_ = ptr; + } +#else // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) buffer_ = ptr; +#endif // defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + buffer_size_bytes_ = num_bytes; } -MemoryAllocation::~MemoryAllocation() {} +MemoryAllocation::~MemoryAllocation() { +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + if (aligned_ptr_) { + free(aligned_ptr_); + } +#endif +} const void* MemoryAllocation::base() const { return buffer_; } diff --git a/tensorflow/compiler/mlir/lite/allocation.h b/tensorflow/compiler/mlir/lite/allocation.h new file mode 100644 index 00000000000000..9ee9f4e846b71e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/allocation.h @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Memory management for TF Lite. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// A memory allocation handle. This could be a mmap or shared memory. +class Allocation { + public: + virtual ~Allocation() {} + + enum class Type { + kMMap, + kFileCopy, + kMemory, + }; + + /// Base pointer of this allocation + virtual const void* base() const = 0; + /// Size in bytes of the allocation + virtual size_t bytes() const = 0; + /// Whether the allocation is valid + virtual bool valid() const = 0; + /// Return the type of the Allocation. + Type type() const { return type_; } + + protected: + Allocation(ErrorReporter* error_reporter, Type type) + : error_reporter_(error_reporter), type_(type) {} + ErrorReporter* error_reporter_; + + private: + const Type type_; +}; + +/// Note that not all platforms support MMAP-based allocation. +/// Use `IsSupported()` to check. +class MMAPAllocation : public Allocation { + public: + /// Loads and maps the provided file to a memory region. + MMAPAllocation(const char* filename, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, ErrorReporter* error_reporter); + + /// Maps the provided file descriptor, with the given offset and length (both + /// in bytes), to a memory region. + /// Note: The provided file descriptor will be dup'ed for usage; the caller + /// retains ownership of the provided descriptor and should close accordingly. + MMAPAllocation(int fd, size_t offset, size_t length, + ErrorReporter* error_reporter); + + ~MMAPAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + int fd() const { return mmap_fd_; } + + // The start address of the mmapped buffer. + // This will be base() rounded down to the nearest page boundary. + const void* mmapped_buffer() const { return mmapped_buffer_; } + + // The size of the mmapped buffer. + size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } + + // Offset of mmapped_buffer() in the file referenced by the file descriptor. + size_t mmapped_buffer_offset_in_file() const { + return offset_of_buffer_in_file_; + } + + static bool IsSupported(); + + protected: + // Data required for mmap. + int mmap_fd_ = -1; // mmap file descriptor + const void* mmapped_buffer_; + size_t buffer_size_bytes_ = 0; + // Used when the address to mmap is not page-aligned. + size_t offset_in_buffer_ = 0; + size_t offset_of_buffer_in_file_ = 0; + + private: + // Assumes ownership of the provided `owned_fd` instance. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); + + // Assumes ownership of the provided `owned_fd` instance, and uses the given + // offset and length (both in bytes) for memory mapping. + MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, + size_t length); +}; + +class FileCopyAllocation : public Allocation { + public: + /// Loads the provided file into a heap memory region. + FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); + ~FileCopyAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + std::unique_ptr copied_buffer_; + size_t buffer_size_bytes_ = 0; +}; + +class MemoryAllocation : public Allocation { + public: + /// Provides a (read-only) view of the provided buffer region as an + /// allocation. + /// Note: The caller retains ownership of `ptr`, and must ensure it remains + /// valid for the lifetime of the class instance. + MemoryAllocation(const void* ptr, size_t num_bytes, + ErrorReporter* error_reporter); + ~MemoryAllocation() override; + const void* base() const override; + size_t bytes() const override; + bool valid() const override; + + private: + const void* buffer_; +#if defined(__x86_64__) && defined(UNDEFINED_BEHAVIOR_SANITIZER) + void* aligned_ptr_ = nullptr; +#endif + size_t buffer_size_bytes_ = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_ALLOCATION_H_ diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 1149d7841b38fd..cdf20cc0913e38 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/lite/toco/toco_flags.pb.h" namespace mlir { namespace TFL { @@ -57,10 +58,6 @@ struct PassConfig { // Whether to enable TFLite variables or not, this will allow // mutable variables and produce ReadVariable/AssignVariable ops in TFLite. bool enable_tflite_variables = false; - // Whether to disable the variable freezing pass or not. - // By default we freeze all variables and disallow mutable variables. When - // 'enable_tflite_variables' is true then we allow mutable variable only. - bool disable_variable_freezing = false; // Whether to unfold large splat constant tensors and replace them with // fill operation. bool unfold_large_splat_constant = false; @@ -102,6 +99,10 @@ struct PassConfig { // Enables the attempt to directly lower composites into tflite ops. bool enable_composite_direct_lowering = true; + + // Specifies the framework of the original model. + toco::TocoFlags::ModelOriginFramework model_origin_framework = + toco::TocoFlags::UNSET; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -118,8 +119,6 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << "\nruntime_verification: " << pass_config.runtime_verification << "\nenable_tflite_variables: " << pass_config.enable_tflite_variables - << "\ndisable_variable_freezing: " - << pass_config.disable_variable_freezing << "\nunfold_large_splat_constant: " << pass_config.unfold_large_splat_constant << "\nguarantee_all_funcs_one_use: " @@ -132,7 +131,11 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << pass_config.legalize_custom_tensor_list_ops << "\nreduce_type_precision: " << pass_config.reduce_type_precision << "\nconvert_qdq_format: " - << GetQDQQuantModeString(pass_config.qdq_conversion_mode) << "\n"; + << GetQDQQuantModeString(pass_config.qdq_conversion_mode) + << "\nmodel_origin_framework: " + << toco::TocoFlags::ModelOriginFramework_Name( + pass_config.model_origin_framework) + << "\n"; } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/core/BUILD b/tensorflow/compiler/mlir/lite/core/BUILD index 184210da130ec0..d76299aa723d51 100644 --- a/tensorflow/compiler/mlir/lite/core/BUILD +++ b/tensorflow/compiler/mlir/lite/core/BUILD @@ -32,13 +32,28 @@ cc_library( ], deps = [ ":macros", + "//tensorflow/compiler/mlir/lite:allocation", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/core/api:verifier", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/lite:allocation", - "//tensorflow/lite:string", - "//tensorflow/lite/core/api:error_reporter", - "//tensorflow/lite/core/api:verifier", "@com_google_absl//absl/strings", "@flatbuffers", ], alwayslink = 1, ) + +cc_library( + name = "absl_error_model_builder", + srcs = ["absl_error_model_builder.cc"], + hdrs = ["absl_error_model_builder.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts_warnings(), + visibility = [ + "//tensorflow/compiler/mlir/lite:__subpackages__", + ], + deps = [ + ":model_builder_base", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "@com_google_absl//absl/log", + ], +) diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc new file mode 100644 index 00000000000000..269d81efc0e73e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.cc @@ -0,0 +1,40 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" + +#include +#include + +#include "absl/log/log.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace mlir::TFL { + +int AbslErrorReporter::Report(const char* format, va_list args) { + char buffer[1024]; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" + vsprintf(buffer, format, args); +#pragma clang diagnostic pop + LOG(ERROR) << buffer; + return 0; +} + +tflite::ErrorReporter* GetAbslErrorReporter() { + static AbslErrorReporter* error_reporter = new AbslErrorReporter; + return error_reporter; +} + +} // namespace mlir::TFL diff --git a/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h new file mode 100644 index 00000000000000..c3d76e2b03f820 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ + +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/model_builder_base.h" + +namespace mlir::TFL { + +// An error reporter that uses absl logging. +class AbslErrorReporter : public tflite::ErrorReporter { + int Report(const char* format, va_list args) override; +}; + +tflite::ErrorReporter* GetAbslErrorReporter(); + +class FlatBufferModelAbslError + : public tflite::impl::FlatBufferModelBase { + public: + // Use stderr_reporter as the default error reporter. + static tflite::ErrorReporter* GetDefaultErrorReporter() { + return GetAbslErrorReporter(); + } + + // Inherit all constructors from FlatBufferModelBase since inherited factory + // methods refer to them. + using FlatBufferModelBase::FlatBufferModelBase; +}; + +} // namespace mlir::TFL + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_ABSL_ERROR_MODEL_BUILDER_H_ diff --git a/tensorflow/compiler/mlir/lite/core/api/BUILD b/tensorflow/compiler/mlir/lite/core/api/BUILD new file mode 100644 index 00000000000000..0aaca3928420d6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/BUILD @@ -0,0 +1,54 @@ +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/lite:__subpackages__", + "//tensorflow/lite:__subpackages__", + ], + licenses = ["notice"], +) + +exports_files(["error_reporter.h"]) + +filegroup( + name = "tflite_internal_cc_3p_api_deps_src", + srcs = [ + "error_reporter.cc", + "error_reporter.h", + "verifier.h", + ], + visibility = ["//tensorflow/lite:__pkg__"], +) + +cc_library( + name = "error_reporter", + srcs = ["error_reporter.cc"], + hdrs = ["error_reporter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + deps = [], +) + +exports_files(["verifier.h"]) + +cc_library( + name = "verifier", + hdrs = ["verifier.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + visibility = ["//visibility:public"], + deps = [":error_reporter"], +) + +tf_cc_test( + name = "error_reporter_test", + size = "small", + srcs = ["error_reporter_test.cc"], + deps = [ + ":error_reporter", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/tensorflow/lite/core/api/error_reporter.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc similarity index 94% rename from tensorflow/lite/core/api/error_reporter.cc rename to tensorflow/compiler/mlir/lite/core/api/error_reporter.cc index 7070eaa57c589a..96f7561d1440f3 100644 --- a/tensorflow/lite/core/api/error_reporter.cc +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + #include namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/core/api/error_reporter.h b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h new file mode 100644 index 00000000000000..79c9fc9365e44a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter.h @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ + +#include + +namespace tflite { + +/// A functor that reports error to supporting system. Invoked similar to +/// printf. +/// +/// Usage: +/// ErrorReporter foo; +/// foo.Report("test %d", 5); +/// or +/// va_list args; +/// foo.Report("test %d", args); // where args is va_list +/// +/// Subclass ErrorReporter to provide another reporting destination. +/// For example, if you have a GUI program, you might redirect to a buffer +/// that drives a GUI error log box. +class ErrorReporter { + public: + virtual ~ErrorReporter() = default; + /// Converts `args` to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + virtual int Report(const char* format, va_list args) = 0; + + /// Converts arguments to character equivalents according to `format` string, + /// constructs the error string and report it. + /// Returns number of characters written or zero on success, and negative + /// number on error. + int Report(const char* format, ...); + + /// Equivalent to `Report` above. The additional `void*` parameter is unused. + /// This method is for compatibility with macros that takes `TfLiteContext`, + /// like TF_LITE_ENSURE and related macros. + int ReportError(void*, const char* format, ...); +}; + +} // namespace tflite + +// You should not make bare calls to the error reporter, instead use the +// TF_LITE_REPORT_ERROR macro, since this allows message strings to be +// stripped when the binary size has to be optimized. If you are looking to +// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and +// every call will be stubbed out, taking no memory. +#ifndef TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) \ + do { \ + static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ + } while (false) +#else // TF_LITE_STRIP_ERROR_STRINGS +#define TF_LITE_REPORT_ERROR(reporter, ...) +#endif // TF_LITE_STRIP_ERROR_STRINGS + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/core/api/error_reporter_test.cc b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc similarity index 96% rename from tensorflow/lite/core/api/error_reporter_test.cc rename to tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc index 03d6da734eae7d..ca7c4a2bb82ff8 100644 --- a/tensorflow/lite/core/api/error_reporter_test.cc +++ b/tensorflow/compiler/mlir/lite/core/api/error_reporter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include diff --git a/tensorflow/compiler/mlir/lite/core/api/verifier.h b/tensorflow/compiler/mlir/lite/core/api/verifier.h new file mode 100644 index 00000000000000..2e24347dd626e4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/api/verifier.h @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/// \file +/// +/// Abstract interface for verifying a model. +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite { + +/// Abstract interface that verifies whether a given model is legit. +/// It facilitates the use-case to verify and build a model without loading it +/// twice. +/// (See also "tensorflow/lite/tools/verifier.h".) +class TfLiteVerifier { + public: + /// Returns true if the model is legit. + virtual bool Verify(const char* data, int length, + ErrorReporter* reporter) = 0; + virtual ~TfLiteVerifier() {} +}; + +} // namespace tflite + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/compiler/mlir/lite/core/c/BUILD b/tensorflow/compiler/mlir/lite/core/c/BUILD index 8368a273de4141..3338e5b8940fca 100644 --- a/tensorflow/compiler/mlir/lite/core/c/BUILD +++ b/tensorflow/compiler/mlir/lite/core/c/BUILD @@ -11,9 +11,12 @@ package( # LINT.IfChange(common) cc_library( - name = "common", + name = "tflite_common", srcs = [], - hdrs = ["builtin_op_data.h"], + hdrs = [ + "builtin_op_data.h", + "dimension_type.h", + ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ @@ -23,17 +26,3 @@ cc_library( alwayslink = 1, # Why?? TODO(b/161243354): eliminate this. ) # LINT.ThenChange(//tensorflow/lite/core/c:common) - -# LINT.IfChange(private_common) -# This is a private target, its visibility is set to public only to be -# used by "tflite_custom_c_library" and "tflite_flex_cc_library". -# Do not use this target directly and don't consider it as a part of the public API. -alias( - name = "private_common", - actual = ":common", - tags = ["avoid_dep"], - visibility = [ - "//visibility:public", - ], -) -# LINT.ThenChange(//tensorflow/lite/core/c:private_common) diff --git a/tensorflow/compiler/mlir/lite/core/c/dimension_type.h b/tensorflow/compiler/mlir/lite/core/c/dimension_type.h new file mode 100644 index 00000000000000..fd2c6122897065 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/core/c/dimension_type.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ + +// LINT.IfChange + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + + +/// Storage format of each dimension in a sparse tensor. +typedef enum TfLiteDimensionType { + kTfLiteDimDense = 0, + kTfLiteDimSparseCSR, +} TfLiteDimensionType; + +#ifdef __cplusplus +} // extern "C" + +#endif // __cplusplus +#endif // TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_DIMENSION_TYPE_H_ + +// LINT.ThenChange(//tensorflow/lite/core/c/common.h) diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc index 28306ca8684e49..2ad2b93329be16 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.cc +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "absl/strings/str_cat.h" -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/core/model_builder_base.h b/tensorflow/compiler/mlir/lite/core/model_builder_base.h index 0e394eaf0bd0ab..aabd4f959a992d 100644 --- a/tensorflow/compiler/mlir/lite/core/model_builder_base.h +++ b/tensorflow/compiler/mlir/lite/core/model_builder_base.h @@ -40,12 +40,11 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers #include "flatbuffers/verifier.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" #include "tensorflow/compiler/mlir/lite/core/macros.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/api/verifier.h" -#include "tensorflow/lite/string_type.h" namespace tflite { @@ -468,7 +467,8 @@ class FlatBufferModelBase { // '\0's in the buffer. for (int len = 0; len < array->size(); ++len) { if (array->data()[len] == '\0') { - return string(reinterpret_cast(array->data()), len); + return std::string(reinterpret_cast(array->data()), + len); } } // If there is no '\0' in the buffer, this indicates that the flatbuffer @@ -503,8 +503,8 @@ class FlatBufferModelBase { if (!buffer || !buffer->data()) continue; const flatbuffers::Vector* array = buffer->data(); if (!array) continue; - std::string val = - string(reinterpret_cast(array->data()), array->size()); + std::string val = std::string( + reinterpret_cast(array->data()), array->size()); // Skip if key or value of metadata is empty. if (!metadata->name() || val.empty()) continue; keys_values[metadata->name()->str()] = val; diff --git a/tensorflow/compiler/mlir/lite/debug/BUILD b/tensorflow/compiler/mlir/lite/debug/BUILD index 05ae1368125e59..dee91a9ffee68e 100644 --- a/tensorflow/compiler/mlir/lite/debug/BUILD +++ b/tensorflow/compiler/mlir/lite/debug/BUILD @@ -60,8 +60,8 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/lite/debug/debug.cc b/tensorflow/compiler/mlir/lite/debug/debug.cc index 2892fbf4a8bc80..d0b85019cfe200 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug.cc @@ -83,7 +83,7 @@ struct WritableFileRawStream : public llvm::raw_ostream { void write_impl(const char* ptr, size_t size) override { // Write the file if it is still valid. If the write fails, null out the // file to avoid encountering another error. - if (file && !file->Append(tsl::StringPiece(ptr, size)).ok()) { + if (file && !file->Append(absl::string_view(ptr, size)).ok()) { file = nullptr; } } diff --git a/tensorflow/compiler/mlir/lite/debug/debug_test.cc b/tensorflow/compiler/mlir/lite/debug/debug_test.cc index 371e3185fdc115..5d1ed84b36d074 100644 --- a/tensorflow/compiler/mlir/lite/debug/debug_test.cc +++ b/tensorflow/compiler/mlir/lite/debug/debug_test.cc @@ -46,8 +46,8 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" diff --git a/tensorflow/compiler/mlir/lite/delegates/flex/BUILD b/tensorflow/compiler/mlir/lite/delegates/flex/BUILD index 4ad7b874da82b8..2b3d198112d393 100644 --- a/tensorflow/compiler/mlir/lite/delegates/flex/BUILD +++ b/tensorflow/compiler/mlir/lite/delegates/flex/BUILD @@ -2,11 +2,9 @@ load( "//tensorflow:tensorflow.bzl", "if_mobile", "if_not_mobile", - "tf_cc_test", "tf_features_nolayering_check_if_ios", ) load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") -load("//tensorflow/compiler/mlir/lite/delegates/flex:build_def.bzl", "tflite_flex_cc_library") load("//tensorflow/lite:special_rules.bzl", "internal_visibility_allowlist") default_visibility = [ @@ -24,18 +22,6 @@ package( licenses = ["notice"], ) -exports_files([ - "delegate.h", - "exported_symbols.lds", - "version_script.lds", -]) - -tflite_flex_cc_library( - name = "delegate", - compatible_with = get_compatible_with_portable(), - visibility = ["//visibility:public"], -) - cc_library( name = "allowlisted_flex_ops_lib", srcs = [ @@ -54,21 +40,3 @@ cc_library( "//tensorflow/core:framework", ]), ) - -tf_cc_test( - name = "allowlisted_flex_ops_test", - size = "small", - srcs = [ - "allowlisted_flex_ops_test.cc", - ], - features = tf_features_nolayering_check_if_ios(), - deps = [ - ":allowlisted_flex_ops_lib", - ":delegate", - "@com_google_googletest//:gtest_main", - ] + if_mobile([ - "//tensorflow/core:portable_tensorflow_lib_lite", - ]) + if_not_mobile([ - "//tensorflow/core:framework", - ]), -) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc index 8313bf2c10e269..4f12a705cc27f7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tac_module.cc @@ -62,8 +62,7 @@ void TacModule::AddTACPass(mlir::OpPassManager* pass_manager, mlir::createCanonicalizerPass()); pass_manager->addPass( mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); - pass_manager->addPass( - mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); + pass_manager->addPass(mlir::TFL::CreateOptimizePass()); } pass_manager->addPass(mlir::TFL::tac::CreateComputeCostPass()); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 61642225a8c498..ecad51df76be37 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -113,7 +113,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/tstring.h" -#include "tensorflow/lite/core/interpreter.h" #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/tools/versioning/gpu_compatibility.h" #include "tensorflow/lite/tools/versioning/op_version.h" @@ -159,6 +158,11 @@ namespace tfl = mlir::TFL; ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex"; +// LINT.IfChange(optional_tensor) +// Taken from third_party/tensorflow/lite/core/c/common.h +constexpr int kTfLiteMigrationOptionalTensor = -1; +// LINT.ThenChange(//tensorflow/lite/core/c/common.h:optional_tensor) + // Use initial buffer size in flatbuffer builder to be same as the initial size // used by the TOCO export. (It does not explain rationale for this choice.) constexpr size_t kInitialBufferSize = 10240; @@ -631,6 +635,11 @@ class Translator { mlir::TFL::WhileOp op, const std::vector& operands, const std::vector& results); + // Build while operator where then & else are regions. + std::optional> BuildIfOperator( + mlir::TFL::IfOp op, const std::vector& operands, + const std::vector& results); + // Build call once operator. BufferOffset BuildCallOnceOperator( mlir::TFL::CallOnceOp op, const std::vector& operands, @@ -1331,6 +1340,54 @@ std::optional> Translator::BuildWhileOperator( builtin_options); } +std::optional> Translator::BuildIfOperator( + mlir::TFL::IfOp op, const std::vector& operands, + const std::vector& results) { + auto opcode_index = GetOpcodeIndex("if", tflite::BuiltinOperator_IF); + auto get_call_op = [&](mlir::Block& b) -> std::optional { + if (b.getOperations().size() != 2) return std::nullopt; + if (auto call_op = dyn_cast(b.front())) return call_op; + return std::nullopt; + }; + auto then_call_op = get_call_op(op.getThenRegion().front()); + auto else_call_op = get_call_op(op.getElseRegion().front()); + if (!then_call_op || !else_call_op) + return op.emitOpError("only single call then/else while export supported"), + std::nullopt; + auto then_subgraph_index = + subgraph_index_map_.at(then_call_op.value().getCallee().str()); + auto else_subgraph_index = + subgraph_index_map_.at(else_call_op.value().getCallee().str()); + auto builtin_options = tflite::CreateIfOptions(builder_, then_subgraph_index, + else_subgraph_index) + .Union(); + + // Get the subgraph index of IF op. + auto subgraph_func = op->getParentOfType(); + auto subgraph_idx = subgraph_index_map_[subgraph_func.getSymName().str()]; + auto new_operands = operands; + + // Then/Else region shares the same operands, only adding once as the new + // operands for the IF op. + if (then_call_op.value().getOperands() != + else_call_op.value().getOperands()) { + return op.emitOpError("Then/Else region does not contain same operands."), + std::nullopt; + } + + for (auto call_arg : then_call_op.value().getOperands()) { + auto name_of_call_arg = name_mapper_.GetUniqueName(call_arg); + const auto call_arg_tensor_id = + tensor_index_map_[subgraph_idx][name_of_call_arg]; + new_operands.push_back(call_arg_tensor_id); + } + auto inputs = builder_.CreateVector(new_operands); + auto outputs = builder_.CreateVector(results); + return tflite::CreateOperator(builder_, opcode_index, inputs, outputs, + tflite::BuiltinOptions_IfOptions, + builtin_options); +} + BufferOffset Translator::BuildNumericVerifyOperator( mlir::TFL::NumericVerifyOp op, const std::vector& operands, const std::vector& results) { @@ -2098,6 +2155,9 @@ std::optional> Translator::BuildOperator( } return BuildWhileOperator(whileOp, operands, results); } + if (auto ifOp = dyn_cast(inst)) { + return BuildIfOperator(ifOp, operands, results); + } inst->emitOpError("is not a supported TFLite op"); return std::nullopt; @@ -3024,7 +3084,7 @@ std::optional> Translator::BuildSubGraph( operands.reserve(real_inst->getNumOperands()); for (auto operand : real_inst->getOperands()) { if (mlir::isa(operand.getType())) - operands.push_back(kTfLiteOptionalTensor); + operands.push_back(kTfLiteMigrationOptionalTensor); else if (auto stats_op = llvm::dyn_cast_or_null( operand.getDefiningOp())) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index a03e988fde32fa..a289126d26b6ca 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -44,6 +44,7 @@ limitations under the License. #include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project @@ -70,16 +71,17 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" @@ -97,8 +99,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/model_builder.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -625,7 +625,7 @@ StatusOr ConvertOp( const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder, - const std::unique_ptr& model_ptr) { + const std::unique_ptr& model_ptr) { llvm::SmallVector operands; llvm::SmallVector outputTypes; @@ -1116,7 +1116,7 @@ StatusOr ConvertSubgraph( bool experimental_prune_unreachable_nodes_unconditionally, const tflite::SignatureDefT* signature, const tflite::ControlEdges& control_edges, - const std::unique_ptr& model_ptr, + const std::unique_ptr& model_ptr, bool use_stablehlo_constant) { // Populate from metadata. ControlNodes control_nodes; @@ -1518,8 +1518,8 @@ OwningOpRef tflite::FlatBufferToMlir( mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect, mlir::stablehlo::StablehloDialect, mlir::vhlo::VhloDialect>(); - auto model_ptr = - FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); + auto model_ptr = tfl::FlatBufferModelAbslError::VerifyAndBuildFromBuffer( + buffer.data(), buffer.length()); if (nullptr == model_ptr) { return emitError(base_loc, "couldn't parse flatbuffer"), nullptr; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc index df28f501ef7656..b393a885170c05 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_to_string.cc @@ -27,7 +27,7 @@ limitations under the License. #include "flatbuffers/minireflect.h" // from @flatbuffers #include "tensorflow/compiler/mlir/lite/schema/reflection/schema_generated.h" #if FLATBUFFERS_LITTLEENDIAN == 0 -#include "tensorflow/lite/core/model_builder.h" +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #endif namespace tflite { @@ -144,7 +144,8 @@ int main(int argc, char** argv) { // If the flatbuffer model comes from stdin, convert its tensor content from // BE to LE to ensure the output text string is the same as on LE platforms. if (std::string(argv[1]) == "-") - tflite::FlatBufferModel::ByteSwapSerializedModel(&serialized_model, true); + mlir::TFL::FlatBufferModelAbslError::ByteSwapSerializedModel( + &serialized_model, true); #endif tflite::ToString(serialized_model); return 0; diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 1633820bb5bd5e..c9aa62843d743c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -33,16 +34,21 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/Threading.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -54,6 +60,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project @@ -61,6 +68,7 @@ limitations under the License. #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h" #include "tensorflow/compiler/mlir/lite/utils/size_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -616,47 +624,50 @@ void IncrementIndex(ArrayRef result_shape, /// attributes `operand1` and `operand2` and returns the result if possible. /// This function assumes the both operands are verified to have value /// attributes of broadcastable types. -template > -Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, +template > +Attribute ConstFoldBinaryOpDenseDense(ShapedType result_type, + DenseElementsAttr lhs, DenseElementsAttr rhs, const CalculationT& calculate) { - auto type = OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()) - .dyn_cast_or_null(); + auto type = llvm::dyn_cast_or_null( + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType())); if (!type) { return {}; } + type = type.clone(result_type.getElementType()); + const bool rhs_is_splat = rhs.isSplat(); const bool lhs_is_splat = lhs.isSplat(); + auto lhs_values = lhs.try_value_begin(); + auto rhs_values = rhs.try_value_begin(); + if (failed(lhs_values) || failed(rhs_values)) { + return {}; + } + // If both of them are splat, compute and return. if (lhs_is_splat && rhs_is_splat) { - auto element_result = AttrElementT::get( - type.getElementType(), calculate(lhs.getSplatValue(), - rhs.getSplatValue())); - if (!element_result) return {}; - - return DenseElementsAttr::get(type, element_result); + return DenseElementsT::get( + type, calculate(*lhs_values.value(), *rhs_values.value())); } auto num_elements = type.getNumElements(); - SmallVector new_values; + SmallVector new_values; new_values.reserve(num_elements); const auto result_shape = type.getShape(); std::vector current_index(type.getRank(), 0); + // Create the new shape with ones padded to the left. - const std::vector lhs_new_shape = + const auto lhs_new_shape = GetPaddedShape(lhs.getType().getShape(), type.getRank()); - const std::vector rhs_new_shape = + const auto rhs_new_shape = GetPaddedShape(rhs.getType().getShape(), type.getRank()); - auto lhs_old_values = lhs.getValues(); - auto rhs_old_values = rhs.getValues(); - // Add each pair of the corresponding values in the dense elements // attributes. for (int64_t i = 0; i < num_elements; ++i) { @@ -669,26 +680,27 @@ Attribute ConstFoldBinaryOpDenseDense(Type result_type, DenseElementsAttr lhs, const int64_t rhs_index = rhs_is_splat ? 0 : GetElementIndex(rhs_new_shape, current_index); - new_values.push_back(calculate(*(lhs_old_values.begin() + lhs_index), - *(rhs_old_values.begin() + rhs_index))); + new_values.push_back(calculate(*(lhs_values.value() + lhs_index), + *(rhs_values.value() + rhs_index))); IncrementIndex(result_shape, ¤t_index); } - return DenseElementsAttr::get(type, ArrayRef(new_values)); + return DenseElementsT::get(type, new_values); } /// Performs const folding `calculate` with broadcast behavior on the two /// attributes `operand1` and `operand2` and returns the result if possible. /// This function assumes the two operands are verified to have value /// attributes of broadcastable types. -template > -Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, +template > +Attribute ConstFoldBinaryOp(ShapedType result_type, Attribute operand1, Attribute operand2, const CalculationT& calculate) { if (operand1.dyn_cast_or_null() && operand2.dyn_cast_or_null()) { - return ConstFoldBinaryOpDenseDense( + return ConstFoldBinaryOpDenseDense( result_type, operand1.cast(), operand2.cast(), calculate); } @@ -703,23 +715,18 @@ Attribute ConstFoldBinaryOp(Type result_type, Attribute operand1, /// Depending on the given `resultType`, either `floatCalculate` or /// `intCalculate` is chosen to conduct the calculate. Attribute ConstFoldBinaryOp( - Type result_type, ArrayRef operands, + ShapedType type, ArrayRef operands, llvm::function_ref float_calculate, llvm::function_ref int_calculate) { - // Note: All types are wrapped in tensor types in TFlite. E.g., f32 is - // represented as tensor. So we are only handling tensor types here. - auto type = result_type.dyn_cast(); - if (!type) return {}; - auto elemType = type.getElementType(); if (elemType.isa()) - return ConstFoldBinaryOp(result_type, operands[0], operands[1], - float_calculate); + return ConstFoldBinaryOp( + type, operands[0], operands[1], float_calculate); if (elemType.isSignlessInteger()) - return ConstFoldBinaryOp(result_type, operands[0], operands[1], - int_calculate); + return ConstFoldBinaryOp( + type, operands[0], operands[1], int_calculate); return {}; } @@ -809,6 +816,73 @@ int64_t AddOp::GetArithmeticCount(Operation* op) { return -1; } +//===----------------------------------------------------------------------===// +// FloorOp +//===----------------------------------------------------------------------===// + +OpFoldResult FloorOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + auto result_type = getType(); + if (!IsF32ShapedType(result_type)) return {}; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::floor(f); + return APFloat(result); + }; + + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// BitwiseXorOp +//===----------------------------------------------------------------------===// + +OpFoldResult BitwiseXorOp::fold(FoldAdaptor adaptor) { + auto compute = [](APInt lhs, APInt rhs) -> APInt { + lhs ^= rhs; + return lhs; + }; + + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), compute); +} + +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + auto result_type = getType(); + if (!IsF32ShapedType(result_type)) return {}; + + auto compute = [](APFloat value) -> APFloat { + float f = value.convertToFloat(); + float result = std::exp(f); + return APFloat(result); + }; + + return ConstFoldUnaryOp(result_type, operands[0], compute); +} + +//===----------------------------------------------------------------------===// +// LogicalNotOp +//===----------------------------------------------------------------------===// + +OpFoldResult LogicalNotOp::fold(FoldAdaptor adaptor) { + auto data = llvm::dyn_cast_or_null(adaptor.getLhs()); + if (!data) { + return {}; + } + + auto compute = [](bool value) { return !value; }; + + return DenseIntElementsAttr::get( + data.getType(), + llvm::to_vector(llvm::map_range(data.getValues(), compute))); +} + //===----------------------------------------------------------------------===// // ConcatenationOp //===----------------------------------------------------------------------===// @@ -1681,6 +1755,38 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; + auto is_zero = [](Attribute a) { + return matchPattern(a, m_Zero()) || matchPattern(a, m_AnyZeroFloat()); + }; + auto is_one = [](Attribute a) { + return matchPattern(a, m_One()) || matchPattern(a, m_OneFloat()); + }; + + // Quantized folding not supported. + const bool is_quantized = + llvm::isa(getType().getElementType()); + + auto lhs = llvm::dyn_cast_or_null(adaptor.getLhs()); + auto rhs = llvm::dyn_cast_or_null(adaptor.getRhs()); + + if (lhs && !is_quantized) { + if (is_zero(lhs) && lhs.getType() == getType()) { + return lhs; + } + if (is_one(lhs) && getRhs().getType() == getType()) { + return getRhs(); + } + } + + if (rhs && !is_quantized) { + if (is_zero(rhs) && rhs.getType() == getType()) { + return rhs; + } + if (is_one(rhs) && getLhs().getType() == getType()) { + return getLhs(); + } + } + // This function is performance critical for op fusion patterns, e.g. // FuseBinaryOpToPrecedingAffine and FuseMulOrDivWithConv2dOrDepthwiseConv2d. // So a few specializations are provided to evaluate the math operation @@ -1688,14 +1794,15 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { // Specialization for f32 type. if (getType().cast().getElementType().isF32()) { - return ConstFoldBinaryOp( + return ConstFoldBinaryOp( getType(), operands[0], operands[1], [](float a, float b) { return a * b; }); } // Specialization for bf16 type. if (getType().cast().getElementType().isBF16()) { - return ConstFoldBinaryOp( + return ConstFoldBinaryOp( getType(), operands[0], operands[1], [](Eigen::bfloat16 a, Eigen::bfloat16 b) { return a * b; }); } @@ -1713,6 +1820,24 @@ int64_t MulOp::GetArithmeticCount(Operation* op) { return -1; } +//===----------------------------------------------------------------------===// +// PowOp +//===----------------------------------------------------------------------===// + +OpFoldResult PowOp::fold(FoldAdaptor adaptor) { + if (getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return std::pow(lhs, rhs); }); + } + if (getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return std::pow(lhs, rhs); }); + } + return {}; +} + //===----------------------------------------------------------------------===// // DivOp //===----------------------------------------------------------------------===// @@ -1721,9 +1846,30 @@ OpFoldResult DivOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); // TODO(b/142478136): Handle fused ops. if (getFusedActivationFunction() != "NONE") return {}; - return ConstFoldBinaryOp( - getType(), operands, [](APFloat a, APFloat b) { return a / b; }, - [](APInt a, APInt b) { return a.sdiv(b); }); + + auto rhs = llvm::dyn_cast_or_null(adaptor.getRhs()); + auto lhs = llvm::dyn_cast_or_null(adaptor.getLhs()); + + if (rhs && lhs) { + return ConstFoldBinaryOp( + getType(), operands, [](APFloat a, APFloat b) { return a / b; }, + [](APInt a, APInt b) { return a.sdiv(b); }); + } + + if (llvm::isa(getType().getElementType())) { + // Quantized folding not supported for the following. + return {}; + } + + auto is_one = [](Attribute a) { + return matchPattern(a, m_One()) || matchPattern(a, m_OneFloat()); + }; + + if (rhs && is_one(rhs) && getLhs().getType() == getType()) { + return getLhs(); + } + + return {}; } int64_t DivOp::GetArithmeticCount(Operation* op) { @@ -3080,12 +3226,12 @@ OpFoldResult MaximumOp::fold(FoldAdaptor adaptor) { if (lhs && lhs.isSplat()) { APFloat lhs_value = lhs.getSplatValue(); lhs_value.changeSign(); - if (lhs_value.isLargest()) return getRhs(); + if (lhs_value.isLargest() || lhs_value.isInfinity()) return getRhs(); } if (rhs && rhs.isSplat()) { APFloat rhs_value = rhs.getSplatValue(); rhs_value.changeSign(); - if (rhs_value.isLargest()) return getLhs(); + if (rhs_value.isLargest() || rhs_value.isInfinity()) return getLhs(); } return nullptr; } @@ -3102,13 +3248,184 @@ OpFoldResult MinimumOp::fold(FoldAdaptor adaptor) { auto lhs = adaptor.getLhs().dyn_cast_or_null(); auto rhs = adaptor.getRhs().dyn_cast_or_null(); - if (lhs && lhs.isSplat() && lhs.getSplatValue().isLargest()) - return getRhs(); - if (rhs && rhs.isSplat() && rhs.getSplatValue().isLargest()) - return getLhs(); + if (lhs && lhs.isSplat()) { + auto splat = lhs.getSplatValue(); + if (splat.isLargest() || splat.isInfinity()) return getRhs(); + } + if (rhs && rhs.isSplat()) { + auto splat = rhs.getSplatValue(); + if (splat.isLargest() || splat.isInfinity()) return getLhs(); + } return nullptr; } +//===----------------------------------------------------------------------===// +// Comparison and Logical Ops +//===----------------------------------------------------------------------===// + +OpFoldResult LessOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs < rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs < rhs; }); + } + return {}; +} + +OpFoldResult LessEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs <= rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs <= rhs; }); + } + return {}; +} + +OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs > rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs > rhs; }); + } + return {}; +} + +OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs >= rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs >= rhs; }); + } + return {}; +} + +OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { + if (getX().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getX(), adaptor.getY(), + [](int32_t lhs, int32_t rhs) { return lhs == rhs; }); + } + if (getX().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getX(), adaptor.getY(), + [](float lhs, float rhs) { return lhs == rhs; }); + } + return {}; +} + +OpFoldResult NotEqualOp::fold(FoldAdaptor adaptor) { + if (getLhs().getType().getElementType().isInteger(32)) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](int32_t lhs, int32_t rhs) { return lhs != rhs; }); + } + if (getLhs().getType().getElementType().isF32()) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](float lhs, float rhs) { return lhs != rhs; }); + } + return {}; +} + +OpFoldResult LogicalAndOp::fold(FoldAdaptor adaptor) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](bool lhs, bool rhs) { return lhs && rhs; }); +} + +OpFoldResult LogicalOrOp::fold(FoldAdaptor adaptor) { + return ConstFoldBinaryOp( + getType(), adaptor.getLhs(), adaptor.getRhs(), + [](bool lhs, bool rhs) { return lhs || rhs; }); +} + +//===----------------------------------------------------------------------===// +// SelectOp +//===----------------------------------------------------------------------===// + +// TODO: b/359275356 - Expand this to handle the broadcast case similar +// to `ConstFoldBinaryOpDense`. +OpFoldResult SelectOp::fold(FoldAdaptor adaptor) { + auto lhs_type = getX().getType(); + auto rhs_type = getY().getType(); + auto condition_type = getCondition().getType(); + auto out_type = getType(); + + if (lhs_type != rhs_type) { + return {}; + } + + if (lhs_type.getShape() != condition_type.getShape()) { + // "broadcasted" condition not yet supported. + return {}; + } + + auto condition_vals = + llvm::dyn_cast_or_null(adaptor.getCondition()); + if (!condition_vals || !condition_vals.getElementType().isInteger(1)) { + return {}; + } + + if (condition_vals.isSplat()) { + const bool val = condition_vals.getSplatValue(); + return val ? adaptor.getX() : adaptor.getY(); + } + + auto lhs_vals = llvm::dyn_cast_or_null(adaptor.getX()); + auto rhs_vals = llvm::dyn_cast_or_null(adaptor.getY()); + if (!lhs_vals || !rhs_vals) { + return {}; + } + + llvm::SmallVector results; + results.reserve(condition_type.getNumElements()); + + auto lhs_it = lhs_vals.getValues().begin(); + auto lhs_end = lhs_vals.getValues().end(); + auto rhs_it = rhs_vals.getValues().begin(); + auto rhs_end = rhs_vals.getValues().end(); + + auto condition_it = condition_vals.getValues().begin(); + auto condition_end = condition_vals.getValues().end(); + + while (condition_it < condition_end && lhs_it < lhs_end && rhs_it < rhs_end) { + if (*condition_it++) { + results.push_back(*lhs_it); + } else { + results.push_back(*rhs_it); + } + + if (!lhs_vals.isSplat()) { + lhs_it++; + } + if (!rhs_vals.isSplat()) { + rhs_it++; + } + } + + return DenseElementsAttr::get(out_type, results); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -3191,35 +3508,12 @@ void ConstOp::getCanonicalizationPatterns(RewritePatternSet& results, // CastOp //===----------------------------------------------------------------------===// -OpFoldResult CastOp::fold(FoldAdaptor adaptor) { - auto operands = adaptor.getOperands(); - assert(operands.size() == 1); - if (getInput().getType() == getType()) { - return getInput(); - } - - // For now, only supports cast between integer types. - auto elements_attr = operands[0].dyn_cast_or_null(); - if (!elements_attr) { - return nullptr; - } - - auto result_element_type = - getType().cast().getElementType().dyn_cast(); - auto operand_element_type = getInput() - .getType() - .cast() - .getElementType() - .dyn_cast(); - // Returns nullptr if either result/operand element type is not integer. - if (!result_element_type || !operand_element_type) { - return nullptr; - } - - const bool is_unsigned = operand_element_type.isUnsigned(); - const bool involves_bool = operand_element_type.getWidth() == 1 || - result_element_type.getWidth() == 1; - const int output_bitwidth = result_element_type.getWidth(); +OpFoldResult CastIntToInt(DenseIntElementsAttr data, IntegerType in_type, + IntegerType out_type) { + const bool is_unsigned = in_type.isUnsigned(); + const bool involves_bool = + in_type.getWidth() == 1 || out_type.getWidth() == 1; + const int output_bitwidth = out_type.getWidth(); // The integer cast op is the same as C integer cast. Depends on the operand // type's signedness, we will determine whether or not sign extension is // needed. @@ -3230,13 +3524,114 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { // true input should always be cast to 1 and not -1 as the sign extension // would do for signed outputs. Similarly, non-zero inputs should be cast // to true. Truncating even numbers to one bit will result in `false`. - return APInt(result_element_type.getWidth(), value != 0); + return APInt(out_type.getWidth(), value != 0); } return is_unsigned ? value.zextOrTrunc(output_bitwidth) : value.sextOrTrunc(output_bitwidth); }; - return elements_attr.mapValues(result_element_type, cast); + return data.mapValues(out_type, cast); +} + +OpFoldResult CastFloatToInt(DenseFPElementsAttr data, FloatType in_type, + IntegerType out_type) { + const bool from_f32 = in_type.isF32(); + const bool to_i32 = out_type.isSignlessInteger(32); + if (!from_f32 || !to_i32) { + return {}; + } + + auto cast = [&](APFloat value) -> APInt { + APSInt result(32, false); + bool is_exact; + value.convertToInteger(result, llvm::RoundingMode::TowardZero, &is_exact); + return result; + }; + + return data.mapValues(out_type, cast); +} + +template +llvm::SmallVector MapStaticCast(DenseElementsAttr data) { + return llvm::map_to_vector(data.getValues(), + [](InType v) { return static_cast(v); }); +} + +OpFoldResult CastIntToFloat(DenseIntElementsAttr data, IntegerType in_type, + FloatType out_type) { + auto result_type = data.getType().clone(out_type); + if (!out_type.isF32()) { + return {}; + } + + if (in_type.isSignlessInteger(32)) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + if (in_type.isSignlessInteger(1)) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + + return {}; +} + +OpFoldResult CastFloatToFloat(DenseFPElementsAttr data, FloatType in_type, + FloatType out_type) { + auto result_type = data.getType().clone(out_type); + if (in_type.isF32() && out_type.isF64()) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + + if (in_type.isF64() && out_type.isF32()) { + return DenseFPElementsAttr::get(result_type, + MapStaticCast(data)); + } + return {}; +} + +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { + auto operands = adaptor.getOperands(); + if (operands.size() != 1) { + return {}; + } + if (getInput().getType() == getType()) { + return getInput(); + } + + auto input = operands[0]; + + auto in_type = getInput().getType().getElementType(); + auto out_type = getType().getElementType(); + + if (auto int_in_type = llvm::dyn_cast_or_null(in_type)) { + auto in_data = llvm::dyn_cast_or_null(input); + if (!in_data) { + return {}; + } + if (auto float_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastIntToFloat(in_data, int_in_type, float_out_type); + } + if (auto int_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastIntToInt(in_data, int_in_type, int_out_type); + } + } + + if (auto float_in_type = llvm::dyn_cast_or_null(in_type)) { + auto in_data = llvm::dyn_cast_or_null(input); + if (!in_data) { + return {}; + } + if (auto float_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastFloatToFloat(in_data, float_in_type, float_out_type); + } + if (auto int_out_type = llvm::dyn_cast_or_null(out_type)) { + return CastFloatToInt(in_data, float_in_type, int_out_type); + } + } + + return {}; } //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 33f920ebb02d5e..5eda0d01c31b61 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -1281,6 +1281,8 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -1357,6 +1359,8 @@ def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -1554,6 +1558,8 @@ def TFL_NotEqualOp : TFL_Op<"not_equal", [ let results = (outs TFL_BoolTensor:$output); + let hasFolder = 1; + let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs), @@ -1681,6 +1687,8 @@ def TFL_EqualOp: TFL_Op<"equal", [ let results = (outs TFL_BoolTensor:$output); let builders = [TFL_ComparisonBinaryBuilder]; + + let hasFolder = 1; } def TFL_ExpOp: TFL_Op<"exp", [ @@ -1697,6 +1705,8 @@ def TFL_ExpOp: TFL_Op<"exp", [ let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y); + let hasFolder = 1; + // This builder doesn't work with quantized type, so it can only be used by // non-quantization tablegen patterns. Currently, it is used by the // elementwise-move reordering pattern in the optimize_patterns.td @@ -1840,6 +1850,8 @@ def TFL_FloorOp: TFL_Op<"floor", [ let results = (outs TFL_FpTensor:$y); + let hasFolder = 1; + let extraClassDeclaration = [{ // Returns whether the return types are compatible. static bool isCompatibleReturnTypes(TypeRange l, TypeRange r) { @@ -1925,6 +1937,8 @@ def TFL_GreaterOp : TFL_Op<"greater", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2036,6 +2050,8 @@ def TFL_LessOp : TFL_Op<"less", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2046,7 +2062,7 @@ def TFL_LessOp : TFL_Op<"less", [ }]; } -def TFL_LogicalAndOp : TFL_Op<"logical_and", [Pure]> { +def TFL_LogicalAndOp : TFL_Op<"logical_and", [ResultsBroadcastableShape, Pure]> { let summary = "Logical AND operator"; let description = [{ @@ -2061,6 +2077,8 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [Pure]> { let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2083,9 +2101,11 @@ def TFL_LogicalNotOp : TFL_Op<"logical_not", [ let arguments = (ins TFL_BoolTensor:$lhs); let results = (outs TFL_BoolTensor:$output); + + let hasFolder = 1; } -def TFL_LogicalOrOp : TFL_Op<"logical_or", [Pure]> { +def TFL_LogicalOrOp : TFL_Op<"logical_or", [ResultsBroadcastableShape, Pure]> { let summary = "Logical OR operator"; let description = [{ @@ -2100,6 +2120,8 @@ def TFL_LogicalOrOp : TFL_Op<"logical_or", [Pure]> { let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -2803,6 +2825,8 @@ def TFL_PowOp : TFL_Op<"pow", [ let hasCustomAssemblyFormat = 1; + let hasFolder = 1; + let extraClassDefinition = [{ ParseResult $cppClass::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); @@ -3157,6 +3181,8 @@ def TFL_SelectOp : TFL_Op<"select", [ let results = (outs TFL_TensorOf<[F32, I1, I8, I16, I32, I64, UI32, QI8, QUI8, QI16, TFL_Quint8]>:$output); + let hasFolder = 1; + // TODO(jpienaar): autogenerate this. let builders = [ OpBuilder<(ins "Value":$condition, "Value":$x, "Value":$y), @@ -4080,6 +4106,7 @@ def TFL_BitcastOp : TFL_Op<"bitcast", [Pure]> { } def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ + ResultsBroadcastableShape, Commutative, SameOperandsAndResultElementType, Pure]> { @@ -4097,6 +4124,8 @@ def TFL_BitwiseXorOp : TFL_Op<"bitwise_xor", [ let results = (outs TFL_TensorOf<[I8, UI8, I16, UI16, I32, UI32]>:$output ); + + let hasFolder = 1; } def TFL_RightShiftOp : TFL_Op<"right_shift", [ @@ -4121,6 +4150,7 @@ def TFL_RightShiftOp : TFL_Op<"right_shift", [ //===----------------------------------------------------------------------===// // Quantization ops. //===----------------------------------------------------------------------===// + def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { let summary = "Dequantize operator"; diff --git a/tensorflow/compiler/mlir/lite/kernels/BUILD b/tensorflow/compiler/mlir/lite/kernels/BUILD index ccc0433bc66c76..8e6046fd2b4199 100644 --- a/tensorflow/compiler/mlir/lite/kernels/BUILD +++ b/tensorflow/compiler/mlir/lite/kernels/BUILD @@ -19,7 +19,7 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/utils:__pkg__", ], deps = [ - "//tensorflow/compiler/mlir/lite/core/c:common", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", ], ) # LINT.ThenChange(//tensorflow/lite/kernels/BUILD) diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/BUILD b/tensorflow/compiler/mlir/lite/kernels/internal/BUILD index b47bebfe991dad..74910218b1d128 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/BUILD +++ b/tensorflow/compiler/mlir/lite/kernels/internal/BUILD @@ -22,6 +22,7 @@ cc_library( name = "runtime_shape", srcs = ["runtime_shape.cc"], hdrs = ["runtime_shape.h"], + compatible_with = get_compatible_with_portable(), deps = [":compatibility_macros"], ) diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h b/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h index d107e55b1b3db6..b38391c39793ec 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h +++ b/tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ +#include #include +#include namespace tflite_migration { @@ -94,8 +96,71 @@ void PreprocessSoftmaxScaling(double beta, double input_scale, int CalculateInputRadius(int input_integer_bits, int input_left_shift, int total_signed_bits = 31); -} // namespace tflite_migration - +// Converts a floating-point number to an integer. For all inputs x where +// static_cast(x) is legal according to the C++ standard, the result +// is identical to that cast (i.e. the result is x with its fractional part +// truncated whenever that is representable as IntOut). +// +// static_cast would cause undefined behavior for the following cases, which +// have well-defined behavior for this function: +// +// 1. If x is NaN, the result is zero. +// +// 2. If the truncated form of x is above the representable range of IntOut, +// the result is std::numeric_limits::max(). +// +// 3. If the truncated form of x is below the representable range of IntOut, +// the result is std::numeric_limits::min(). +// +// Note that cases #2 and #3 cover infinities as well as finite numbers. +// +// The range of FloatIn must include the range of IntOut, otherwise +// the results are undefined. +// TODO(sfeuz): Replace by absl::SafeCast once available. +template +IntOut SafeCast(FloatIn x) { + static_assert(!std::numeric_limits::is_integer, + "FloatIn is integer"); + static_assert(std::numeric_limits::is_integer, + "IntOut is not integer"); + static_assert(std::numeric_limits::radix == 2, "IntOut is base 2"); + + // Special case NaN, for which the logic below doesn't work. + if (std::isnan(x)) { + return 0; + } + + // Negative values all clip to zero for unsigned results. + if (!std::numeric_limits::is_signed && x < 0) { + return 0; + } + + // Handle infinities. + if (std::isinf(x)) { + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); + } + + // Set exp such that x == f * 2^exp for some f with |f| in [0.5, 1.0), + // unless x is zero in which case exp == 0. Note that this implies that the + // magnitude of x is strictly less than 2^exp. + int exp = 0; + std::frexp(x, &exp); + + // Let N be the number of non-sign bits in the representation of IntOut. If + // the magnitude of x is strictly less than 2^N, the truncated version of x + // is representable as IntOut. The only representable integer for which this + // is not the case is kMin for signed types (i.e. -2^N), but that is covered + // by the fall-through below. + if (exp <= std::numeric_limits::digits) { + return x; + } + + // Handle numbers with magnitude >= 2^N. + return x < 0 ? std::numeric_limits::min() + : std::numeric_limits::max(); +} // LINT.ThenChange(//tensorflow/lite/kernels/internal/quantization_util.h) +} // namespace tflite_migration #endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_QUANTIZATION_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/BUILD b/tensorflow/compiler/mlir/lite/kernels/internal/utils/BUILD new file mode 100644 index 00000000000000..29677fd91a5eea --- /dev/null +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/BUILD @@ -0,0 +1,26 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/compiler/mlir/lite:build_def.bzl", "tflite_copts") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +# LINT.IfChange(sparsity_format_converter) + +cc_library( + name = "sparsity_format_converter", + srcs = ["sparsity_format_converter.cc"], + hdrs = ["sparsity_format_converter.h"], + compatible_with = get_compatible_with_portable(), + copts = tflite_copts(), + deps = [ + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "@eigen_archive//:eigen3", + ], +) + +# LINT.ThenChange(//tensorflow/lite/kernels/internal/utils/BUILD) diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc new file mode 100644 index 00000000000000..4a28c1474e9be8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -0,0 +1,205 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h" + +#include +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" + +namespace tflite_migration { +namespace internal { +namespace sparsity { + +// LINT.IfChange + +template +FormatConverter::FormatConverter( + const std::vector& shape, const std::vector& traversal_order, + const std::vector& format, + const std::vector& block_size, const std::vector& block_map) + : dense_shape_(shape), + traversal_order_(traversal_order), + block_size_(block_size), + block_map_(block_map) { + dense_size_ = 1; + int block_dim = 0; + blocked_shape_.resize(shape.size()); + format_.resize(shape.size() + block_map.size()); + for (int i = 0; i < shape.size(); i++) { + format_[i] = format[traversal_order[i]]; + dense_size_ *= shape[i]; + if (block_dim < block_map.size() && block_map[block_dim] == i) { + blocked_shape_[i] = shape[i] / block_size[block_dim]; + block_dim++; + } else { + blocked_shape_[i] = shape[i]; + } + } + + // Only dense blocks are supported. + for (int i = 0; i < block_map.size(); i++) { + format_[i + shape.size()] = kTfLiteDimDense; + } +} + +template +void FormatConverter::DenseToSparse(const T* src_data) { + int num_original_dims = dense_shape_.size(); + int num_block_dims = block_map_.size(); + int num_expanded_dims = num_original_dims + num_block_dims; + std::vector expanded_shape(num_expanded_dims); + for (int i = 0; i < num_expanded_dims; i++) { + if (i < num_original_dims) { + expanded_shape[i] = blocked_shape_[i]; + } else { + expanded_shape[i] = block_size_[i - num_original_dims]; + } + } + + std::vector shape_offset(num_original_dims); + shape_offset[shape_offset.size() - 1] = 1; + for (int i = num_original_dims - 1; i > 0; --i) { + shape_offset[i - 1] = shape_offset[i] * dense_shape_[i]; + } + + std::vector expanded_shape_offset(num_expanded_dims); + for (int i = 0; i < num_original_dims; ++i) { + expanded_shape_offset[i] = shape_offset[i]; + } + for (int i = 0; i < num_block_dims; ++i) { + int mapped_dim = block_map_[i]; + expanded_shape_offset[num_original_dims + i] = shape_offset[mapped_dim]; + expanded_shape_offset[mapped_dim] *= block_size_[i]; + } + + std::vector dst_ordered_offset(num_expanded_dims); + for (int i = 0; i < num_expanded_dims; ++i) { + dst_ordered_offset[i] = expanded_shape_offset[traversal_order_[i]]; + } + + std::vector dst_dim_has_nonzeroes(num_expanded_dims); + std::fill(dst_dim_has_nonzeroes.begin(), dst_dim_has_nonzeroes.end(), false); + std::vector inner_compressed_dim(num_expanded_dims); + int most_recent_compressed_dim = -1; + std::vector num_segments_of_next_compressed_dim(num_expanded_dims); + int segment_count = 1; + for (int i = num_expanded_dims - 1; i >= 0; --i) { + inner_compressed_dim[i] = most_recent_compressed_dim; + if (format_[i] == kTfLiteDimSparseCSR) { + most_recent_compressed_dim = i; + num_segments_of_next_compressed_dim[i] = segment_count; + segment_count = 1; + } else { + num_segments_of_next_compressed_dim[i] = -1; + segment_count *= expanded_shape[traversal_order_[i]]; + } + } + + dim_metadata_.resize(num_expanded_dims * 2); + std::vector dst_sparse_dims; + dst_sparse_dims.reserve(num_expanded_dims); + for (int i = 0; i < num_expanded_dims; ++i) { + dim_metadata_[i * 2].clear(); + dim_metadata_[i * 2 + 1].clear(); + if (format_[i] == kTfLiteDimDense) { + // If dimension is dense, just store the shape. + dim_metadata_[i * 2].push_back(expanded_shape[traversal_order_[i]]); + } else { + dim_metadata_[i * 2].push_back(0); // Segment array always begins with 0. + dst_sparse_dims.push_back(i); // Add dimension to the sparse list. + } + } + + // This algorithm assumes that the block size is small enough for all the + // elements to fit in cache, so the strided accesses from different traversal + // order and the write-first-erase-later strategy shouldn't be too slow + int dst_dim_idx = num_expanded_dims; + std::vector coordinate(num_expanded_dims, 0); + int dense_tensor_idx = 0; + while (dst_dim_idx >= 0) { + if (dst_dim_idx == num_expanded_dims) { + // We have a complete coordinate. Add the element to the value array if it + // is not zero, or if the last dimension is dense. + if (!IsZero(src_data[dense_tensor_idx])) { + data_.push_back(src_data[dense_tensor_idx]); + // Mark all sparse dimensions that their current indices have nonzeroes. + for (auto dst_dim : dst_sparse_dims) { + if (!dst_dim_has_nonzeroes[dst_dim]) { + // Only add the index to the indices array if the current nonzero + // is the first nonzero of the block. + dim_metadata_[2 * dst_dim + 1].push_back(coordinate[dst_dim]); + dst_dim_has_nonzeroes[dst_dim] = true; + } + } + } else if (format_[num_expanded_dims - 1] == kTfLiteDimDense) { + data_.push_back(src_data[dense_tensor_idx]); + } + --dst_dim_idx; + } else { + int original_dim_idx = traversal_order_[dst_dim_idx]; + int dim_size = expanded_shape[original_dim_idx]; + if (dst_dim_has_nonzeroes[dst_dim_idx]) { + // If the previous block has nonzeroes, reset the flag to false since + // we have just moved to a new block. + dst_dim_has_nonzeroes[dst_dim_idx] = false; + } else if (format_[dst_dim_idx] == kTfLiteDimSparseCSR) { + // This block is empty. Delete unnecessary values if compressed. + int next_compressed_dim = inner_compressed_dim[dst_dim_idx]; + int erase_offset = dim_metadata_[2 * dst_dim_idx + 1].size() * + num_segments_of_next_compressed_dim[dst_dim_idx]; + if (next_compressed_dim >= 0) { + auto& segments = dim_metadata_[2 * inner_compressed_dim[dst_dim_idx]]; + segments.erase(segments.begin() + 1 + erase_offset, segments.end()); + } else { + data_.erase(data_.begin() + erase_offset, data_.end()); + } + } + if (++coordinate[dst_dim_idx] < dim_size) { + // The current dst_dim_idx is valid (not out of bound). + dense_tensor_idx += dst_ordered_offset[dst_dim_idx]; + ++dst_dim_idx; + } else { + // dst_dim_idx has reached its dim size. Update segment array and go + // back to incrementing the previous dimension (dst_dim_idx - 1). + if (format_[dst_dim_idx] == kTfLiteDimSparseCSR) { + dim_metadata_[2 * dst_dim_idx].push_back( + dim_metadata_[2 * dst_dim_idx + 1].size()); + } + coordinate[dst_dim_idx] = -1; + dense_tensor_idx -= dst_ordered_offset[dst_dim_idx] * dim_size; + --dst_dim_idx; + } + } + } +} + +template +bool FormatConverter::IsZero(const T val) { + return (val == static_cast(0)); +} + +template class FormatConverter; +template class FormatConverter; +template class FormatConverter; + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/utils/sparsity_format_converter.cc) + +} // namespace sparsity +} // namespace internal +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h new file mode 100644 index 00000000000000..12b54502b46369 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h @@ -0,0 +1,102 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ + +#include + +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/compiler/mlir/lite/core/c/dimension_type.h" + +namespace tflite_migration { +namespace internal { +namespace sparsity { + +// LINT.IfChange + +// A converter that keeps an internal representation of sparse tensor parameters +// and converts tensors between dense and sparse formats. +template +class FormatConverter { + public: + /* + * Creates a dense to sparse converter. + * @param shape Shape of the dense tensor. + * @param traversal_order In what order to traverse all dimensions, + * including block dimensions. + * @param format Whether each dimension in the dense tensor is + * dense or sparse (not in the traversal order). + * @param block_size Size of each block dimension. + * @param block_map Map from block dimension to original tensor + * dimension. + */ + FormatConverter(const std::vector& shape, + const std::vector& traversal_order, + const std::vector& format, + const std::vector& block_size = {}, + const std::vector& block_map = {}); + + const std::vector& GetData() { return data_; } + + const std::vector>& GetDimMetadata() { + return dim_metadata_; + } + + // Method for dense to sparse conversion. Need to call GetData() method to get + // the compressed data. + + void DenseToSparse(const T* src_data); + + // Check if val is equal to zero. + bool IsZero(const T val); + + // Shape of the conceptual dense tensor. + std::vector dense_shape_; + // Shape of the dense tensor with inner blocks reduced. For example, a (4, 4) + // tensor with (2, 2) block has blocked_shape (2, 2). + std::vector blocked_shape_; + // Total number of elements in the dense tensor. + size_t dense_size_; + // Has n(original dimension)+k(block_dimension) elements. + std::vector traversal_order_; + // Format of each dimension in the traversal order. + std::vector format_; + // Size of each block dimension, in the same order as block map. + std::vector block_size_; + // Map from block dimension to the original tensor dimension. + std::vector block_map_; + // Metadata of each dimension in the traversal order. + // Each dimension needs two vectors. For dense dimensions, the first vector + // stores the size of that dimension, and the second vector is empty. For + // sparse dimensions, the first vector stores the segments and the second one + // stores the indices. + std::vector> dim_metadata_; + // Actual buffer holding data after conversion. Could be sparse buffer or + // dense buffer. + std::vector data_; +}; + +extern template class FormatConverter; +extern template class FormatConverter; +extern template class FormatConverter; +extern template class FormatConverter; + +// LINT.ThenChange(//tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h) + +} // namespace sparsity +} // namespace internal +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_KERNELS_INTERNAL_UTILS_SPARSITY_FORMAT_CONVERTER_H_ diff --git a/tensorflow/lite/mmap_allocation.cc b/tensorflow/compiler/mlir/lite/mmap_allocation.cc similarity index 97% rename from tensorflow/lite/mmap_allocation.cc rename to tensorflow/compiler/mlir/lite/mmap_allocation.cc index 3d1a7f03e713e1..eb106899228fba 100644 --- a/tensorflow/lite/mmap_allocation.cc +++ b/tensorflow/compiler/mlir/lite/mmap_allocation.cc @@ -21,8 +21,8 @@ limitations under the License. #include -#include "tensorflow/lite/allocation.h" -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { namespace { diff --git a/tensorflow/lite/mmap_allocation_disabled.cc b/tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc similarity index 96% rename from tensorflow/lite/mmap_allocation_disabled.cc rename to tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc index 95c34446797d7c..4e89594285473a 100644 --- a/tensorflow/lite/mmap_allocation_disabled.cc +++ b/tensorflow/compiler/mlir/lite/mmap_allocation_disabled.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "tensorflow/lite/allocation.h" +#include "tensorflow/compiler/mlir/lite/allocation.h" namespace tflite { diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 6d2de495e6673d..299bb9e2f2bc06 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -195,8 +195,10 @@ cc_library( ":saved_model_to_tfl_flatbuffer", "//tensorflow/c:kernels", "//tensorflow/c:tf_status_headers", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector", + "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/compiler/mlir/lite/python/interpreter_wrapper:python_utils", "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -205,8 +207,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite:model_builder", - "//tensorflow/lite/python/interpreter_wrapper:python_error_reporter", "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:model_flags_proto_cc", "//tensorflow/lite/toco:toco_convert", diff --git a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc index 31ec151442dfae..881c30019b903a 100644 --- a/tensorflow/compiler/mlir/lite/python/converter_python_api.cc +++ b/tensorflow/compiler/mlir/lite/python/converter_python_api.cc @@ -31,10 +31,12 @@ limitations under the License. #include "google/protobuf/text_format.h" #include "tensorflow/c/kernels.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/debug/debug_options.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector.h" #include "tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.h" #include "tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.h" +#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_utils.h" #include "tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.h" @@ -42,12 +44,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/sparsity/sparsify_model.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/model_builder.h" -#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h" #include "tensorflow/lite/toco/logging/conversion_log_util.h" #include "tensorflow/lite/toco/logging/toco_conversion_log.pb.h" #include "tensorflow/lite/toco/model.h" @@ -56,7 +58,6 @@ limitations under the License. #include "tensorflow/lite/toco/toco_flags.pb.h" #include "tensorflow/lite/toco/toco_graphviz_dump_options.h" #include "tensorflow/lite/toco/toco_tooling.h" -#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/lite/toco/types.pb.h" @@ -309,7 +310,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, bool enable_variable_quantization, bool disable_per_channel_for_dense_layers, PyObject* debug_options_proto_txt_raw) { - using tflite::interpreter_wrapper::PythonErrorReporter; + using tflite_migration::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; std::unique_ptr error_reporter(new PythonErrorReporter); @@ -362,9 +363,9 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, return nullptr; } - std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length, - error_reporter.get()); + std::unique_ptr model = + mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer( + buf, length, error_reporter.get()); if (!model) { PyErr_Format(PyExc_ValueError, "Invalid model"); return nullptr; @@ -399,7 +400,7 @@ PyObject* MlirQuantizeModel(PyObject* data, bool disable_per_channel, } PyObject* MlirSparsifyModel(PyObject* data) { - using tflite::interpreter_wrapper::PythonErrorReporter; + using tflite_migration::interpreter_wrapper::PythonErrorReporter; char* buf = nullptr; Py_ssize_t length; std::unique_ptr error_reporter(new PythonErrorReporter); @@ -408,9 +409,9 @@ PyObject* MlirSparsifyModel(PyObject* data) { PyErr_Format(PyExc_ValueError, "Failed to convert input PyObject"); return nullptr; } - std::unique_ptr model = - tflite::FlatBufferModel::BuildFromBuffer(buf, length, - error_reporter.get()); + std::unique_ptr model = + mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer( + buf, length, error_reporter.get()); if (!model) { PyErr_Format(PyExc_ValueError, "Invalid model"); return nullptr; diff --git a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc index 6591251d9e915b..b880df7f74a3ca 100644 --- a/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc +++ b/tensorflow/compiler/mlir/lite/python/flatbuffer_to_mlir.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" @@ -30,7 +31,6 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD index 8d2cb7a65e4b8d..9268de7ec1de54 100644 --- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD @@ -15,3 +15,14 @@ cc_library( "//third_party/python_runtime:headers", # buildcleaner: keep ], ) + +cc_library( + name = "python_error_reporter", + srcs = ["python_error_reporter.cc"], + hdrs = ["python_error_reporter.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite:stateful_error_reporter", + "//third_party/python_runtime:headers", # buildcleaner: keep + ], +) diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc new file mode 100644 index 00000000000000..75f9222d7c22d2 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h" + +#include +#include +#include + +namespace tflite_migration { +namespace interpreter_wrapper { + +// Report an error message +int PythonErrorReporter::Report(const char* format, va_list args) { + char buf[1024]; + int formatted = vsnprintf(buf, sizeof(buf), format, args); + buffer_ << buf; + return formatted; +} + +// Set's a Python runtime exception with the last error. +PyObject* PythonErrorReporter::exception() { + std::string last_message = message(); + PyErr_SetString(PyExc_RuntimeError, last_message.c_str()); + return nullptr; +} + +// Gets the last error message and clears the buffer. +std::string PythonErrorReporter::message() { + std::string value = buffer_.str(); + buffer_.clear(); + return value; +} +} // namespace interpreter_wrapper +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h new file mode 100644 index 00000000000000..f98a35227388bb --- /dev/null +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ + +#include + +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/stateful_error_reporter.h" + +namespace tflite_migration { +namespace interpreter_wrapper { + +class PythonErrorReporter : public tflite_migration::StatefulErrorReporter { + public: + PythonErrorReporter() = default; + + // Report an error message + int Report(const char* format, va_list args) override; + + // Sets a Python runtime exception with the last error and + // clears the error message buffer. + PyObject* exception(); + + // Gets the last error message and clears the buffer. + std::string message() override; + + private: + std::stringstream buffer_; +}; + +} // namespace interpreter_wrapper +} // namespace tflite_migration +#endif // TENSORFLOW_COMPILER_MLIR_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_ERROR_REPORTER_H_ diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 7a4567d7bd1c93..0aaedeae200a6e 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -211,6 +211,7 @@ Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.enable_stablehlo_quantizer = toco_flags.has_quantization_config(); pass_config.enable_composite_direct_lowering = toco_flags.enable_composite_direct_lowering(); + pass_config.model_origin_framework = toco_flags.model_origin_framework(); if (toco_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index d6f999eabb3c16..be09317cb52310 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index e57c2b30808d82..c269b41b596ab5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -168,16 +168,16 @@ tf_cc_test( deps = [ ":quantize_model", ":test_util", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/lite:framework", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -205,11 +205,11 @@ tf_cc_test( deps = [ ":quantize_weights", ":test_util", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", - "//tensorflow/lite:framework", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@flatbuffers", @@ -223,9 +223,7 @@ cc_library( srcs = ["test_util.cc"], hdrs = ["test_util.h"], deps = [ - "//tensorflow/lite:framework", - "//tensorflow/lite/core/api", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", "@com_google_googletest//:gtest", - "@flatbuffers", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc index 1e7cdcdea07d33..371f45210190a8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model_test.cc @@ -30,15 +30,15 @@ limitations under the License. #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/model_builder.h" -#include "tsl/lib/core/status_test_util.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_model_test.cc @@ -50,6 +50,7 @@ namespace tflite { namespace optimize { namespace { +using mlir::TFL::FlatBufferModelAbslError; using testing::Eq; using testing::FloatEq; using testing::FloatNear; @@ -100,7 +101,7 @@ absl::Status QuantizeModel( return status; } - auto flatbuffer_model = FlatBufferModel::BuildFromBuffer( + auto flatbuffer_model = FlatBufferModelAbslError::BuildFromBuffer( output_buffer.data(), output_buffer.size()); *model = UnPackFlatBufferModel(*flatbuffer_model->GetModel()); return absl::OkStatus(); @@ -157,9 +158,10 @@ absl::Status QuantizeModelAllOperators( disable_per_channel_for_dense_layers); } -std::unique_ptr ReadModel(const std::string& model_name) { +std::unique_ptr ReadModel( + const std::string& model_name) { auto model_path = tensorflow::io::JoinPath(*g_test_model_dir, model_name); - return FlatBufferModel::BuildFromFile(model_path.c_str()); + return FlatBufferModelAbslError::BuildFromFile(model_path.c_str()); } template @@ -198,7 +200,7 @@ class QuantizeModelTest : public testing::Test { model_ = UnPackFlatBufferModel(*readonly_model_); } - std::unique_ptr input_model_; + std::unique_ptr input_model_; const Model* readonly_model_; tflite::ModelT model_; std::string output_buffer_; // Raw buffer for quantized output model. diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc index 7a42e74c2619af..db124c8dd19813 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights_test.cc @@ -27,6 +27,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/command_line_flags.h" -#include "tensorflow/lite/model_builder.h" #include "tsl/platform/logging.h" // Note: branched from tensorflow/lite/tools/optimize/quantize_weights_test.cc @@ -50,6 +50,7 @@ namespace { using mlir::lite::BufferType; using mlir::lite::CustomOpMap; using mlir::lite::QuantizeWeights; +using mlir::TFL::FlatBufferModelAbslError; constexpr bool kUseUpdatedHybridSchemeDefault = true; std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { @@ -58,28 +59,28 @@ std::unique_ptr CreateMutableModelFromFile(const Model* input_model) { return copied_model; } -std::unique_ptr ReadTestModel() { +std::unique_ptr ReadTestModel() { auto model_path = tensorflow::io::JoinPath( *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); - return FlatBufferModel::BuildFromFile(model_path.c_str()); + return FlatBufferModelAbslError::BuildFromFile(model_path.c_str()); } -std::unique_ptr ReadSharedWeightsTestModel() { +std::unique_ptr ReadSharedWeightsTestModel() { auto model_path = tensorflow::io::JoinPath( *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); - return FlatBufferModel::BuildFromFile(model_path.c_str()); + return FlatBufferModelAbslError::BuildFromFile(model_path.c_str()); } -std::unique_ptr ReadGatherTestModel() { +std::unique_ptr ReadGatherTestModel() { auto model_path = tensorflow::io::JoinPath( *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); - return FlatBufferModel::BuildFromFile(model_path.c_str()); + return FlatBufferModelAbslError::BuildFromFile(model_path.c_str()); } -std::unique_ptr ReadCustomOpTestModel() { +std::unique_ptr ReadCustomOpTestModel() { auto model_path = tensorflow::io::JoinPath( *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); - return FlatBufferModel::BuildFromFile(model_path.c_str()); + return FlatBufferModelAbslError::BuildFromFile(model_path.c_str()); } template @@ -111,7 +112,7 @@ class QuantizeWeightsTest : public testing::Test { model_ = input_model_->GetModel(); } - std::unique_ptr input_model_; + std::unique_ptr input_model_; const Model* model_; bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc index e096868eec8807..66c1adef98bc22 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" +#include +#include + #include namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h index b4e317c131888e..8953a384766963 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h +++ b/tensorflow/compiler/mlir/lite/quantization/lite/test_util.h @@ -15,7 +15,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TEST_UTIL_H_ -#include "tensorflow/lite/core/api/error_reporter.h" +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace mlir { namespace lite { diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index e664137c2c136b..a275f4ab2fbd66 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -1,10 +1,11 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//visibility:public", + "//visibility:private", ], licenses = ["notice"], ) @@ -13,6 +14,9 @@ cc_library( name = "portable_tensor_utils", srcs = ["portable_tensor_utils.cc"], hdrs = ["portable_tensor_utils.h"], + visibility = [ + "//tensorflow/compiler/mlir/quantization/common/quantization_lib:__pkg__", + ], ) cc_library( @@ -36,3 +40,103 @@ tf_cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "quantization_utils", + srcs = ["quantization_utils.cc"], + hdrs = ["quantization_utils.h"], + deps = [ + ":model_utils", + ":portable_tensor_utils", + "//tensorflow/compiler/mlir/lite/kernels/internal:runtime_shape", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@eigen_archive//:eigen3", + ], +) + +tf_cc_test( + name = "quantization_utils_test", + srcs = ["quantization_utils_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", + ], + data = [ + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + ], + deps = [ + ":quantization_utils", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "//tensorflow/lite/core:framework", # to remove when mlir version is ready. + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_xla//xla/tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/util:command_line_flags", + ], +) + +cc_library( + name = "quantize_weights", + srcs = select({ + "//tensorflow:ios": ["quantize_weights_portable.cc"], + "//tensorflow:android": ["quantize_weights_portable.cc"], + "//conditions:default": ["quantize_weights.cc"], + }), + hdrs = ["quantize_weights.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:model_utils", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantization_utils", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@flatbuffers//:runtime_cc", + ] + select({ + "//tensorflow:android": [], + "//tensorflow:ios": [], + "//conditions:default": [ + "//tensorflow/compiler/mlir/lite/quantization/lite:quantize_weights", + ], + }), +) + +tf_cc_test( + name = "quantize_weights_test", + srcs = ["quantize_weights_test.cc"], + args = [ + "--test_model_file=$(location //tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin)", + ], + data = [ + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/custom_op.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/quantized_with_gather.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/single_conv_weights_min_0_max_plus_10.bin", + "//tensorflow/compiler/mlir/lite/quantization/lite:testdata/weight_shared_between_convs.bin", + ], + tags = [ + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + deps = [ + ":quantize_weights", + "//tensorflow/compiler/mlir/lite/quantization/lite:test_util", + "//tensorflow/compiler/mlir/lite/schema:schema_fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", + "//tensorflow/core:framework_internal", + "//tensorflow/lite/core:framework", # to remove when mlir version is ready. + "@com_google_googletest//:gtest", + "@flatbuffers", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_xla//xla/tsl/util:command_line_flags", + ], +) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc new file mode 100644 index 00000000000000..10b25368e37dda --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc @@ -0,0 +1,371 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/tools/optimize/quantization_utils.cc as part of +// the effort to decouple TFLite from MLIR. + +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "Eigen/Core" // from @eigen_archive +#include "tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +namespace { + +// LINT.IfChange(QuantizationUtilsConstants) +const int8_t kMinQuantizedValue8bit = -127; +const int8_t kMaxQuantizedValue8bit = 127; + +const int8_t kMinQuantizedValue4bit = -7; +const int8_t kMaxQuantizedValue4bit = 7; + +// The maximum number of dimensions supported in per-channel quantization. +constexpr int kPerChannelMaxDim = 4; +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:QuantizationUtilsConstants) +} // namespace + +using absl::InternalError; +using mlir::RuntimeShape; +using tflite::BufferT; +using tflite::QuantizationParametersT; +using tflite::TensorT; +using tflite::TensorType; +using tflite::TensorType_INT8; + +// LINT.IfChange(NumElements) +absl::Status NumElements(const TensorT& tensor, uint64_t* num_elements) { + *num_elements = 1; + for (const int64_t dim : tensor.shape) { + if (dim <= 0 || *num_elements > UINT64_MAX / static_cast(dim)) { + return InternalError("Invalid tensor shape."); + } + *num_elements *= dim; + } + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:NumElements) + +// LINT.IfChange(FillPerChannelMinMax) +absl::Status FillPerChannelMinMax( + const float* const input, const std::vector& dimension, + int32_t channel_dim_index, QuantizationParametersT* quantization_params) { + if (!quantization_params->min.empty() || !quantization_params->max.empty()) { + return absl::InvalidArgumentError( + "Min or max already present in tensor quantization params."); + } + + if (dimension.size() > kPerChannelMaxDim) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected tensor with less than %d dimensions, but got %d.", + kPerChannelMaxDim + 1, dimension.size())); + } + if (channel_dim_index >= dimension.size()) { + return absl::InvalidArgumentError(absl::StrCat( + "Expected channel_dim_index to be less than %d, but got %d.", + dimension.size(), channel_dim_index)); + } + + const int32_t channel_dim_size = dimension[channel_dim_index]; + quantization_params->quantized_dimension = channel_dim_index; + quantization_params->min = std::vector(channel_dim_size); + quantization_params->max = std::vector(channel_dim_size); + std::vector has_min_max_value(channel_dim_size, false); + int indices[kPerChannelMaxDim]; + RuntimeShape unextended_tensor_dims(dimension.size(), dimension.data()); + RuntimeShape tensor_dims = + RuntimeShape::ExtendedShape(kPerChannelMaxDim, unextended_tensor_dims); + channel_dim_index += + kPerChannelMaxDim - unextended_tensor_dims.DimensionsCount(); + + // Compute min max ranges per channel + for (indices[0] = 0; indices[0] < tensor_dims.Dims(0); indices[0]++) { + for (indices[1] = 0; indices[1] < tensor_dims.Dims(1); indices[1]++) { + for (indices[2] = 0; indices[2] < tensor_dims.Dims(2); indices[2]++) { + for (indices[3] = 0; indices[3] < tensor_dims.Dims(3); indices[3]++) { + int channel_idx = indices[channel_dim_index]; + const float val = input[Offset(tensor_dims, indices)]; + if (has_min_max_value[channel_idx]) { + if (quantization_params->min[channel_idx] > val) { + quantization_params->min[channel_idx] = val; + } else if (quantization_params->max[channel_idx] < val) { + quantization_params->max[channel_idx] = val; + } + } else { + quantization_params->min[channel_idx] = val; + quantization_params->max[channel_idx] = val; + has_min_max_value[channel_idx] = true; + } + } + } + } + } + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:FillPerChannelMinMax) + +// LINT.IfChange(SymmetricPerChannelQuantization) +// Per-channel quantize a tensor at the given index and fills both scales and +// quantized values. +absl::Status SymmetricPerChannelQuantization( + TensorT* tensor, const float* const input, int32_t channel_dim_index, + std::vector* output_scales, std::vector* output_value) { + if (tensor == nullptr) { + return absl::InvalidArgumentError("Cannot quantize. Tensor is null."); + } + const int32_t channel_dim_size = tensor->shape[channel_dim_index]; + // Fill per channel max and min values if needed + if (tensor->quantization == nullptr) { + tensor->quantization = std::make_unique(); + } + if (!HasMinMax(tensor)) { + absl::Status status = FillPerChannelMinMax( + input, tensor->shape, channel_dim_index, tensor->quantization.get()); + if (!status.ok()) { + return status; + } + } + + // Calculate scales per channel using max and min values from tensor. + std::vector scale_invs(channel_dim_size); + const float half_scale = kMaxQuantizedValue8bit; + for (int channel_idx = 0; channel_idx < channel_dim_size; channel_idx++) { + const float half_range = + std::max(std::abs(tensor->quantization->min[channel_idx]), + std::abs(tensor->quantization->max[channel_idx])); + output_scales->at(channel_idx) = half_range / half_scale; + if (half_range == 0) { + scale_invs[channel_idx] = 0; + } else { + scale_invs[channel_idx] = half_scale / half_range; + } + } + + // Quantize the input values. + SymmetricPerChannelQuantizeValues(input, scale_invs, tensor->shape, + channel_dim_index, output_value); + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:SymmetricPerChannelQuantization) + +// LINT.IfChange(SymmetricPerChannelQuantizeValues) +void SymmetricPerChannelQuantizeValues(const float* const input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value) { + // Quantize the values. + int indices[kPerChannelMaxDim]; + RuntimeShape unextended_tensor_dims(dimension.size(), dimension.data()); + RuntimeShape tensor_dims = + RuntimeShape::ExtendedShape(kPerChannelMaxDim, unextended_tensor_dims); + channel_dim_index += + kPerChannelMaxDim - unextended_tensor_dims.DimensionsCount(); + for (indices[0] = 0; indices[0] < tensor_dims.Dims(0); indices[0]++) { + for (indices[1] = 0; indices[1] < tensor_dims.Dims(1); indices[1]++) { + for (indices[2] = 0; indices[2] < tensor_dims.Dims(2); indices[2]++) { + for (indices[3] = 0; indices[3] < tensor_dims.Dims(3); indices[3]++) { + int channel_idx = indices[channel_dim_index]; + int index = Offset(tensor_dims, indices); + const float val = input[index]; + const int32_t quantized_value = + static_cast(round(val * scales_inv[channel_idx])); + output_value->at(index) = std::min( + kMaxQuantizedValue8bit, + std::max(kMinQuantizedValue8bit, quantized_value)); + } + } + } + } +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:SymmetricPerChannelQuantizeValues) + +// LINT.IfChange(SymmetricQuantizeTensor) +absl::Status SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { + if (model == nullptr || tensor == nullptr) { + return absl::InvalidArgumentError("No tensor to quantize."); + } + + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InvalidArgumentError("Missing buffer."); + } + const float* float_data = reinterpret_cast(buffer->data.data()); + uint64_t num_elements; + absl::Status status = NumElements(*tensor, &num_elements); + if (!status.ok()) { + return status; + } + + std::vector quantized_buffer; + quantized_buffer.resize(num_elements); + + float min_value, max_value, scaling_factor; + mlir::lite::toco_legacy::PortableSymmetricQuantizeFloats( + float_data, num_elements, quantized_buffer.data(), &min_value, &max_value, + &scaling_factor); + + if (tensor->quantization == nullptr) { + tensor->quantization = std::make_unique(); + } + tensor->quantization->scale = std::vector(1, scaling_factor); + tensor->quantization->zero_point = std::vector(1, 0); + + uint8_t* uint8_buffer = reinterpret_cast(quantized_buffer.data()); + model->buffers[tensor->buffer]->data.assign(uint8_buffer, + uint8_buffer + num_elements); + + // Update the tensor type. + tensor->type = TensorType_INT8; + + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:SymmetricQuantizeTensor) + +// LINT.IfChange(QuantizeTensorFloat16) +absl::Status QuantizeTensorFloat16(ModelT* model, TensorT* tensor) { + if (model == nullptr || tensor == nullptr) { + return absl::InvalidArgumentError("No tensor to quantize."); + } + + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InvalidArgumentError("Missing buffer."); + } + + uint64_t num_elements; + absl::Status status = NumElements(*tensor, &num_elements); + if (!status.ok()) { + return status; + } + + // Copy single byte buffer data to float vector to guard against misalignment. + std::vector float_vector(num_elements); + uint8_t* first = buffer->data.data(); + std::copy(first, first + buffer->data.size(), + reinterpret_cast(float_vector.data())); + + // Transform float data to float16. + std::vector quantized_buffer; + quantized_buffer.resize(num_elements); + constexpr float kMaxFloat16Value = 65504.f; + constexpr float kMinFloat16Value = -65504.f; + std::transform(float_vector.begin(), float_vector.end(), + quantized_buffer.begin(), [=](float a) { + float clamped = std::min(std::max(a, kMinFloat16Value), + kMaxFloat16Value); + return static_cast(clamped); + }); + + char* half_buffer = reinterpret_cast(quantized_buffer.data()); + model->buffers[tensor->buffer]->data.assign( + half_buffer, half_buffer + sizeof(Eigen::half) * num_elements); + + // Update the tensor type. + tensor->type = tflite::TensorType_FLOAT16; + + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:QuantizeTensorFloat16) + +// LINT.IfChange(AddQuantizationParams) +absl::Status AddQuantizationParams(const std::vector& scales, + const std::vector& zero_point, + int quantized_dimension, + const uint8_t* buffer_data, + size_t buffer_size, TensorType output_type, + ModelT* model, TensorT* tensor) { + if (tensor->quantization == nullptr) { + tensor->quantization = std::make_unique(); + } + tensor->quantization->scale.assign(scales.begin(), scales.end()); + if (zero_point.size() != scales.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Received zero_point of size %d and scales of size %d. " + "These sizes should match.", + zero_point.size(), scales.size())); + } + tensor->quantization->zero_point.assign(zero_point.begin(), zero_point.end()); + tensor->quantization->quantized_dimension = quantized_dimension; + model->buffers[tensor->buffer]->data.assign(buffer_data, + buffer_data + buffer_size); + // Update the tensor type. + tensor->type = output_type; + return absl::OkStatus(); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:AddQuantizationParams) + +// LINT.IfChange(SymmetricQuantizeTensorPerChannel) +absl::Status SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, + int32_t channel_dim_index) { + if (tensor->shape.size() > kPerChannelMaxDim) { + return absl::InvalidArgumentError(absl::StrCat( + "SymmetricQuantizeTensorPerChannel requires tensor with less than %d " + "dimensions, but got %d dimension(s).", + kPerChannelMaxDim + 1, tensor->shape.size())); + } + + // Get dimensions. + uint64_t num_elements; + absl::Status status = NumElements(*tensor, &num_elements); + if (!status.ok()) { + return status; + } + const int32_t channel_dim_size = tensor->shape[channel_dim_index]; + + // Get input float data. + const BufferT* buffer = model->buffers[tensor->buffer].get(); + const float* float_input_data = + reinterpret_cast(buffer->data.data()); + + // Create container for output scale and output data. + std::vector scales(channel_dim_size); + std::vector final_buffer(num_elements); + + // Quantize the input data with respect to channel_dim_index. + status = SymmetricPerChannelQuantization( + tensor, float_input_data, channel_dim_index, &scales, &final_buffer); + if (!status.ok()) { + return status; + } + + // Set the buffers and output type. + uint8_t* uint8_buffer = reinterpret_cast(final_buffer.data()); + const size_t buffer_size = num_elements * sizeof(int8_t); + std::vector zero_point(scales.size(), 0); + return AddQuantizationParams(scales, zero_point, channel_dim_index, + uint8_buffer, buffer_size, TensorType_INT8, + model, tensor); +} +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.cc:SymmetricQuantizeTensorPerChannel) + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h new file mode 100644 index 00000000000000..bd68ed1ccb473d --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/tools/optimize/quantization_utils.h as part of +// the effort to decouple TFLite from MLIR. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using tflite::ModelT; +using tflite::QuantizationParametersT; +using tflite::TensorT; +using tflite::TensorType; + +// LINT.IfChange(num_elements) +// Returns the number of elements in the given tensor. +absl::Status NumElements(const TensorT& tensor, uint64_t* num_elements); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:num_elements) + +// LINT.IfChange(fill_per_channel_min_max) +// Populates the max and min values for per channel quantization. +absl::Status FillPerChannelMinMax(const float* const input, + const std::vector& dimension, + int32_t channel_dim_index, + QuantizationParametersT* quantization_params); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:fill_per_channel_min_max) + +// LINT.IfChange(symmetric_per_channel_quantization) +// Per-channel quantize a tensor at the given index and returns both scales and +// quantized values. +// Parameters: +// - tensor is the tensor to be quantized, needed to access associated +// quantization parameters +// - input is the float input data to be quantized. +// - channel_dim_index is the channel index within "dimension". +// dimension[channel_dim_index] gives the number of channels. +// - output_scale is the output scale, the size of which equals the number of +// channels. +// - output_value is the output data, the size of which equals the number of +// inputs. +absl::Status SymmetricPerChannelQuantization(TensorT* tensor, + const float* const input, + int32_t channel_dim_index, + std::vector* output_scales, + std::vector* output_value); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_per_channel_quantization) + +// LINT.IfChange(symmetric_per_channel_quantize_values) +// Quantize the values given an array of scales. +void SymmetricPerChannelQuantizeValues(const float* const input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_per_channel_quantize_values) + +// LINT.IfChange(symmetric_quantize_tensor) +// Quantizes tensor using symmetric quantization with the min and max elements +// of the tensor. +absl::Status SymmetricQuantizeTensor(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_quantize_tensor) + +// LINT.IfChange(symmetric_quantize_tensor_per_channel) +// Quantizes tensor with per channel. +absl::Status SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, + int32_t channel_dim_index); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:symmetric_quantize_tensor_per_channel) + +// LINT.IfChange(quantize_tensor_float16) +// Quantizes tensor to float16. +absl::Status QuantizeTensorFloat16(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:quantize_tensor_float16) + +// LINT.IfChange(add_quantization_params) +absl::Status AddQuantizationParams(const std::vector& scales, + const std::vector& zero_point, + int quantized_dimension, + const uint8_t* buffer_data, + size_t buffer_size, TensorType output_type, + ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantization_utils.h:add_quantization_params) + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc new file mode 100644 index 00000000000000..0a7bcd0df79597 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils_test.cc @@ -0,0 +1,495 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file is the MLIR copy of part of +// third_party/tensorflow/lite/tools/optimize/quantization_utils_test.cc as part +// of the effort to decouple TFLite from MLIR. + +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "absl/status/status.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tensorflow/lite/core/model_builder.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/path.h" + +namespace { +std::string* g_test_model_dir = nullptr; +} // namespace + +namespace mlir { +namespace lite { +namespace toco_legacy { +namespace { + +using tflite::BuiltinOperator_CONV_2D; +using tflite::FlatBufferModel; // to remove when mlir version is ready, from + // model.h +using tflite::QuantizationParametersT; +using tflite::SubGraphT; +using tflite::TensorT; +using tflite::TensorType_FLOAT16; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT8; + +std::unique_ptr ReadModel(const char* model) { + auto model_path = tsl::io::JoinPath(*g_test_model_dir, model); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadConvModel() { + return ReadModel(mlir::lite::internal::kConvModelWith0Plus10Weights); +} + +using ::testing::ElementsAreArray; + +class QuantizationUtilsTest : public testing::Test {}; + +TEST_F(QuantizationUtilsTest, NumElements) { + TensorT tensor; + tensor.shape = {1, 2, 3, 4}; + uint64_t num_elements; + TF_EXPECT_OK(NumElements(tensor, &num_elements)); + EXPECT_EQ(num_elements, 1 * 2 * 3 * 4); + + tensor.shape = {5}; + TF_EXPECT_OK(NumElements(tensor, &num_elements)); + EXPECT_EQ(num_elements, 5); + + tensor.shape = {}; + TF_EXPECT_OK(NumElements(tensor, &num_elements)); + // Scalars with empty shape have 1 element. + EXPECT_EQ(num_elements, 1); + + tensor.shape = {1, 2, 3, -1}; + EXPECT_EQ(NumElements(tensor, &num_elements).code(), + absl::StatusCode::kInternal); +} + +TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantizationWithNullQParams) { + // Set up an input with [3, 2, 2, 2] size and 0 is the channel index. + const std::vector input = { + 3.0, 2.0, 5.0, -2.0, 3.0, 2.0, 5.0, -2.0, // Channel 1. + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // Channel 2. + 1.0, 0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, // Channel 3. + }; + const int channel_index = 0; + + // Create holder for output scale and data. + std::vector output_scales(3); + std::vector output_data(3 * 2 * 2 * 2); + + // Call SymmetricPerChannelQuantization with quant_params as a null pointer + // and verify the result. + TensorT tensor = TensorT(); + tensor.quantization = nullptr; + tensor.shape = {3, 2, 2, 2}; + TF_EXPECT_OK(mlir::lite::toco_legacy::SymmetricPerChannelQuantization( + &tensor, input.data(), channel_index, &output_scales, &output_data)); + const std::vector expected_output_scales = {0.0393700786, 0.0629921257, + 0.0472440943}; + const std::vector expected_output_data = { + 76, 51, 127, -51, 76, 51, 127, -51, // Channel 1. + 16, 32, 48, 64, 79, 95, 111, 127, // Channel 2. + 21, 0, -21, -42, -64, -85, -106, -127, // Channel 3. + }; + EXPECT_THAT(output_scales, ElementsAreArray(expected_output_scales)); + EXPECT_THAT(output_data, ElementsAreArray(expected_output_data)); +} + +TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantization) { + // Set up an input with [3, 2, 2, 2] size and 0 is the channel index. + const std::vector input = { + 3.0, 2.0, 5.0, -2.0, 3.0, 2.0, 5.0, -2.0, // Channel 1. + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // Channel 2. + 1.0, 0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, // Channel 3. + }; + const int32_t channel_index = 0; + + // Create holder for output scale and data. + std::vector output_scales(3); + std::vector output_data(3 * 2 * 2 * 2); + + // Initialize pointer to quantization parameters + TensorT tensor = TensorT(); + tensor.quantization = std::make_unique(); + tensor.shape = {3, 2, 2, 2}; + TF_EXPECT_OK(mlir::lite::toco_legacy::FillPerChannelMinMax( + input.data(), tensor.shape, channel_index, tensor.quantization.get())); + + // Test that FillPerChanneMinMax worked + const std::vector expected_mins = {-2.0, 1.0, -6.0}; + const std::vector expected_maxs = {5.0, 8.0, 1.0}; + EXPECT_THAT(tensor.quantization->min, ElementsAreArray(expected_mins)); + EXPECT_THAT(tensor.quantization->max, ElementsAreArray(expected_maxs)); + + // Call SymmetricPerChannelQuantization with quant_params as a null pointer + // and verify the result. + TF_EXPECT_OK(mlir::lite::toco_legacy::SymmetricPerChannelQuantization( + &tensor, input.data(), channel_index, &output_scales, &output_data)); + const std::vector expected_output_scales = {0.0393700786, 0.0629921257, + 0.0472440943}; + const std::vector expected_output_data = { + 76, 51, 127, -51, 76, 51, 127, -51, // Channel 1. + 16, 32, 48, 64, 79, 95, 111, 127, // Channel 2. + 21, 0, -21, -42, -64, -85, -106, -127, // Channel 3. + }; + EXPECT_THAT(output_scales, ElementsAreArray(expected_output_scales)); + EXPECT_THAT(output_data, ElementsAreArray(expected_output_data)); +} + +TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantization2DTensor) { + // Set up an input with [3, 8] size and 0 is the channel index. + const std::vector input = { + 3.0, 2.0, 5.0, -2.0, 3.0, 2.0, 5.0, -2.0, // Batch 1. + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // Batch 2. + 1.0, 0.0, -1.0, -2.0, -3.0, -4.0, -5.0, -6.0, // Batch 3. + }; + const int32_t channel_index = 1; + + // Create holder for output scale and data. + std::vector output_scales(8); + std::vector output_data(3 * 8); + + // Initialize pointer to quantization parameters + TensorT tensor = TensorT(); + tensor.quantization = std::make_unique(); + tensor.shape = {3, 8}; + TF_EXPECT_OK(mlir::lite::toco_legacy::FillPerChannelMinMax( + input.data(), tensor.shape, channel_index, tensor.quantization.get())); + + // Test that FillPerChanneMinMax worked + const std::vector expected_mins = {1.0, 0.0, -1.0, -2.0, + -3.0, -4.0, -5.0, -6.0}; + const std::vector expected_maxs = {3.0, 2.0, 5.0, 4.0, + 5.0, 6.0, 7.0, 8.0}; + EXPECT_THAT(tensor.quantization->min, ElementsAreArray(expected_mins)); + EXPECT_THAT(tensor.quantization->max, ElementsAreArray(expected_maxs)); + + // Call SymmetricPerChannelQuantization with quant_params as a null pointer + // and verify the result. + TF_EXPECT_OK(mlir::lite::toco_legacy::SymmetricPerChannelQuantization( + &tensor, input.data(), channel_index, &output_scales, &output_data)); + const std::vector expected_output_scales = { + 0.02362204724, 0.01574803149, 0.03937007874, 0.03149606299, + 0.03937007874, 0.04724409448, 0.05511811023, 0.06299212598}; + const std::vector expected_output_data = { + 127, 127, 127, -64, 76, 42, 91, -32, // Batch 1. + 42, 127, 76, 127, 127, 127, 127, 127, // Batch 2. + 42, 0, -25, -64, -76, -85, -91, -95, // Batch 3. + }; + EXPECT_THAT(output_scales, ElementsAreArray(expected_output_scales)); + EXPECT_THAT(output_data, ElementsAreArray(expected_output_data)); +} + +TEST_F(QuantizationUtilsTest, SymmetricPerChannelQuantizeValues) { + // Set up an input with [3, 1, 1, 2] size and 0 is the channel index. + const std::vector input = { + 13.0, 21.0, // Channel 1. + 21.0, 22.0, // Channel 2. + 31.0, 40.0, // Channel 3. + }; + const std::vector scales_inv = {2, 0.5, 3}; + const std::vector dimension = {3, 1, 1, 2}; + const int channel_index = 0; + + // Create holder for output data. + std::vector output_data(3 * 1 * 1 * 2); + + // Call SymmetricPerChannelQuantizeValues and verify the result. + SymmetricPerChannelQuantizeValues(input.data(), scales_inv, dimension, + channel_index, &output_data); + const std::vector expected_output_data = { + 26, 42, // Channel 1. + 11, 11, // Channel 2. + 93, 120, // Channel 3. + }; + EXPECT_THAT(output_data, ElementsAreArray(expected_output_data)); +} + +TEST_F(QuantizationUtilsTest, FillPerChannelMinMax) { + // Set up an input with [3, 1, 1, 2] size. + const std::vector input = { + 13.0, 21.0, // Channel 1. + 21.0, 22.0, // Channel 2. + 31.0, 40.0, // Channel 3. + }; + + // Initialize pointer to quantization parameters. + QuantizationParametersT quantization_params = QuantizationParametersT(); + std::vector dimension = {3, 1, 1, 2}; + int32_t channel_dim_idx = 0; + const std::vector expected_mins = {13.0, 21.0, 31.0}; + const std::vector expected_maxs = {21.0, 22.0, 40.0}; + + TF_EXPECT_OK(mlir::lite::toco_legacy::FillPerChannelMinMax( + input.data(), dimension, channel_dim_idx, &quantization_params)); + + EXPECT_EQ(quantization_params.min, expected_mins); + EXPECT_EQ(quantization_params.max, expected_maxs); + EXPECT_EQ(quantization_params.quantized_dimension, channel_dim_idx); +} + +TEST_F(QuantizationUtilsTest, FillPerChannelMinMaxFillDim3) { + // Set up an input with [3, 1, 1, 2] size. + const std::vector input = { + // Channel 1, Channel 2 + 13.0, 21.0, 21.0, 22.0, 31.0, 40.0, + }; + + // Initialize pointer to quantization parameters. + QuantizationParametersT quantization_params = QuantizationParametersT(); + std::vector dimension = {3, 1, 1, 2}; + int32_t channel_dim_idx = 3; + const std::vector expected_mins = {13.0, 21.0}; + const std::vector expected_maxs = {31.0, 40.0}; + + TF_EXPECT_OK(mlir::lite::toco_legacy::FillPerChannelMinMax( + input.data(), dimension, channel_dim_idx, &quantization_params)); + + EXPECT_EQ(quantization_params.min, expected_mins); + EXPECT_EQ(quantization_params.max, expected_maxs); + EXPECT_EQ(quantization_params.quantized_dimension, channel_dim_idx); +} + +TEST_F(QuantizationUtilsTest, FillPerChannelMinMax2DTensor) { + // Set up an input with [3, 2] size. + const std::vector input = { + // Channel 1, Channel 2 + 13.0, 21.0, 21.0, 22.0, 31.0, 40.0, + }; + + // Initialize pointer to quantization parameters. + QuantizationParametersT quantization_params = QuantizationParametersT(); + std::vector dimension = {3, 2}; + int32_t channel_dim_idx = 1; + const std::vector expected_mins = {13.0, 21.0}; + const std::vector expected_maxs = {31.0, 40.0}; + + TF_EXPECT_OK(mlir::lite::toco_legacy::FillPerChannelMinMax( + input.data(), dimension, channel_dim_idx, &quantization_params)); + + EXPECT_EQ(quantization_params.min, expected_mins); + EXPECT_EQ(quantization_params.max, expected_maxs); + EXPECT_EQ(quantization_params.quantized_dimension, channel_dim_idx); +} + +TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensorNullInputs) { + EXPECT_EQ(SymmetricQuantizeTensor(nullptr, nullptr).code(), + absl::StatusCode::kInvalidArgument); +} + +TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensorNullQuantParams) { + // Conv model has weights between 0 and 10. + // Quantize the weights tensor. + ASSERT_TRUE(g_test_model_dir); + ASSERT_FALSE(g_test_model_dir->empty()); + auto test_model = ReadConvModel(); + ASSERT_TRUE(test_model); + auto readonly_model = test_model->GetModel(); + ASSERT_TRUE(readonly_model); + ASSERT_TRUE(readonly_model->subgraphs()); + ASSERT_GE(readonly_model->subgraphs()->size(), 1); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + auto subgraph = model.subgraphs[0].get(); + auto conv_op = subgraph->operators.at(0).get(); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); + int32_t weights_tensor_idx = conv_op->inputs[1]; + TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); + // Empty quantization parameters. + weights_tensor->quantization = std::make_unique(); + + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32); + size_t float_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + + TF_EXPECT_OK(SymmetricQuantizeTensor(&model, weights_tensor)); + + size_t quant_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + EXPECT_EQ(weights_tensor->type, TensorType_INT8); + EXPECT_EQ(quant_buffer_size * 4, float_buffer_size); +} + +TEST_F(QuantizationUtilsTest, SymmetricQuantizeTensor) { + // Conv model has weights between 0 and 10. + // Quantize the weights tensor. + ASSERT_TRUE(g_test_model_dir); + ASSERT_FALSE(g_test_model_dir->empty()); + auto test_model = ReadConvModel(); + ASSERT_TRUE(test_model); + auto readonly_model = test_model->GetModel(); + ASSERT_TRUE(readonly_model); + ASSERT_TRUE(readonly_model->subgraphs()); + ASSERT_GE(readonly_model->subgraphs()->size(), 1); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + auto subgraph = model.subgraphs[0].get(); + auto conv_op = subgraph->operators.at(0).get(); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); + int32_t weights_tensor_idx = conv_op->inputs[1]; + TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); + + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32); + size_t float_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + + TF_EXPECT_OK(SymmetricQuantizeTensor(&model, weights_tensor)); + + size_t quant_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + EXPECT_EQ(weights_tensor->type, TensorType_INT8); + EXPECT_EQ(quant_buffer_size * 4, float_buffer_size); +} + +TEST_F(QuantizationUtilsTest, QuantizeFloat16Clamp) { + // Create data. + auto model = std::make_unique(); + auto subgraph = std::make_unique(); + auto tensor = std::make_unique(); + auto buffer = std::make_unique(); + constexpr int kNumElements = 6; + const std::vector weights = {2.0, 1.0, 65504., 65505, -65504., -99999}; + auto weights_reinterpreted_data = + reinterpret_cast(weights.data()); + buffer->data.assign(weights_reinterpreted_data, + weights_reinterpreted_data + weights.size() * 4); + tensor->buffer = 0; + tensor->shape = {1, kNumElements}; + + // Wire the model. + model->subgraphs.push_back(std::move(subgraph)); + model->subgraphs[0]->tensors.push_back(std::move(tensor)); + model->buffers.push_back(std::move(buffer)); + + // Call and verify. + TF_EXPECT_OK(QuantizeTensorFloat16(model.get(), + model->subgraphs[0]->tensors[0].get())); + auto weightsf16 = reinterpret_cast( + model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data.data()); + std::vector wf32(kNumElements); + std::transform(weightsf16, weightsf16 + 6, wf32.begin(), + [](Eigen::half a) { return static_cast(a); }); + + EXPECT_THAT(wf32, + ElementsAreArray({2.0, 1.0, 65504., 65504., -65504., -65504.})); + EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_FLOAT16); +} + +TEST_F(QuantizationUtilsTest, QuantizeFloat16) { + // Conv model has weights between 0 and 10. + // Quantize the weights tensor. + ASSERT_TRUE(g_test_model_dir != nullptr); + ASSERT_FALSE(g_test_model_dir->empty()); + auto test_model = ReadConvModel(); + ASSERT_TRUE(test_model); + auto readonly_model = test_model->GetModel(); + ASSERT_TRUE(readonly_model); + ASSERT_TRUE(readonly_model->subgraphs()); + ASSERT_GE(readonly_model->subgraphs()->size(), 1); + tflite::ModelT model; + readonly_model->UnPackTo(&model); + auto subgraph = model.subgraphs[0].get(); + auto conv_op = subgraph->operators.at(0).get(); + ASSERT_EQ( + GetBuiltinCode(model.operator_codes.at(conv_op->opcode_index).get()), + BuiltinOperator_CONV_2D); + int32_t weights_tensor_idx = conv_op->inputs[1]; + TensorT* weights_tensor = subgraph->tensors.at(weights_tensor_idx).get(); + + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT32); + size_t float_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + + TF_EXPECT_OK(QuantizeTensorFloat16(&model, weights_tensor)); + + size_t quant_buffer_size = + model.buffers.at(weights_tensor->buffer)->data.size(); + EXPECT_EQ(weights_tensor->type, TensorType_FLOAT16); + EXPECT_EQ(quant_buffer_size * 2, float_buffer_size); +} + +TEST_F(QuantizationUtilsTest, AddQuantizationParams) { + // Create data. + auto model = std::make_unique(); + auto subgraph = std::make_unique(); + auto tensor = std::make_unique(); + auto buffer = std::make_unique(); + const std::vector scales = {0.5, 1.0, 1.5}; + const std::vector zero_points = {5, 10, 15}; + const int32_t quantizated_dimension = 3; + const std::vector buffer_data = {1, 2, 3, 4}; + const int32_t buffer_size = 4; + tensor->buffer = 0; + + // Wire the model. + model->subgraphs.push_back(std::move(subgraph)); + model->subgraphs[0]->tensors.push_back(std::move(tensor)); + model->buffers.push_back(std::move(buffer)); + + // Call and verify. + TF_EXPECT_OK(AddQuantizationParams(scales, zero_points, quantizated_dimension, + buffer_data.data(), buffer_size, + TensorType_INT8, model.get(), + model->subgraphs[0]->tensors[0].get())); + EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->scale, + ElementsAreArray(scales)); + EXPECT_THAT(model->subgraphs[0]->tensors[0]->quantization->zero_point, + ElementsAreArray(zero_points)); + EXPECT_THAT(model->buffers[model->subgraphs[0]->tensors[0]->buffer]->data, + ElementsAreArray(buffer_data)); + EXPECT_EQ(model->subgraphs[0]->tensors[0]->type, TensorType_INT8); +} + + +} // namespace +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +int main(int argc, char** argv) { + std::string model_file; + const std::vector flag_list = { + tsl::Flag("test_model_file", &model_file, + "Path to test tflite model file."), + }; + + const bool parse_result = tsl::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + std::cerr << "Required test_model_file\n"; + std::abort(); + } + g_test_model_dir = new std::string(tsl::io::Dirname(model_file)); + ::tsl::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc new file mode 100644 index 00000000000000..b2d6fe97280174 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc @@ -0,0 +1,751 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "tensorflow/core/platform/logging.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { +namespace { + +using absl::flat_hash_set; +using mlir::lite::toco_legacy:: + CustomOpMap; // Use this instead of mlir::lite::CustomOpMap because that + // uses mlir::lite::CustomOpInfo in + // tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.h, + // and we need mlir::lite::toco_legacy::CustomOpInfo, in + // tensorflow/compiler/mlir/lite/quantization/lite/optimize/quantize_weights.h +using tflite::BufferT; +using tflite::BuiltinOperator; +using tflite::BuiltinOperator_BATCH_MATMUL; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN; +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_DEPTHWISE_CONV_2D; +using tflite::BuiltinOperator_EMBEDDING_LOOKUP; +using tflite::BuiltinOperator_FULLY_CONNECTED; +using tflite::BuiltinOperator_GATHER; +using tflite::BuiltinOperator_LSTM; +using tflite::BuiltinOperator_RNN; +using tflite::BuiltinOperator_SVDF; +using tflite::BuiltinOperator_TRANSPOSE_CONV; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM; +using tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN; +using tflite::FinishModelBuffer; +using tflite::GetBuiltinCode; +using tflite::Model; +using tflite::ModelT; +using tflite::OperatorCodeT; +using tflite::OperatorT; +using tflite::SubGraphT; +using tflite::TensorT; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT8; + +struct ConsumerOpInfo { + OperatorT* op; + // The index of the op in the operators vector. + int32_t op_idx; + // The index of the tensor to quantize in subgraph->tensors. + int32_t op_input_idx; +}; + +struct TensorPerChannel { + TensorT* t; + bool is_per_channel; + int channel_dim; +}; + +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; + +// Redefined from tensorflow/lite/core/c/common.h as local const int instead of +// discouraged #define macro. +const int kTfLiteOptionalTensor = -1; + +// Convert the MLIR CustomOpMap from the TFlite CustomOpMap as their member +// variables differ. +void ConstructMLIRCustomOpMap(mlir::lite::CustomOpMap& mlir_map, + const CustomOpMap& tflite_map) { + for (const auto& entry : tflite_map) { + mlir_map[entry.first].quantizable_input_indices = + entry.second.quantizable_input_indices; + mlir_map[entry.first].is_weight_only = !entry.second.is_hybrid; + mlir_map[entry.first].no_side_effect = true; + } +} + +// Gets the operators that consume tensor_idx. +std::vector GetTensorConsumers(const ModelT* model, + const SubGraphT* subgraph, + int32_t tensor_idx) { + // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor, + // instead doing one sweep for the entire model. + std::vector consumer_ops; + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + consumer_ops.push_back( + {op, static_cast(op_idx), static_cast(i)}); + } + } + } + return consumer_ops; +} + +// Gets the list of op->inputs indices of the weights inputs to be quantized for +// the provided op. +std::vector GetWeightInputIndices(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto& custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info != custom_op_map.end()) { + return custom_op_info->second.quantizable_input_indices; + } + } else if (builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP || + builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) { + return {1}; + } else if (builtin_op_code == BuiltinOperator_SVDF) { + // tensorflow/lite/kernels/svdf.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { + // tensorflow/lite/kernels/lstm.cc + // tensorflow/lite/kernels/unidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16}; + } else if (builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + // tensorflow/lite/kernels/basic_rnn.cc + // tensorflow/lite/kernels/unidirectional_sequence_rnn.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) { + // tensorflow/lite/kernels/bidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { + // tensorflow/lite/kernels/bidirectional_sequence_rnn.cc + return {1, 2, 4, 5, 6, 8, 9, 10, 11}; + } else if (builtin_op_code == BuiltinOperator_GATHER) { + // tensorflow/lite/kernels/gather.cc + return {0}; + } + return {}; +} + +// Checks that a specific input can be quantized. +bool IsQuantizedInput(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, int op_input_idx) { + const auto quantized_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + return std::find(std::begin(quantized_input_indices), + std::end(quantized_input_indices), + op_input_idx) != std::end(quantized_input_indices); +} + +// Returns true if the operator supports hybrid evaluation. +bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + // Operations that support hybrid evaluation. + bool eval_hybrid = false; + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info == custom_op_map.end()) { + return {}; + } else { + return custom_op_info->second.is_hybrid; + } + } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_SVDF || + builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + eval_hybrid = true; + } else if (builtin_op_code == BuiltinOperator_LSTM) { + const tflite::LSTMOptionsT* options = op->builtin_options.AsLSTMOptions(); + // Only lstm kernel_type full supports hybrid evaluation. + if (options->kernel_type == tflite::LSTMKernelType_FULL) { + eval_hybrid = true; + } + } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + eval_hybrid = use_updated_hybrid_scheme; + } + return eval_hybrid; +} + +// Returns true if all of the op's inputs are quantized. +bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, + const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + + if (tensor_idx == -1) { + // Optional tensor. + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + + if (tensor->type != TensorType_INT8) { + return false; + } + } + return true; +} + +// Inserts Tensors for each input tensor of op that should be +// quantized into tensor_map. +absl::Status InsertQuantizableInputTensorsFromOperator( + const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + absl::flat_hash_map* tensor_map, + int subgraph_index, bool use_updated_hybrid_scheme) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + auto builtin_code = GetBuiltinCode(op_code); + + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + if (tensor_idx == -1) { + LOG(INFO) << "Skipping optional tensor input " << op_input_idx + << " of operation " << EnumNameBuiltinOperator(builtin_code); + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; + continue; + } + + uint64_t num_elements; + if (!mlir::lite::toco_legacy::NumElements(*tensor, &num_elements).ok()) { + return absl::InternalError("Error in quantization_utils NumElements"); + } + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; + continue; + } + + // Some tensors may have a null buffer vector, indicating an intermediate + // array. + if (model->buffers[tensor->buffer]->data.data() == nullptr) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has no allocated buffer."; + continue; + } + + if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/3}}); + } else if (builtin_code == BuiltinOperator_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/0}}); + } else { + switch (builtin_code) { + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsBidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsBidirectionalSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_FULLY_CONNECTED: + op->builtin_options.AsFullyConnectedOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BATCH_MATMUL: + op->builtin_options.AsBatchMatMulOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_LSTM: + op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_RNN: + op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_SVDF: + op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsUnidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + default: + break; + } + tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}}); + } + } + + return absl::OkStatus(); +} + +// Updates operator code versions for the operators with INT8 inputs. +void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) { + for (int i = 0, end = model->operator_codes.size(); i < end; ++i) { + const BuiltinOperator& op_code = + GetBuiltinCode(model->operator_codes[i].get()); + if (op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2; + } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_EMBEDDING_LOOKUP) { + model->operator_codes[i]->version = 3; + } else if (op_code == BuiltinOperator_LSTM) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3; + } else if (op_code == BuiltinOperator_CONV_2D) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2; + } else if (op_code == BuiltinOperator_FULLY_CONNECTED) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3; + } else if (op_code == BuiltinOperator_BATCH_MATMUL) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1; + } else if (op_code == BuiltinOperator_SVDF) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2; + } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + model->operator_codes[i]->version = 6; + } + } +} + +// Returns true if the op in consumer_op_infos can pass through quantization. +bool IsQuantizationPassThroughOps( + const ModelT* model, const std::vector& consumer_op_infos) { + if (consumer_op_infos.size() != 1) { + return false; + } + const OperatorT* consumer_op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get()); + return op_code == BuiltinOperator_GATHER || + op_code == BuiltinOperator_EMBEDDING_LOOKUP; +} + +// Copies quantization parameters from input to output and returns consumers of +// the output tensor as a tuple with values: +// - index of the output tensor +// - pointer to the output tensor +// - vector of consumers ops. +std::tuple> +PassQuantizationAndGetConsumers( + const ModelT* model, const SubGraphT* subgraph, + const std::vector& consumer_op_infos, + const CustomOpMap& custom_op_map) { + const OperatorT* op = consumer_op_infos.front().op; + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + if (op->outputs.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized output"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t output_tensor_idx = op->outputs.front(); + const auto input_idx = GetWeightInputIndices(op_code, custom_op_map); + if (input_idx.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized input"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t input_tensor_idx = op->inputs[input_idx.front()]; + + // Propagate quantization params. + const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get(); + TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get(); + if (!output_tensor->quantization) { + output_tensor->quantization = + std::make_unique(); + } + *output_tensor->quantization = *input_tensor->quantization; + output_tensor->type = TensorType_INT8; + return std::make_tuple( + output_tensor_idx, output_tensor, + GetTensorConsumers(model, subgraph, output_tensor_idx)); +} + +inline bool IsOpDenylisted(const flat_hash_set& op_denylist, + const BuiltinOperator op_code) { + return op_denylist.find(op_code) != op_denylist.end(); +} + +absl::Status QuantizeWeightsInt8( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + bool use_hybrid_evaluation, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, + const absl::flat_hash_set& op_denylist = {}) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + absl::Status status = InsertQuantizableInputTensorsFromOperator( + model.get(), op, weights_min_num_elements, custom_op_map, &tensor_map, + subgraph_index, use_updated_hybrid_scheme); + if (!status.ok()) return status; + } + + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (tensor_pair.second.is_per_channel) { + if (!mlir::lite::toco_legacy::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim) + .ok()) { + return absl::InternalError( + "SymmetricQuantizeTensorPerChannel failed"); + } + } else { + if (!mlir::lite::toco_legacy::SymmetricQuantizeTensor( + model.get(), tensor_pair.second.t) + .ok()) { + return absl::InternalError("SymmetricQuantizeTensor failed"); + } + } + } + + // Examine the tensor consumers to determine which require dequantize ops. + for (const auto& tensor_pair : tensor_map) { + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second.t; + std::vector consumer_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) { + std::tie(tensor_idx, tensor, consumer_op_infos) = + PassQuantizationAndGetConsumers(model.get(), subgraph, + consumer_op_infos, custom_op_map); + if (tensor_idx < 0) { + // Error message is already logged by PassQuantizationAndGetConsumers. + return absl::InternalError("PassQuantizationAndGetConsumers failed"); + } + } + + std::vector dequant_op_infos; // Ops that need dequants. + for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { + OperatorT* consumer_op = consumer_op_info.op; + const OperatorCodeT* consumer_op_code = + model->operator_codes[consumer_op->opcode_index].get(); + // If the op is a hybrid op and all the required tensors are quantized, + // we have no further work to do, but for all ops that require + // dequantization we need to add a Dequantize op. + bool eval_hybrid = + use_hybrid_evaluation && + !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) && + IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map, + use_updated_hybrid_scheme) && + CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code, + custom_op_map) && + IsQuantizedInput(consumer_op_code, custom_op_map, + consumer_op_info.op_input_idx); + if (!eval_hybrid) { + dequant_op_infos.push_back(consumer_op_info); + } + } + + // Check if this tensor is an output tensor. + int32_t output_index = -1; + for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { + if (subgraph->outputs[i] == tensor_idx) { + output_index = i; + break; + } + } + + // If no ops require dequant and it is not output, we are done for this + // tensor. + if (dequant_op_infos.empty() && output_index < 0) { + continue; + } + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const std::string dequant_name = tensor->name + "_dequantize"; + mlir::lite::toco_legacy::MakeTensor( + dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + mlir::lite::toco_legacy::MakeDequantizeOperator( + model.get(), &dequantize_op, tensor_idx, dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + // Update output name. + if (output_index >= 0) { + subgraph->outputs[output_index] = dequantize_output_idx; + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + // Update the modified operator code versions. + UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme); + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return absl::OkStatus(); +} + +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) { + OperatorT* op = subgraph->operators[i].get(); + for (auto tensor_idx : op->inputs) { + // Skip optional tensors. + if (tensor_idx == kTfLiteOptionalTensor) { + continue; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InternalError("Buffer is null"); + } + // Quantize tensors that have data to quantize. + bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); + if (tensor->type == TensorType_FLOAT32 && is_constant) { + tensor_map.insert({tensor_idx, tensor}); + } + } + } + + // The hash map ensures that we quantize each tensor exactly once. + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (!mlir::lite::toco_legacy::QuantizeTensorFloat16(model.get(), + tensor_pair.second) + .ok()) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } + + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second; + std::vector dequant_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const std::string dequant_name = tensor->name + "_dequantize"; + mlir::lite::toco_legacy::MakeTensor( + dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + mlir::lite::toco_legacy::MakeDequantizeOperator( + model.get(), &dequantize_op, tensor_idx, dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + return absl::OkStatus(); +} +} // namespace + +namespace internal { +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + bool use_hybrid_evaluation, + QuantizerType quantizer_type) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, use_hybrid_evaluation); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} +} // namespace internal + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights(builder, input_model, + weights_min_num_elements); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, BufferType quant_type, + bool use_updated_hybrid_scheme, + QuantizerType quantizer_type) { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + return mlir::lite::QuantizeWeights(builder, input_model, + (mlir::lite::BufferType)quant_type, + use_updated_hybrid_scheme); + } + switch (quant_type) { + case BufferType::QUANTIZED_INT8: { + mlir::lite::toco_legacy::CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + kWeightsMinNumElementsDefault, custom_op_map, + use_updated_hybrid_scheme); + } + case BufferType::QUANTIZED_FLOAT16: + return QuantizeWeightsFloat16(builder, input_model); + } +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + mlir::lite::CustomOpMap mlir_custom_op_map; + ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, mlir_custom_op_map); + } + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + mlir::lite::CustomOpMap mlir_custom_op_map; + ConstructMLIRCustomOpMap(mlir_custom_op_map, custom_op_map); + return mlir::lite::QuantizeWeights( + builder, input_model, weights_min_num_elements, mlir_custom_op_map, + use_updated_hybrid_scheme, op_denylist); + } + return QuantizeWeightsInt8(builder, input_model, + /*use_hybrid_evaluation=*/true, + weights_min_num_elements, custom_op_map, + use_updated_hybrid_scheme, op_denylist); +} + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h new file mode 100644 index 00000000000000..039c18d8e1d256 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" + +namespace mlir { +namespace lite { +namespace toco_legacy { + +using ::tflite::BuiltinOperator; +using ::tflite::Model; + +// Supported resulting types from quantization process. +enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 }; +enum class QuantizerType { OLD_QUANTIZER, MLIR_QUANTIZER }; + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_hybrid; +}; + +// Map from custom op code to custom op quantization information. +using CustomOpMap = std::unordered_map; + +// This macro is for internal use for conversions requiring previous behavior. +#ifdef TFLITE_USE_PREVIOUS_HYBRID_SCHEME +// Use asymmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = false; +#else +// Use symmetric quantized activations and per-channel quantized weights. +constexpr bool kUseUpdatedHybridSchemeDefault = true; +#endif + +// Quantizes input_model and populates the provided builder with the new model. +// By default only weights tensors weight more than 1024 elements will be +// quantized. +// +// A tflite::Model can be obtained from the builder with: +// const uint8_t* buffer = builder->GetBufferPointer(); +// tflite::Model* model = GetModel(buffer); +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + BufferType quant_type = BufferType::QUANTIZED_INT8, + bool use_updated_hybrid_scheme = kUseUpdatedHybridSchemeDefault, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but only weights with greater than or equal +// weights_min_num_elements elements will be quantized. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but with entry point of quantizing custom ops. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +// Same as above, but if use updated_hybrid_scheme is false, +// use previous quantization scheme. Optional op_denylist argument +// disables hybrid evaluation for provided BuiltinOperators. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const absl::flat_hash_set& op_denylist = {}, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); + +namespace internal { +// If use_hybrid_evaluation is false, will disable using hybrid eval for +// operations that support it. +// +// We use this internal QuantizeWeights call to test models with hybrid +// evaluation disabled. +absl::Status QuantizeWeights( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + uint64_t weights_min_num_elements, bool use_hybrid_evaluation, + QuantizerType quantizer_type = QuantizerType::OLD_QUANTIZER); +} // namespace internal + +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_TOCO_LEGACY_QUANTIZE_WEIGHTS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc new file mode 100644 index 00000000000000..91030d4cf57e27 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_portable.cc @@ -0,0 +1,692 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// clang-format off +#include "tensorflow/lite/tools/toco_legacy/quantize_weights.h" +// clang-format on + +#include +#include +#include +#include + +#include "flatbuffers/flexbuffers.h" +#include "absl/container/flat_hash_map.h" +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +// #include "tensorflow/lite/context.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/model_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "tensorflow/lite/core/model.h" // to be replaced with unda's model_builder + +namespace tflite { +namespace optimize { + +namespace { + +struct ConsumerOpInfo { + OperatorT* op; + // The index of the op in the operators vector. + int32_t op_idx; + // The index of the tensor to quantize in subgraph->tensors. + int32_t op_input_idx; +}; + +struct TensorPerChannel { + TensorT* t; + bool is_per_channel; + int channel_dim; +}; + +// The default minimum number of elements a weights array must have to be +// quantized by this transformation. +const int kWeightsMinNumElementsDefault = 1024; + +// Gets the operators that consume tensor_idx. +std::vector GetTensorConsumers(const ModelT* model, + const SubGraphT* subgraph, + int32_t tensor_idx) { + // TODO(suharshs): If this proves to be too slow, avoid calling it per tensor, + // instead doing one sweep for the entire model. + std::vector consumer_ops; + for (size_t op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) { + OperatorT* op = subgraph->operators[op_idx].get(); + if (op == nullptr) { + continue; + } + for (size_t i = 0; i < op->inputs.size(); ++i) { + if (op->inputs[i] == tensor_idx) { + consumer_ops.push_back( + {op, static_cast(op_idx), static_cast(i)}); + } + } + } + return consumer_ops; +} + +// Gets the list of op->inputs indices of the weights inputs to be quantized for +// the provided op. +std::vector GetWeightInputIndices(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto& custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info != custom_op_map.end()) { + return custom_op_info->second.quantizable_input_indices; + } + } else if (builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D || + builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_EMBEDDING_LOOKUP || + builtin_op_code == BuiltinOperator_TRANSPOSE_CONV) { + return {1}; + } else if (builtin_op_code == BuiltinOperator_SVDF) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/svdf.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/lstm.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16}; + } else if (builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/basic_rnn.cc + // https://www.tensorflow.org/code/tensorflow/lite/kernels/unidirectional_sequence_rnn.cc + return {1, 2}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc + return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 33, 40, 41, 42, 43, 44, 45, 46, 47}; + } else if (builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc + return {1, 2, 4, 5, 6, 8, 9, 10, 11}; + } else if (builtin_op_code == BuiltinOperator_GATHER) { + // https://www.tensorflow.org/code/tensorflow/lite/kernels/gather.cc + return {0}; + } + return {}; +} + +// Checks that a specific input can be quantized. +bool IsQuantizedInput(const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, int op_input_idx) { + const auto quantized_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + return std::find(std::begin(quantized_input_indices), + std::end(quantized_input_indices), + op_input_idx) != std::end(quantized_input_indices); +} + +// Returns true if the operator supports hybrid evaluation. +bool IsHybridEvaluationOp(const OperatorT* op, const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme) { + const BuiltinOperator builtin_op_code = GetBuiltinCode(op_code); + // Operations that support hybrid evaluation. + bool eval_hybrid = false; + if (builtin_op_code == BuiltinOperator_CUSTOM) { + const std::string custom_code = op_code->custom_code; + const auto custom_op_info = custom_op_map.find(custom_code); + if (custom_op_info == custom_op_map.end()) { + return {}; + } else { + return custom_op_info->second.is_hybrid; + } + } else if (builtin_op_code == BuiltinOperator_FULLY_CONNECTED || + builtin_op_code == BuiltinOperator_BATCH_MATMUL || + builtin_op_code == BuiltinOperator_CONV_2D || + builtin_op_code == BuiltinOperator_SVDF || + builtin_op_code == BuiltinOperator_RNN || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + builtin_op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + eval_hybrid = true; + } else if (builtin_op_code == BuiltinOperator_LSTM) { + const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions(); + // Only lstm kernel_type full supports hybrid evaluation. + if (options->kernel_type == LSTMKernelType_FULL) { + eval_hybrid = true; + } + } else if (builtin_op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + eval_hybrid = use_updated_hybrid_scheme; + } + return eval_hybrid; +} + +// Returns true if all of the op's inputs are quantized. +bool CheckAllOpInputsQuantized(const SubGraphT* subgraph, const OperatorT* op, + const OperatorCodeT* op_code, + const CustomOpMap& custom_op_map) { + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + + if (tensor_idx == -1) { + // Optional tensor. + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + + if (tensor->type != TensorType_INT8) { + return false; + } + } + return true; +} + +// Inserts Tensors for each input tensor of op that should be +// quantized into tensor_map. +TfLiteStatus InsertQuantizableInputTensorsFromOperator( + const ModelT* model, OperatorT* op, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + absl::flat_hash_map* tensor_map, + int subgraph_index, bool use_updated_hybrid_scheme) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + auto builtin_code = GetBuiltinCode(op_code); + + std::vector op_input_indices = + GetWeightInputIndices(op_code, custom_op_map); + for (const int32_t op_input_idx : op_input_indices) { + int32_t tensor_idx = op->inputs[op_input_idx]; + if (tensor_idx == -1) { + LOG(INFO) << "Skipping optional tensor input " << op_input_idx + << " of operation " << EnumNameBuiltinOperator(builtin_code); + continue; + } + + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + if (tensor->type != TensorType_FLOAT32) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " that is not type float."; + continue; + } + + uint64_t num_elements; + TF_LITE_ENSURE_STATUS(utils::NumElements(*tensor, &num_elements)); + if (num_elements < weights_min_num_elements) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has fewer than " << weights_min_num_elements + << " elements (" << num_elements << ")."; + continue; + } + + // Some tensors may have a null buffer vector, indicating an intermediate + // array. + if (model->buffers[tensor->buffer]->data.data() == nullptr) { + LOG(INFO) << "Skipping quantization of tensor " << tensor->name + << " because it has no allocated buffer."; + continue; + } + + if (builtin_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/3}}); + } else if (builtin_code == BuiltinOperator_CONV_2D) { + tensor_map->insert({tensor_idx, + {tensor, /*is_per_channel=*/use_updated_hybrid_scheme, + /*dim=*/0}}); + } else { + switch (builtin_code) { + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsBidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsBidirectionalSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_FULLY_CONNECTED: + op->builtin_options.AsFullyConnectedOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_BATCH_MATMUL: + op->builtin_options.AsBatchMatMulOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_LSTM: + op->builtin_options.AsLSTMOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_RNN: + op->builtin_options.AsRNNOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_SVDF: + op->builtin_options.AsSVDFOptions()->asymmetric_quantize_inputs = + use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: + op->builtin_options.AsUnidirectionalSequenceLSTMOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: + op->builtin_options.AsSequenceRNNOptions() + ->asymmetric_quantize_inputs = use_updated_hybrid_scheme; + break; + default: + break; + } + tensor_map->insert({tensor_idx, {tensor, /*is_per_channel=*/false}}); + } + } + + return kTfLiteOk; +} + +// Updates operator code versions for the operators with INT8 inputs. +void UpdateInt8OperatorVersions(ModelT* model, bool use_updated_hybrid_scheme) { + for (int i = 0, end = model->operator_codes.size(); i < end; ++i) { + const BuiltinOperator& op_code = + GetBuiltinCode(model->operator_codes[i].get()); + if (op_code == BuiltinOperator_RNN || + op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 3 : 2; + } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM || + op_code == BuiltinOperator_EMBEDDING_LOOKUP) { + model->operator_codes[i]->version = 3; + } else if (op_code == BuiltinOperator_LSTM) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 3; + } else if (op_code == BuiltinOperator_CONV_2D) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 5 : 2; + } else if (op_code == BuiltinOperator_FULLY_CONNECTED) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 9 : 3; + } else if (op_code == BuiltinOperator_BATCH_MATMUL) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 1; + } else if (op_code == BuiltinOperator_SVDF) { + model->operator_codes[i]->version = use_updated_hybrid_scheme ? 4 : 2; + } else if (op_code == BuiltinOperator_DEPTHWISE_CONV_2D) { + model->operator_codes[i]->version = 6; + } + } +} + +// Returns true if the op in consumer_op_infos can pass through quantization. +bool IsQuantizationPassThroughOps( + const ModelT* model, const std::vector& consumer_op_infos) { + if (consumer_op_infos.size() != 1) { + return false; + } + const OperatorT* consumer_op = consumer_op_infos.front().op; + const BuiltinOperator op_code = + GetBuiltinCode(model->operator_codes[consumer_op->opcode_index].get()); + return op_code == BuiltinOperator_GATHER || + op_code == BuiltinOperator_EMBEDDING_LOOKUP; +} + +// Copies quantization parameters from input to output and returns consumers of +// the output tensor as a tuple with values: +// - index of the output tensor +// - pointer to the output tensor +// - vector of consumers ops. +std::tuple> +PassQuantizationAndGetConsumers( + const ModelT* model, const SubGraphT* subgraph, + const std::vector& consumer_op_infos, + const CustomOpMap& custom_op_map) { + const OperatorT* op = consumer_op_infos.front().op; + const OperatorCodeT* op_code = model->operator_codes[op->opcode_index].get(); + if (op->outputs.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized output"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t output_tensor_idx = op->outputs.front(); + const auto input_idx = GetWeightInputIndices(op_code, custom_op_map); + if (input_idx.size() != 1) { + LOG(ERROR) + << "An op that passes quantization has more than one quantized input"; + return std::make_tuple(-1, nullptr, std::vector()); + } + const int32_t input_tensor_idx = op->inputs[input_idx.front()]; + + // Propagate quantization params. + const TensorT* input_tensor = subgraph->tensors[input_tensor_idx].get(); + TensorT* output_tensor = subgraph->tensors[output_tensor_idx].get(); + if (!output_tensor->quantization) { + output_tensor->quantization = std::make_unique(); + } + *output_tensor->quantization = *input_tensor->quantization; + output_tensor->type = TensorType_INT8; + return std::make_tuple( + output_tensor_idx, output_tensor, + GetTensorConsumers(model, subgraph, output_tensor_idx)); +} + +inline bool IsOpDenylisted(const flat_hash_set& op_denylist, + const BuiltinOperator op_code) { + return op_denylist.find(op_code) != op_denylist.end(); +} + +absl::Status QuantizeWeightsInt8( + flatbuffers::FlatBufferBuilder* builder, const Model* input_model, + bool use_hybrid_evaluation, uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist = {}) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0; i < subgraph->operators.size(); ++i) { + OperatorT* op = subgraph->operators[i].get(); + if (InsertQuantizableInputTensorsFromOperator( + model.get(), op, weights_min_num_elements, custom_op_map, + &tensor_map, subgraph_index, + use_updated_hybrid_scheme) != kTfLiteOk) { + return absl::InternalError( + "Failed to insert quantizable input tensors from operator"); + } + } + + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (tensor_pair.second.is_per_channel) { + if (utils::SymmetricQuantizeTensorPerChannel( + model.get(), tensor_pair.second.t, + tensor_pair.second.channel_dim, nullptr) != kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor per channel"); + } + } else { + if (utils::SymmetricQuantizeTensor(model.get(), tensor_pair.second.t) != + kTfLiteOk) { + return absl::InternalError("Failed to quantize tensor"); + } + } + } + + // Examine the tensor consumers to determine which require dequantize ops. + for (const auto& tensor_pair : tensor_map) { + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second.t; + std::vector consumer_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + if (IsQuantizationPassThroughOps(model.get(), consumer_op_infos)) { + std::tie(tensor_idx, tensor, consumer_op_infos) = + PassQuantizationAndGetConsumers(model.get(), subgraph, + consumer_op_infos, custom_op_map); + if (tensor_idx < 0) { + // Error message is already logged by PassQuantizationAndGetConsumers. + return absl::InternalError( + "Failed to pass quantization and get consumers"); + } + } + + std::vector dequant_op_infos; // Ops that need dequants. + for (ConsumerOpInfo& consumer_op_info : consumer_op_infos) { + OperatorT* consumer_op = consumer_op_info.op; + const OperatorCodeT* consumer_op_code = + model->operator_codes[consumer_op->opcode_index].get(); + // If the op is a hybrid op and all the required tensors are quantized, + // we have no further work to do, but for all ops that require + // dequantization we need to add a Dequantize op. + bool eval_hybrid = + use_hybrid_evaluation && + !IsOpDenylisted(op_denylist, GetBuiltinCode(consumer_op_code)) && + IsHybridEvaluationOp(consumer_op, consumer_op_code, custom_op_map, + use_updated_hybrid_scheme) && + CheckAllOpInputsQuantized(subgraph, consumer_op, consumer_op_code, + custom_op_map) && + IsQuantizedInput(consumer_op_code, custom_op_map, + consumer_op_info.op_input_idx); + if (!eval_hybrid) { + dequant_op_infos.push_back(consumer_op_info); + } + } + + // Check if this tensor is an output tensor. + int32_t output_index = -1; + for (int32_t i = 0; i < subgraph->outputs.size(); ++i) { + if (subgraph->outputs[i] == tensor_idx) { + output_index = i; + break; + } + } + + // If no ops require dequant and it is not output, we are done for this + // tensor. + if (dequant_op_infos.empty() && output_index < 0) { + continue; + } + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const string dequant_name = tensor->name + "_dequantize"; + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + // Update output name. + if (output_index >= 0) { + subgraph->outputs[output_index] = dequantize_output_idx; + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + // Update the modified operator code versions. + UpdateInt8OperatorVersions(model.get(), use_updated_hybrid_scheme); + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + + return absl::OkStatus(); +} + +absl::Status QuantizeWeightsFloat16(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model) { + std::unique_ptr model; + model.reset(input_model->UnPack()); + + for (int subgraph_index = 0, end = model->subgraphs.size(); + subgraph_index < end; ++subgraph_index) { + SubGraphT* subgraph = model->subgraphs.at(subgraph_index).get(); + + absl::flat_hash_map tensor_map; + for (int i = 0, sub_end = subgraph->operators.size(); i < sub_end; ++i) { + OperatorT* op = subgraph->operators[i].get(); + for (auto tensor_idx : op->inputs) { + // Skip optional tensors. + if (tensor_idx == kTfLiteOptionalTensor) { + continue; + } + TensorT* tensor = subgraph->tensors[tensor_idx].get(); + BufferT* buffer = model->buffers[tensor->buffer].get(); + if (buffer == nullptr) { + return absl::InternalError("Buffer is null"); + } + // Quantize tensors that have data to quantize. + bool is_constant = !model->buffers[tensor->buffer].get()->data.empty(); + if (tensor->type == TensorType_FLOAT32 && is_constant) { + tensor_map.insert({tensor_idx, tensor}); + } + } + } + + // The hash map ensures that we quantize each tensor exactly once. + for (std::pair tensor_pair : tensor_map) { + // Quantize the tensor. + if (utils::QuantizeTensorFloat16(model.get(), tensor_pair.second) != + kTfLiteOk) { + return absl::InternalError("QuantizeTensorFloat16 failed"); + } + + int32_t tensor_idx = tensor_pair.first; + TensorT* tensor = tensor_pair.second; + std::vector dequant_op_infos = + GetTensorConsumers(model.get(), subgraph, tensor_idx); + + // Create a new tensor to be the output of the dequantize op. + std::unique_ptr dequantize_output; + const string dequant_name = tensor->name + "_dequantize"; + utils::MakeTensor(dequant_name, tensor->shape, tensor->shape_signature, + TensorType_FLOAT32, &dequantize_output); + const int32_t dequantize_output_idx = subgraph->tensors.size(); + subgraph->tensors.push_back(std::move(dequantize_output)); + + // Create the Dequantize operation. + std::unique_ptr dequantize_op; + utils::MakeDequantizeOperator(model.get(), &dequantize_op, tensor_idx, + dequantize_output_idx); + + // Update the op_input of all the ops that need the created dequantize + // operation. + int32_t min_op_idx = subgraph->operators.size(); + for (ConsumerOpInfo& dequant_op_info : dequant_op_infos) { + dequant_op_info.op->inputs[dequant_op_info.op_input_idx] = + dequantize_output_idx; + min_op_idx = std::min(dequant_op_info.op_idx, min_op_idx); + } + + // Insert the newly created Dequantize operation before the earliest + // consumer, since TFLite requires operators to be topo-sorted. + subgraph->operators.insert(subgraph->operators.begin() + min_op_idx, + std::move(dequantize_op)); + } + } + + flatbuffers::Offset output_model_location = + Model::Pack(*builder, model.get()); + FinishModelBuffer(*builder, output_model_location); + return absl::OkStatus(); +} +} // namespace + +namespace internal { +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + bool use_hybrid_evaluation, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, use_hybrid_evaluation, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} +} // namespace internal + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, BufferType quant_type, + bool use_updated_hybrid_scheme, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + switch (quant_type) { + case BufferType::QUANTIZED_INT8: { + // By default we require that only weights with more than + // kWeightsMinSizeDefault elements are quantized. + CustomOpMap custom_op_map; + return QuantizeWeightsInt8(builder, input_model, true, + kWeightsMinNumElementsDefault, custom_op_map, + use_updated_hybrid_scheme); + } + case BufferType::QUANTIZED_FLOAT16: + return QuantizeWeightsFloat16(builder, input_model); + } +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + return QuantizeWeightsInt8(builder, input_model, true, + weights_min_num_elements, custom_op_map, + kUseUpdatedHybridSchemeDefault); +} + +absl::Status QuantizeWeights(flatbuffers::FlatBufferBuilder* builder, + const Model* input_model, + uint64_t weights_min_num_elements, + const CustomOpMap& custom_op_map, + bool use_updated_hybrid_scheme, + const flat_hash_set& op_denylist, + QuantizerType quantizer_type) { + if (quantizer_type == QuantizerType::MLIR_QUANTIZER) { + LOG(ERROR) << "Portable targets cannot use the MLIR quantizer."; + return absl::InternalError( + "Portable targets cannot use the MLIR quantizer."); + } + return QuantizeWeightsInt8(builder, input_model, + /*use_hybrid_evaluation=*/true, + weights_min_num_elements, custom_op_map, + use_updated_hybrid_scheme, op_denylist); +} + +} // namespace optimize +} // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc new file mode 100644 index 00000000000000..7277e1dfbbe438 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights_test.cc @@ -0,0 +1,702 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "flatbuffers/vector.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/test_util.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tensorflow/lite/core/model_builder.h" // TODO: b/321735756 - replace with mlir model_builder +#include "tsl/platform/init_main.h" +#include "tsl/platform/path.h" + +namespace { +std::string* g_test_model_dir = nullptr; +} // namespace + +namespace mlir { +namespace lite { +namespace toco_legacy { +namespace { + +using tflite::BuiltinOperator_CONV_2D; +using tflite::BuiltinOperator_CUSTOM; +using tflite::BuiltinOperator_DEQUANTIZE; +using tflite::FlatBufferModel; // to remove when mlir version is ready, from + // model.h +using tflite::GetModel; +using tflite::Model; +using tflite::TensorType_FLOAT16; +using tflite::TensorType_FLOAT32; +using tflite::TensorType_INT8; + +std::unique_ptr ReadTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kConvModelWith0Plus10Weights); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadSharedWeightsTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithSharedWeights); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadGatherTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kQuantizedWithGather); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +std::unique_ptr ReadCustomOpTestModel() { + auto model_path = tsl::io::JoinPath( + *g_test_model_dir, ::mlir::lite::internal::kModelWithCustomOp); + return FlatBufferModel::BuildFromFile(model_path.c_str()); +} + +template +std::vector GetAsVector(const flatbuffers::Vector* vec) { + return std::vector(vec->begin(), vec->end()); +} + +class QuantizeWeightsTest : public testing::Test { + protected: + QuantizeWeightsTest() = default; + + void LoadBasicModel() { + input_model_ = ReadTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadSharedWeightsModel() { + input_model_ = ReadSharedWeightsTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadGatherTestModel() { + input_model_ = ReadGatherTestModel(); + model_ = input_model_->GetModel(); + } + + void LoadCustomOpTestModel() { + input_model_ = ReadCustomOpTestModel(); + model_ = input_model_->GetModel(); + } + + std::unique_ptr input_model_; + const Model* model_; + + bool IsModelInputOrOutput(const Model* model, uint32_t tensor_idx) { + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto subgraph = model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < subgraph->inputs()->size(); ++i) { + if (subgraph->inputs()->Get(i) == tensor_idx) { + return true; + } + } + for (size_t i = 0; i < subgraph->outputs()->size(); ++i) { + if (subgraph->outputs()->Get(i) == tensor_idx) { + return true; + } + } + } + return false; + } + + // Returns the producer op code of the specified tensor_idx. + bool GetProducerOpCode(const Model* model, uint32_t subgraph_idx, + uint32_t tensor_idx, + tflite::BuiltinOperator* op_code) { + const auto subgraph = model->subgraphs()->Get(subgraph_idx); + for (size_t op_idx = 0; op_idx < subgraph->operators()->size(); ++op_idx) { + const auto op = subgraph->operators()->Get(op_idx); + for (size_t i = 0; i < op->outputs()->size(); ++i) { + if (op->outputs()->Get(i) == tensor_idx) { + const uint32_t op_code_idx = op->opcode_index(); + *op_code = GetBuiltinCode(model->operator_codes()->Get(op_code_idx)); + return true; + } + } + } + return false; + } +}; + +TEST_F(QuantizeWeightsTest, QuantizationSucceeds) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); +} + +TEST_F(QuantizeWeightsTest, WeightsMinNumElements) { + LoadBasicModel(); + // Make weights_min_size sufficiently large such that no quantization should + // happen, i.e. the original model is the same size as the old one. + flatbuffers::FlatBufferBuilder builder; + const uint64_t kWeightsMinNumElements = 1000000; + ASSERT_TRUE(QuantizeWeights(&builder, model_, kWeightsMinNumElements, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + // Everything should remain equal between the two graphs. + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + EXPECT_EQ(quant_tensor->type(), float_tensor->type()); + } + } +} + +TEST_F(QuantizeWeightsTest, HybridConv) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + // Nothing should change. + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + // Make sure the graph only has one Conv operation. + ASSERT_EQ(quantized_graph->operators()->size(), 1); + const auto op = quantized_graph->operators()->Get(0); + const uint32_t op_code_idx = op->opcode_index(); + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), + BuiltinOperator_CONV_2D); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + // If the tensor is a weight, it should have type INT8, otherwise it + // should stay with type FLOAT32. + // If the tensor is a bias, it should have type FLOAT32. + if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8) + << quant_tensor->name()->str(); + auto shape = GetAsVector(quant_tensor->shape()); + if (kUseUpdatedHybridSchemeDefault) { + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), shape[0]); + } else { + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1); + } + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConv) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation=*/false, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have an extra tensor from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 1); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type INT8. + // If the tensor is a bias, it should have type FLOAT32. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be INT8, and all other tensors should be + // FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConvFloat16) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, BufferType::QUANTIZED_FLOAT16, + kUseUpdatedHybridSchemeDefault, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have two extra tensors from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 2); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type FLOAT16. + // If the tensor is a bias, it should have type FLOAT16. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be FLOAT16, and all other tensors should + // be FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT16); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, SharedWeights_Hybrid) { + LoadSharedWeightsModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + uint32_t num_conv_ops = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CONV_2D) { + num_conv_ops++; + // Ensure that each convolution's weights tensor is now INT8. + const auto weights_tensor = + quantized_graph->tensors()->Get(op->inputs()->Get(1)); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + } + } + } + // Ensure that there were exactly two convolutions in the model. + EXPECT_EQ(num_conv_ops, 2); +} + +TEST_F(QuantizeWeightsTest, SharedWeights_Dequantize) { + LoadSharedWeightsModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(internal::QuantizeWeights(&builder, model_, 0, + /*use_hybrid_evaluation*/ false, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + uint32_t num_conv_ops = 0; + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CONV_2D) { + num_conv_ops++; + // Ensure that each convolution's weights tensor is still FLOAT + // (the output of the dequantize). + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32); + + // Check that it comes from a dequantize operation. + BuiltinOperator producer_op_code; + ASSERT_TRUE(GetProducerOpCode(output_model, subgraph_idx, + weights_tensor_index, &producer_op_code)); + EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE); + } + } + } + // Ensure that there were exactly two convolutions in the model. + EXPECT_EQ(num_conv_ops, 2); +} + +TEST_F(QuantizeWeightsTest, VerifyGatherQuantization) { + LoadGatherTestModel(); + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE( + QuantizeWeights(&builder, model_, 0, QuantizerType::OLD_QUANTIZER).ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == tflite::BuiltinOperator_GATHER) { + uint32_t input_tensor_index = op->inputs()->Get(0); + const auto weights_tensor = + quantized_graph->tensors()->Get(input_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + } + } + } +} + +TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationDequantize) { + LoadCustomOpTestModel(); + + // The custom op is not hybrid, and the second input is a constant that can + // be quantized. + CustomOpMap custom_op_map; + custom_op_map["CustomTestOp"] = { + .quantizable_input_indices = {1}, + .is_hybrid = false, + }; + + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + const auto quantized_graph = output_model->subgraphs()->Get(0); + // A dequantize op should be added. + ASSERT_EQ(quantized_graph->operators()->size(), + model_->subgraphs()->Get(0)->operators()->size() + 1); + int num_custom_ops_found = 0; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CUSTOM) { + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_FLOAT32); + + // Check that it comes from a dequantize operation. + BuiltinOperator producer_op_code; + ASSERT_TRUE(GetProducerOpCode(output_model, 0, weights_tensor_index, + &producer_op_code)); + EXPECT_EQ(producer_op_code, BuiltinOperator_DEQUANTIZE); + num_custom_ops_found++; + } + } + EXPECT_EQ(num_custom_ops_found, 1); +} + +TEST_F(QuantizeWeightsTest, VerifyCustomOpQuantizationHybrid) { + LoadCustomOpTestModel(); + + // The custom op is hybrid, and the second input is a constant that can + // be quantized. + CustomOpMap custom_op_map; + custom_op_map["CustomTestOp"] = { + .quantizable_input_indices = {1}, + .is_hybrid = true, + }; + + flatbuffers::FlatBufferBuilder builder; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + const auto quantized_graph = output_model->subgraphs()->Get(0); + ASSERT_EQ(quantized_graph->operators()->size(), + model_->subgraphs()->Get(0)->operators()->size()); + int num_custom_ops_found = 0; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + const auto op_code = + GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)); + if (op_code == BuiltinOperator_CUSTOM) { + uint32_t weights_tensor_index = op->inputs()->Get(1); + const auto weights_tensor = + quantized_graph->tensors()->Get(weights_tensor_index); + EXPECT_EQ(weights_tensor->type(), TensorType_INT8); + num_custom_ops_found++; + } + } + EXPECT_EQ(num_custom_ops_found, 1); +} + +TEST_F(QuantizeWeightsTest, VerifyUpdatedHybridSchemeFalseQuantizationHybrid) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + const CustomOpMap custom_op_map; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/false, + /*op_denylist=*/{}, QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + // Nothing should change. + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + subgraph_idx++) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size()); + // Make sure the graph only has one Conv operation. + ASSERT_EQ(quantized_graph->operators()->size(), 1); + const auto op = quantized_graph->operators()->Get(0); + const uint32_t op_code_idx = op->opcode_index(); + ASSERT_EQ(GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)), + BuiltinOperator_CONV_2D); + for (size_t i = 0; i < quantized_graph->tensors()->size(); i++) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + const auto float_tensor = float_graph->tensors()->Get(i); + EXPECT_EQ(quant_tensor->buffer(), float_tensor->buffer()); + EXPECT_EQ(quant_tensor->is_variable(), float_tensor->is_variable()); + EXPECT_EQ(GetAsVector(quant_tensor->shape()), + GetAsVector(float_tensor->shape())); + EXPECT_EQ(quant_tensor->name()->str(), float_tensor->name()->str()); + // If the tensor is a weight, it should have type INT8, otherwise it + // should stay with type FLOAT32. + // If the tensor is a bias, it should have type FLOAT32. + if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8) + << quant_tensor->name()->str(); + auto shape = GetAsVector(quant_tensor->shape()); + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 1); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +TEST_F(QuantizeWeightsTest, DequantizeConvBlocklisted) { + LoadBasicModel(); + flatbuffers::FlatBufferBuilder builder; + const CustomOpMap custom_op_map; + ASSERT_TRUE(QuantizeWeights(&builder, model_, 0, custom_op_map, + /*use_updated_hybrid_scheme=*/true, + /*op_denylist*/ {BuiltinOperator_CONV_2D}, + QuantizerType::OLD_QUANTIZER) + .ok()); + + const uint8_t* buffer = builder.GetBufferPointer(); + const Model* output_model = GetModel(buffer); + ASSERT_TRUE(output_model); + + ASSERT_EQ(output_model->subgraphs()->size(), model_->subgraphs()->size()); + for (size_t subgraph_idx = 0; subgraph_idx < model_->subgraphs()->size(); + ++subgraph_idx) { + const auto quantized_graph = output_model->subgraphs()->Get(subgraph_idx); + const auto float_graph = model_->subgraphs()->Get(subgraph_idx); + // The output graph should have an extra tensor from the added dequantize + // op. + ASSERT_EQ(quantized_graph->tensors()->size(), + float_graph->tensors()->size() + 1); + // Check that a dequantize op exists. + int32_t dequant_input_idx = -1; + int32_t dequant_output_idx = -1; + for (size_t i = 0; i < quantized_graph->operators()->size(); ++i) { + const auto op = quantized_graph->operators()->Get(i); + const uint32_t op_code_idx = op->opcode_index(); + if (GetBuiltinCode(output_model->operator_codes()->Get(op_code_idx)) == + BuiltinOperator_DEQUANTIZE) { + dequant_input_idx = op->inputs()->Get(0); + dequant_output_idx = op->outputs()->Get(0); + } + } + ASSERT_GT(dequant_input_idx, -1); + ASSERT_GT(dequant_output_idx, -1); + for (size_t i = 0; i < quantized_graph->tensors()->size(); ++i) { + const auto quant_tensor = quantized_graph->tensors()->Get(i); + // If the tensor is a weight, it should have type INT8. + // If the tensor is a bias, it should have type FLOAT32. + // If the tensor is an input or output it should have type FLOAT32. + // The input to dequantize should be INT8, and all other tensors should be + // FLOAT32. + if (i == dequant_input_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + // The dequantize should still be quantized per-channel + EXPECT_EQ(quant_tensor->quantization()->scale()->size(), 5); + EXPECT_EQ(quant_tensor->quantization()->quantized_dimension(), 0); + } else if (i == dequant_output_idx) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (IsModelInputOrOutput(output_model, i)) { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->name()->str() == "conv_bias") { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } else if (quant_tensor->buffer() != 0) { + // If it's a non-bias constant tensor, it must be the weight. + EXPECT_EQ(quant_tensor->type(), TensorType_INT8); + } else { + EXPECT_EQ(quant_tensor->type(), TensorType_FLOAT32); + } + } + } +} + +} // namespace +} // namespace toco_legacy +} // namespace lite +} // namespace mlir + +int main(int argc, char** argv) { + std::string model_file; + const std::vector flag_list = { + tsl::Flag("test_model_file", &model_file, + "Path to test tflite model file."), + }; + + const bool parse_result = tsl::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + std::cerr << "Required test_model_file\n"; + std::abort(); + } + g_test_model_dir = new std::string(tsl::io::Dirname(model_file)); + ::tsl::port::InitMain(argv[0], &argc, &argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/mlir/lite/schema/BUILD b/tensorflow/compiler/mlir/lite/schema/BUILD index 7dd8eecb66d935..14e80f4f363a15 100644 --- a/tensorflow/compiler/mlir/lite/schema/BUILD +++ b/tensorflow/compiler/mlir/lite/schema/BUILD @@ -11,7 +11,15 @@ package( ) exports_files( - srcs = ["schema.fbs"], + srcs = [ + "schema.fbs", + "schema_v0.fbs", + "schema_v1.fbs", + "schema_v2.fbs", + "schema_v3.fbs", + "schema_v3a.fbs", + "schema_v3b.fbs", + ], ) filegroup( diff --git a/tensorflow/lite/schema/schema_v0.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v0.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v0.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v0.fbs diff --git a/tensorflow/lite/schema/schema_v1.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v1.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v1.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v1.fbs diff --git a/tensorflow/lite/schema/schema_v2.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v2.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v2.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v2.fbs diff --git a/tensorflow/lite/schema/schema_v3.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v3.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v3.fbs diff --git a/tensorflow/lite/schema/schema_v3a.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v3a.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v3a.fbs diff --git a/tensorflow/lite/schema/schema_v3c.fbs b/tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs similarity index 100% rename from tensorflow/lite/schema/schema_v3c.fbs rename to tensorflow/compiler/mlir/lite/schema/schema_v3c.fbs diff --git a/tensorflow/compiler/mlir/lite/sparsity/BUILD b/tensorflow/compiler/mlir/lite/sparsity/BUILD index 92e6aaadff6d46..7c15ac2756e6ba 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/BUILD +++ b/tensorflow/compiler/mlir/lite/sparsity/BUILD @@ -54,9 +54,9 @@ tf_cc_test( ], deps = [ ":sparsify_model", + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", - "//tensorflow/lite/core:model_builder", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@flatbuffers", diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc index 0d1339d710d938..cc557b5f4d5112 100644 --- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc +++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include #include "absl/status/status.h" #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" -#include "tensorflow/lite/core/model_builder.h" namespace mlir { namespace lite { @@ -41,7 +41,7 @@ TEST(SparsifyModelTest, MetadataIsAddedToOutputModel) { std::string expected_value = "test_data"; // Load input model - auto input_fbm = tflite::FlatBufferModel::BuildFromFile( + auto input_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromFile( "tensorflow/compiler/mlir/lite/sparsity/testdata/" "sparse_tensor.bin"); tflite::ModelT input_model; @@ -60,7 +60,7 @@ TEST(SparsifyModelTest, MetadataIsAddedToOutputModel) { // Sparsify and create output model flatbuffers::FlatBufferBuilder output_builder; ASSERT_TRUE(SparsifyModel(input_model, &output_builder).ok()); - auto output_fbm = tflite::FlatBufferModel::BuildFromBuffer( + auto output_fbm = mlir::TFL::FlatBufferModelAbslError::BuildFromBuffer( reinterpret_cast(output_builder.GetCurrentBufferPointer()), output_builder.GetSize()); tflite::ModelT output_model; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 3fe45b85d7bba9..067d1cc185c2c6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -549,7 +549,6 @@ cc_library( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], alwayslink = 1, @@ -641,10 +640,15 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:custom_call", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:dot_general", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:gather", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:iota", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:sort", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:util", "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -688,9 +692,11 @@ cc_library( deps = [ ":passes_inc_gen", ":prepare_hlo_inc_gen", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:conv_util", "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:pad_util", - "@llvm-project//llvm:Support", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:reduce_window", + "//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions:slice", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -797,6 +803,39 @@ cc_library( ], ) +cc_library( + name = "lift_callsite_loc_caller", + srcs = ["transforms/torch/lift_callsite_loc_caller_pass.cc"], + copts = ["-Ithird_party"], + deps = [ + ":passes_inc_gen", + ":prepare_hlo", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], + alwayslink = True, +) + +cc_library( + name = "build_stablehlo_composite", + srcs = ["transforms/torch/build_stablehlo_composite_pass.cc"], + copts = ["-Ithird_party"], + deps = [ + ":passes_inc_gen", + ":prepare_hlo", + "@com_google_absl//absl/strings", + "@jsoncpp_git//:jsoncpp", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], + alwayslink = True, +) + cc_library( name = "composite_lowering", srcs = [ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir index 7db47a1a3e7703..a06886a8d4688a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/optimize.mlir @@ -408,3 +408,107 @@ func.func @testConvertReshapeDotRhsToBatchedDot(%arg0: tensor<1x72x72xf32>, %arg // CHECK-SAME: >}> : (tensor<1x72x72xf32>, tensor<1x72x128xf32>) -> tensor<1x72x128xf32> // CHECK: return %[[R]] : tensor<1x72x128xf32> } + +// ----- + +// CHECK-LABEL: broadcast_reshape_one_non_unit_dimnsion +func.func @broadcast_reshape_one_non_unit_dimnsion(%arg0: tensor<1x1x1x63xf32>) -> tensor<32x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + return %1 : tensor<32x1x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x1x1x63xf32>) -> tensor<63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<63xf32>) -> tensor<32x1x63xf32> +// CHECK: return %1 : tensor<32x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_one_non_unit_dimnsion_trailing_zeros +func.func @broadcast_reshape_one_non_unit_dimnsion_trailing_zeros(%arg0: tensor<63x1x1x1xf32>) -> tensor<63x1x2xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> + %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<63x1x2xf32> + return %1 : tensor<63x1x2xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<63x1x1x1xf32>) -> tensor<63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<63xf32>) -> tensor<63x1x2xf32> +// CHECK: return %1 : tensor<63x1x2xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multiple_non_unit_dimension +func.func @broadcast_reshape_multiple_non_unit_dimension(%arg0: tensor<1x2x1x63xf32>) -> tensor<2x3x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<1x2x3x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x2x3x63xf32>) -> tensor<2x3x63xf32> + return %1 : tensor<2x3x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<2x3x63xf32> +// CHECK: return %1 : tensor<2x3x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multiple_non_unit_dimension_unsorted_broadcast_dims +func.func @broadcast_reshape_multiple_non_unit_dimension_unsorted_broadcast_dims(%arg0: tensor<1x2x1x63xf32>) -> tensor<3x2x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 2, 1, 3]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<3x1x2x63xf32> + %1 = mhlo.reshape %0 : (tensor<3x1x2x63xf32>) -> tensor<3x2x63xf32> + return %1 : tensor<3x2x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<3x2x63xf32> +// CHECK: return %1 : tensor<3x2x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_broadcast_increases_rank +func.func @broadcast_reshape_broadcast_increases_rank(%arg0: tensor<1x2x1x63xf32>) -> tensor<2x3x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 4]> : tensor<4xi64>}> : (tensor<1x2x1x63xf32>) -> tensor<1x2x3x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x2x3x1x63xf32>) -> tensor<2x3x63xf32> + return %1 : tensor<2x3x63xf32> +} + +// CHECK: %0 = mhlo.reshape %arg0 : (tensor<1x2x1x63xf32>) -> tensor<2x63xf32> +// CHECK: %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[0, 2]> : tensor<2xi64>}> : (tensor<2x63xf32>) -> tensor<2x3x63xf32> +// CHECK: return %1 : tensor<2x3x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_not_same_non_unit_dims +func.func @broadcast_reshape_not_same_non_unit_dims(%arg0: tensor<63x1x1x1xf32>) -> tensor<2x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> + %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<2x1x63xf32> + return %1 : tensor<2x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<63x1x1x1xf32>) -> tensor<63x1x1x2xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<63x1x1x2xf32>) -> tensor<2x1x63xf32> +// CHECK: return %1 : tensor<2x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_multi_use +func.func @broadcast_reshape_multi_use(%arg0: tensor<1x1x1x63xf32>) -> (tensor<32x1x63xf32>, tensor<1x32x1x63xf32>) { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + return %1, %0 : tensor<32x1x63xf32>, tensor<1x32x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x63xf32> + +// ----- + +// CHECK-LABEL: broadcast_reshape_rank_increase +func.func @broadcast_reshape_rank_increase(%arg0: tensor<1x1x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> { + %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> + %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> + return %1 : tensor<32x1x1x1x1x63xf32> +} + +// CHECK: %0 = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}> : (tensor<1x1x1x63xf32>) -> tensor<1x32x1x63xf32> +// CHECK: %1 = mhlo.reshape %0 : (tensor<1x32x1x63xf32>) -> tensor<32x1x1x1x1x63xf32> + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir index 411680f28b2c97..38fc6d57d93015 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir @@ -15,6 +15,75 @@ func.func @main(%arg0: tensor) -> tensor { // 2D //=-- +// CHECK-LABEL: transpose_conv2d_same_padding_nchw_ihwo +func.func @transpose_conv2d_same_padding_nchw_ihwo(%input: tensor<1x2x256x256xf32>, %filter:tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> { + %1 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x2x256x256xf32>, tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> + func.return %1 : tensor<1x2x512x512xf32> +} + +// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) +// CHECK-SAME: permutation +// CHECK-SAME: [1, 2, 3, 0] +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_KERNEL]]) +// CHECK-SAME: [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 3, 1, 2] + +// CHECK-LABEL: transpose_conv2d_same_padding_nchw_oihw +func.func @transpose_conv2d_same_padding_nchw_oihw(%input: tensor<1x2x256x256xf32>, %filter:tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x2x256x256xf32>, tensor<2x2x4x4xf32>) -> tensor<1x2x512x512xf32> + func.return %0 : tensor<1x2x512x512xf32> +} + +// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_KERNEL]]) +// CHECK-SAME: [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) +// CHECK-SAME: permutation +// CHECK-SAME: [0, 3, 1, 2] + +// ----- + +// CHECK-LABEL: depthwise_transpose_conv2d_same_padding_nchw_hwoi +func.func @depthwise_transpose_conv2d_same_padding_nchw_hwoi(%input: tensor<1x2x20x20xf32>, %filter:tensor<8x8x2x1xf32>) -> tensor<1x2x80x80xf32> { + %1 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], + window = {pad = [[5, 5], [5, 5]], lhs_dilate = [4, 4]} + {batch_group_count = 1 : i64, feature_group_count = 2 : i64} + : (tensor<1x2x20x20xf32>, tensor<8x8x2x1xf32>) -> tensor<1x2x80x80xf32> + func.return %1 : tensor<1x2x80x80xf32> + + // CHECK: %0 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<1x2x20x20xf32>) -> tensor<1x20x20x2xf32> + // CHECK: %1 = "mhlo.transpose"(%arg1) <{permutation = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : (tensor<8x8x2x1xf32>) -> tensor<2x8x8x1xf32> + // CHECK: %2 = "mhlo.slice"(%0) <{limit_indices = dense<[1, 20, 20, 1]> : tensor<4xi64>, start_indices = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<1x20x20x2xf32>) -> tensor<1x20x20x1xf32> + // CHECK: %3 = "mhlo.slice"(%1) <{limit_indices = dense<[1, 8, 8, 1]> : tensor<4xi64>, start_indices = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<2x8x8x1xf32>) -> tensor<1x8x8x1xf32> + // CHECK: %4 = mhlo.convolution(%2, %3) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = {{\[\[}}5, 5], [5, 5]], lhs_dilate = [4, 4]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x20x20x1xf32>, tensor<1x8x8x1xf32>) -> tensor<1x80x80x1xf32> + // CHECK: %5 = "mhlo.slice"(%0) <{limit_indices = dense<[1, 20, 20, 2]> : tensor<4xi64>, start_indices = dense<[0, 0, 0, 1]> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<1x20x20x2xf32>) -> tensor<1x20x20x1xf32> + // CHECK: %6 = "mhlo.slice"(%1) <{limit_indices = dense<[2, 8, 8, 1]> : tensor<4xi64>, start_indices = dense<[1, 0, 0, 0]> : tensor<4xi64>, strides = dense<1> : tensor<4xi64>}> : (tensor<2x8x8x1xf32>) -> tensor<1x8x8x1xf32> + // CHECK: %7 = mhlo.convolution(%5, %6) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = {{\[\[}}5, 5], [5, 5]], lhs_dilate = [4, 4]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x20x20x1xf32>, tensor<1x8x8x1xf32>) -> tensor<1x80x80x1xf32> + // CHECK: %8 = "mhlo.concatenate"(%4, %7) <{dimension = 3 : i64}> : (tensor<1x80x80x1xf32>, tensor<1x80x80x1xf32>) -> tensor<1x80x80x2xf32> + // CHECK: %9 = "mhlo.transpose"(%8) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<1x80x80x2xf32>) -> tensor<1x2x80x80xf32> + // CHECK: return %9 : tensor<1x2x80x80xf32> +} + // CHECK-LABEL: conv2d_nhwc_ohwi_nhwc func.func @conv2d_nhwc_ohwi_nhwc(%input: tensor<1x256x256x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { %0 = mhlo.convolution(%input, %filter) @@ -141,105 +210,180 @@ func.func @conv2d_nchw_oihw_nchw(%input: tensor<1x3x256x256xf32>, %filter: tenso // ----- -// 1D -//=-- +// CHECK-LABEL: conv2d_nhwc_ohwi_nhwc_padded +func.func @conv2d_nhwc_ohwi_nhwc_padded(%input: tensor<1x254x254x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { + %0 = "mhlo.convolution"(%input, %filter) { + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, + batch_group_count = 1 : i64, + feature_group_count = 1 : i64, + window_strides = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 1]> : tensor<2xi64>, + lhs_dilation = dense<[1, 1]> : tensor<2xi64> + } : (tensor<1x254x254x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> + func.return %0 : tensor<1x256x256x2xf32> +} -// CHECK-LABEL: conv1d_nsc_osi_nsc -func.func @conv1d_nsc_osi_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad" +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]] +// CHECK-SAME: pad +// CHECK-SAME: [0, 0], [0, 0] +// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> + +// ----- + +// CHECK-LABEL: conv2d_nhwc_ohwi_nhwc_asymmetric_padded +func.func @conv2d_nhwc_ohwi_nhwc_asymmetric_padded(%input: tensor<1x255x255x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { + %0 = "mhlo.convolution"(%input, %filter) { + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, f]x[o, 0, i]->[b, 0, f]>, - feature_group_count = 1 : i64 - } : (tensor<16x32x256xf32>, tensor<256x1x256xf32>) -> tensor<16x32x256xf32> - func.return %0 : tensor<16x32x256xf32> + feature_group_count = 1 : i64, + window_strides = dense<1> : tensor<2xi64>, + padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, + rhs_dilation = dense<[1, 1]> : tensor<2xi64>, + lhs_dilation = dense<[1, 1]> : tensor<2xi64> + } : (tensor<1x255x255x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> + func.return %0 : tensor<1x256x256x2xf32> } -// CHECK-NOT: transpose -// CHECK: [b, 0, f]x[o, 0, i]->[b, 0, f] -// CHECK-NOT: transpose +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad" +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<0> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]] +// CHECK-SAME: pad +// CHECK-SAME: [0, 0], [0, 0] +// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> + // ----- -// CHECK-LABEL: conv1d_ncs_osi_nsc -func.func @conv1d_ncs_osi_nsc(%arg0: tensor<16x256x32xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { +// CHECK-LABEL: conv2d_nchw_ohwi_nhwc_padded +func.func @conv2d_nchw_ohwi_nhwc_padded(%input: tensor<1x3x253x249xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, f, 0, 1]x[o, 0, 1, i]->[b, 0, 1, f], + window = {stride = [1, 1], pad = [[1, 2], [3, 4]]} { batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, f, 0]x[o, 0, i]->[b, 0, f]>, feature_group_count = 1 : i64 - } : (tensor<16x256x32xf32>, tensor<256x1x256xf32>) -> tensor<16x32x256xf32> - func.return %0 : tensor<16x32x256xf32> + } : (tensor<1x3x253x249xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> + func.return %0 : tensor<1x256x256x2xf32> } -// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) -// CHECK-SAME: permutation -// CHECK-SAME: [0, 2, 1] -// CHECK: mhlo.convolution(%[[TRANSPOSED_INPUT]], %arg1) -// CHECK-SAME: [b, 0, f]x[o, 0, i]->[b, 0, f] -// CHECK-NOT: transpose +// Want to ensure that we transpose before padding input (which this test does implicitly). + +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad" +// CHECK-SAME: edge_padding_high = dense<[0, 2, 4, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 3, 0]> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]], %arg1) +// CHECK-SAME: pad +// CHECK-SAME: [0, 0], [0, 0] +// CHECK-SAME: (tensor<1x256x256x3xf32>, tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> // ----- -// CHECK-LABEL: conv1d_nsc_sio_nsc -func.func @conv1d_nsc_sio_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<1x256x256xf32>) -> tensor<16x32x256xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { +// CHECK-LABEL: conv2d_nchw_ohwi_nhwc_padded_dilated_lhs +func.func @conv2d_nchw_ohwi_nhwc_padded_dilated_lhs(%input: tensor<1x64x64x256xf32>, %filter: tensor<64x2x2x256xf32>) -> tensor<1x128x128x64xf32> { + %0 = "mhlo.convolution"(%input, %filter) { batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, - feature_group_count = 1 : i64 - } : (tensor<16x32x256xf32>, tensor<1x256x256xf32>) -> tensor<16x32x256xf32> - func.return %0 : tensor<16x32x256xf32> + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<2> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<2xi64>, + window_reversal = dense : tensor<2xi1>, + window_strides = dense<1> : tensor<2xi64>} : + (tensor<1x64x64x256xf32>, tensor<64x2x2x256xf32>) -> tensor<1x128x128x64xf32> + func.return %0 : tensor<1x128x128x64xf32> +} + +// CHECK-NOT: mhlo.pad +// CHECK: mhlo.convolution +// CHECK-SAME: pad +// CHECK-SAME: [1, 1], [1, 1] +// CHECK-SAME: lhs_dilate = [2, 2] + +// ----- + +// CHECK-LABEL: depthwise_conv2d_nhwc_ohwi_nhwc +func.func @depthwise_conv2d_nhwc_ohwi_nhwc(%arg0: tensor<1x10x10x207xf32>, %arg1: tensor<3312x3x3x1xf32>) -> tensor<1x8x8x3312xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x10x10x207xf32>, tensor<3312x3x3x1xf32>) -> tensor<1x8x8x3312xf32> + func.return %0 : tensor<1x8x8x3312xf32> } // CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) // CHECK-SAME: permutation -// CHECK-SAME: [2, 0, 1] +// CHECK-SAME: [3, 1, 2, 0] // CHECK: mhlo.convolution(%arg0, %[[TRANSPOSED_KERNEL]]) -// CHECK-SAME: [b, 0, f]x[o, 0, i]->[b, 0, f] +// CHECK-SAME: [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] // CHECK-NOT: transpose // ----- -// CHECK-LABEL: conv1d_nsc_osi_ncs -func.func @conv1d_nsc_osi_ncs(%arg0: tensor<16x32x256xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x256x32xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { + +// CHECK-LABEL: depthwise_conv2d_nchw_ihwo_nhwc +func.func @depthwise_conv2d_nchw_ihwo_nhwc(%arg0: tensor<1x207x10x10xf32>, %arg1: tensor<1x3x3x3312xf32>) -> tensor<1x8x8x3312xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, f]x[o, 0, i]->[b, f, 0]>, - feature_group_count = 1 : i64 - } : (tensor<16x32x256xf32>, tensor<256x1x256xf32>) -> tensor<16x256x32xf32> - func.return %0 : tensor<16x256x32xf32> + dimension_numbers = #mhlo.conv<[b, f, 0, 1]x[i, 0, 1, o]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x207x10x10xf32>, tensor<1x3x3x3312xf32>) -> tensor<1x8x8x3312xf32> + func.return %0 : tensor<1x8x8x3312xf32> } -// CHECK-NOT: transpose -// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution -// CHECK-SAME: [b, 0, f]x[o, 0, i]->[b, 0, f] -// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) +// CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) // CHECK-SAME: permutation -// CHECK-SAME: [0, 2, 1] - +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: mhlo.convolution(%[[TRANSPOSED_INPUT]], %arg1) +// CHECK-SAME: [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] +// CHECK-NOT: transpose // ----- -// CHECK-LABEL: conv1d_ncs_ois_ncs -func.func @conv1d_ncs_ois_ncs(%arg0: tensor<16x256x32xf32>, %arg1: tensor<256x256x1xf32>) -> tensor<16x256x32xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { +// CHECK-LABEL: depthwise_conv2d_nchw_ihwo_nhwc_padded +func.func @depthwise_conv2d_nchw_ihwo_nhwc_padded(%arg0: tensor<1x207x8x8xf32>, %arg1: tensor<1x3x3x3312xf32>) -> tensor<1x8x8x3312xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, f, 0]x[o, i, 0]->[b, f, 0]>, - feature_group_count = 1 : i64 - } : (tensor<16x256x32xf32>, tensor<256x256x1xf32>) -> tensor<16x256x32xf32> - func.return %0 : tensor<16x256x32xf32> + dimension_numbers = #mhlo.conv<[b, f, 0, 1]x[i, 0, 1, o]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<1> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x207x8x8xf32>, tensor<1x3x3x3312xf32>) -> tensor<1x8x8x3312xf32> + func.return %0 : tensor<1x8x8x3312xf32> } // CHECK: %[[TRANSPOSED_INPUT:.*]] = "mhlo.transpose"(%arg0) // CHECK-SAME: permutation -// CHECK-SAME: [0, 2, 1] -// CHECK: %[[TRANSPOSED_KERNEL:.*]] = "mhlo.transpose"(%arg1) -// CHECK-SAME: permutation -// CHECK-SAME: [0, 2, 1] -// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[TRANSPOSED_INPUT]], %[[TRANSPOSED_KERNEL]]) -// CHECK-SAME: [b, 0, f]x[o, 0, i]->[b, 0, f] -// CHECK: "mhlo.transpose"(%[[CONV_OUT]]) -// CHECK-SAME: permutation -// CHECK-SAME: [0, 2, 1] - +// CHECK-SAME: [0, 2, 3, 1] +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad"(%[[TRANSPOSED_INPUT]] +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]], %arg1) +// CHECK-SAME: [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] +// CHECK-SAME: pad = +// CHECK-SAME: [0, 0], [0, 0] // ----- @@ -320,6 +464,119 @@ func.func @conv3d_ndhwc_dhwio_ncdhw(%arg0: tensor<1x8x8x32x207xf32>, %arg1: tens // ----- +// CHECK-LABEL: conv3d_ndhwc_dhwio_ndhwc_padded +func.func @conv3d_ndhwc_dhwio_ndhwc_padded(%arg0: tensor<1x6x6x30x207xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f]>, + feature_group_count = 1 : i64, + padding = dense<1> : tensor<3x2xi64>} : + (tensor<1x6x6x30x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> + func.return %0 : tensor<1x6x6x1x16xf32> +} + +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad" +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]], %arg1) +// CHECK-SAME: pad = +// CHECK-SAME: [0, 0], [0, 0], [0, 0] +// CHECK-SAME: (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> + +// ----- + +// CHECK-LABEL: conv3d_ncdhw_dhwio_ndhwc_padded +func.func @conv3d_ncdhw_dhwio_ndhwc_padded(%arg0: tensor<1x207x6x6x30xf32>, %arg1: tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, f, 0, 1, 2]x[0, 1, 2, i, o]->[b, 0, 1, 2, f]>, + feature_group_count = 1 : i64, + padding = dense<1> : tensor<3x2xi64>} : + (tensor<1x207x6x6x30xf32>, tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> + func.return %0 : tensor<1x6x6x1x16xf32> +} + +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad" +// CHECK-SAME: edge_padding_high = dense<[0, 1, 1, 1, 0]> +// CHECK-SAME: edge_padding_low = dense<[0, 1, 1, 1, 0]> +// CHECK-SAME: interior_padding = dense<0> +// CHECK: mhlo.convolution(%[[PADDED_LHS]], %arg1) +// CHECK-SAME: pad = +// CHECK-SAME: [0, 0], [0, 0], [0, 0] +// CHECK-SAME: (tensor<1x8x8x32x207xf32>, tensor<3x3x32x207x16xf32>) -> tensor<1x6x6x1x16xf32> + +// ----- + +// 1D +//=-- + +// CHECK-LABEL: conv1d_nsc_osi_nsc +func.func @conv1d_nsc_osi_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[o, 0, i]->[b, 0, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xf32>, tensor<256x1x256xf32>) -> tensor<16x32x256xf32> + func.return %0 : tensor<16x32x256xf32> +} + +// CHECK: %[[RESHAPED_LHS:.*]] = mhlo.reshape %arg0 +// CHECK: %[[RESHAPED_RHS:.*]] = mhlo.reshape %arg1 +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[RESHAPED_LHS]], %[[RESHAPED_RHS]]) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: mhlo.reshape %[[CONV_OUT]] + +// ----- + +// CHECK-LABEL: conv1d_nsc_sio_nsc +func.func @conv1d_nsc_sio_nsc(%arg0: tensor<16x32x256xf32>, %arg1: tensor<1x256x256xf32>) -> tensor<16x32x256xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, + feature_group_count = 1 : i64, + lhs_dilation = dense<1> : tensor<1xi64>, + padding = dense<0> : tensor<1x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<1xi64>, + window_strides = dense<1> : tensor<1xi64> + } : (tensor<16x32x256xf32>, tensor<1x256x256xf32>) -> tensor<16x32x256xf32> + func.return %0 : tensor<16x32x256xf32> +} + +// CHECK: %[[RESHAPED_LHS:.*]] = mhlo.reshape %arg0 +// CHECK: %[[RESHAPED_RHS:.*]] = mhlo.reshape %arg1 +// CHECK: %[[TPOSED_RHS:.*]] = "mhlo.transpose"(%[[RESHAPED_RHS]]) <{permutation = dense<[3, 0, 1, 2]> : tensor<4xi64>}> : (tensor<1x1x256x256xf32>) -> tensor<256x1x1x256xf32> +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[RESHAPED_LHS]], %[[TPOSED_RHS]]) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: mhlo.reshape %[[CONV_OUT]] + +// ----- + + +// CHECK-LABEL: conv1d_ncs_osi_nsc_padded +func.func @conv1d_ncs_osi_nsc_padded(%arg0: tensor<16x256x30xf32>, %arg1: tensor<256x1x256xf32>) -> tensor<16x32x256xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, f, 0]x[o, 0, i]->[b, 0, f]>, + feature_group_count = 1 : i64, + padding = dense<1> : tensor<1x2xi64> + } : (tensor<16x256x30xf32>, tensor<256x1x256xf32>) -> tensor<16x32x256xf32> + func.return %0 : tensor<16x32x256xf32> +} + +// CHECK: %[[RESHAPED_LHS:.*]] = mhlo.reshape %arg0 : (tensor<16x256x30xf32>) -> tensor<16x256x30x1xf32> +// CHECK: %[[RESHAPED_RHS:.*]] = mhlo.reshape %arg1 : (tensor<256x1x256xf32>) -> tensor<256x1x1x256xf32> +// CHECK: %[[TPOSED_LHS:.*]] = "mhlo.transpose"(%0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<16x256x30x1xf32>) -> tensor<16x30x1x256xf32> +// CHECK: %[[PADDED_LHS:.*]] = "mhlo.pad"(%[[TPOSED_LHS]], %cst) <{edge_padding_high = dense<[0, 1, 0, 0]> : tensor<4xi64>, edge_padding_low = dense<[0, 1, 0, 0]> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>}> : (tensor<16x30x1x256xf32>, tensor) -> tensor<16x32x1x256xf32> +// CHECK: %[[CONV_OUT:.*]] = mhlo.convolution(%[[PADDED_LHS]], %[[RESHAPED_RHS]]) dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// CHECK: mhlo.reshape %[[CONV_OUT]] : (tensor<16x32x1x256xf32>) -> tensor<16x32x256xf32> + +// ----- + //===----------------------------------------------------------------------===// // mhlo.pad //===----------------------------------------------------------------------===// @@ -334,11 +591,11 @@ func.func @pad_2d(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor<4x3xf32> func.return %0 : tensor<4x3xf32> } -// CHECK: mhlo.slice +// CHECK: mhlo.slice // CHECK-SAME: limit_indices = dense<3> // CHECK-SAME: start_indices = dense<[0, 1]> // CHECK-SAME: (tensor<3x3xf32>) -> tensor<3x2xf32> -// CHECK: mhlo.pad +// CHECK: mhlo.pad // CHECK-SAME: edge_padding_high = dense<1> // CHECK-SAME: edge_padding_low = dense<0> // CHECK-SAME: (tensor<3x2xf32>, tensor) -> tensor<4x3xf32> @@ -355,11 +612,11 @@ func.func @pad_2d_negative(%arg0: tensor<3x3xf32>, %arg1: tensor) -> tensor func.return %0 : tensor<1x2xf32> } -// CHECK: mhlo.slice +// CHECK: mhlo.slice // CHECK-SAME: limit_indices = dense<[2, 3]> // CHECK-SAME: start_indices = dense<1> // CHECK-SAME: (tensor<3x3xf32>) -> tensor<1x2xf32> -// CHECK-NOT: mhlo.pad +// CHECK-NOT: mhlo.pad // ----- @@ -373,12 +630,115 @@ func.func @pad_3d_mixed(%arg0: tensor<3x3x3xf32>, %arg1: tensor) -> tensor< func.return %0 : tensor<3x3x3xf32> } -// CHECK: mhlo.slice +// CHECK: mhlo.slice // CHECK-SAME: limit_indices = dense<[3, 2, 3]> // CHECK-SAME: start_indices = dense<[1, 0, 0]> // CHECK-SAME: (tensor<3x3x3xf32>) -> tensor<2x2x3xf32> -// CHECK: mhlo.pad +// CHECK: mhlo.pad // CHECK-SAME: edge_padding_high = dense<[1, 0, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 1, 0]> // CHECK-SAME: (tensor<2x2x3xf32>, tensor) -> tensor<3x3x3xf32> +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.reduce_window +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: reduce_window_valid_channel_first +func.func @reduce_window_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = mhlo.maximum %arg1, %arg2 : tensor + mhlo.return %2 : tensor + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, + window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor) -> tensor<4x3x7x7xf32> + func.return %1 : tensor<4x3x7x7xf32> +} + +// CHECK: %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor +// CHECK: %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> +// CHECK: %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]]) +// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> +// CHECK: %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32> + +// ----- + +// CHECK-LABEL: reduce_window_same_channel_first +func.func @reduce_window_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 1], [0, 1]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, + window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor<4x3x16x16xf32>, tensor) -> tensor<4x3x8x8xf32> + func.return %1 : tensor<4x3x8x8xf32> +} + +// CHECK: %[[INIT_CST:.*]] = mhlo.constant dense<0xFF800000> : tensor +// CHECK: %[[TPOSE_IN:.*]] = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> +// CHECK: %[[RW:.*]] = "mhlo.reduce_window"(%[[TPOSE_IN]], %[[INIT_CST]]) +// CHECK-SAME: padding +// CHECK-SAME: [0, 0], [0, 1], [0, 1], [0, 0] +// CHECK-SAME: window_dimensions = dense<[1, 3, 3, 1]> +// CHECK-SAME: window_strides = dense<[1, 2, 2, 1]> +// CHECK: %3 = "mhlo.transpose"(%[[RW]]) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.dynamic_slice +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: dynamic_slice +func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK: mhlo.dynamic_slice +// CHECK-SAME: (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_slice_ui32 +func.func @dynamic_slice_ui32(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK: mhlo.dynamic_slice +// CHECK-SAME: (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + +// CHECK-LABEL: dynamic_slice_ui64 +func.func @dynamic_slice_ui64(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK: mhlo.dynamic_slice +// CHECK-SAME: (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_slice_i64 +func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK: mhlo.dynamic_slice +// CHECK-SAME: (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index fdd1f98b9756b9..60e40cf1082419 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -72,13 +72,13 @@ func.func @dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) func.return %0 : tensor<3x5x1x4xf32> } -// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose" -// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose" -// CHECK-NEXT: %[[RESHAPED_0:.*]] = mhlo.reshape %[[TRANSPOSED_0]] -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %[[TRANSPOSED_1]] -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32> +// CHECK: %[[TRANSPOSED_0:.*]] = "tfl.transpose" +// CHECK: %[[TRANSPOSED_1:.*]] = "tfl.transpose" +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%[[TRANSPOSED_0]] +// CHECK: %[[RESHAPED_1:.*]] = "tfl.reshape"(%[[TRANSPOSED_1]] +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<3x5x12xf32>, tensor<3x12x4xf32>) -> tensor<3x5x4xf32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<3x5x1x4xf32> // ----- @@ -96,11 +96,10 @@ func.func @dot_general_repeated(%arg0: tensor<1x1x1024xf32>, %arg1: tensor<1024x func.return %0 : tensor<1x1x1024xf32> } -// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32> +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x1024xf32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<1x1x1024xf32> // ----- @@ -115,11 +114,10 @@ func.func @dot_general_int8(%arg0: tensor<256xi8>, %arg1: tensor<256x8xi8>) -> t func.return %0 : tensor<8xi32> } -// CHECK: %[[RESHAPED_0:.*]] = mhlo.reshape %arg0 -// CHECK-NEXT: %[[RESHAPED_1:.*]] = mhlo.reshape %arg1 -// CHECK-NEXT: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %[[RESHAPED_1]]) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32> -// CHECK-NEXT: %[[RESHAPED_BMM:.*]] = mhlo.reshape %[[BMM_0]] -// CHECK-NEXT: return %[[RESHAPED_BMM]] : tensor<8xi32> +// CHECK: %[[RESHAPED_0:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[BMM_0:.*]] = "tfl.batch_matmul"(%[[RESHAPED_0]], %arg1) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : {{.*}} -> tensor<1x8xi32> +// CHECK: %[[RESHAPED_BMM:.*]] = "tfl.reshape"(%[[BMM_0]] +// CHECK: return %[[RESHAPED_BMM]] : tensor<8xi32> // ----- @@ -135,29 +133,30 @@ func.func @dot_general_dynamic_rhs_out_dim(%arg0: tensor<4x4x256xf32>, %arg1: te func.return %0 : tensor<4x4x?xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK-NEXT: %3 = mhlo.reshape %arg0 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: %4 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %7 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%4, %5, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.unsorted_segment_prod"(%4, %6, %7) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %9, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %12 = mhlo.dynamic_reshape %2, %11 : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> -// CHECK-NEXT: %13 = "tfl.batch_matmul"(%3, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> -// CHECK-NEXT: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %21 = mhlo.dynamic_reshape %13, %20 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: return %21 : tensor<4x4x?xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 2, 1]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> +// CHECK: %3 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %10 = "tfl.concatenation"(%9, %8, %7) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %12 = "tfl.reshape"(%2, %11) : (tensor<4x256x?xf32>, tensor<3xi32>) -> tensor<4x256x?xf32> +// CHECK: %13 = "tfl.batch_matmul"(%arg0, %12) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x256xf32>, tensor<4x256x?xf32>) -> tensor<4x4x?xf32> +// CHECK: %14 = "tfl.shape"(%arg0) : (tensor<4x4x256xf32>) -> tensor<3xi32> +// CHECK: %15 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %17 = "tfl.gather"(%14, %16) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %18 = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: %19 = "tfl.gather"(%15, %18) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<3xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %20 = "tfl.concatenation"(%17, %19) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %22 = "tfl.reshape"(%13, %21) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> +// CHECK: return %22 : tensor<4x4x?xf32> // ----- @@ -173,43 +172,45 @@ func.func @dot_general_dynamic_batch_dim(%arg0: tensor<2x?x2x3xf32>, %arg1: tens func.return %0 : tensor<2x?x2x4xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %12 = mhlo.dynamic_reshape %arg0, %11 : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32> -// CHECK-NEXT: %13 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %20 = "tfl.gather"(%13, %19) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %21 = "tfl.concatenation"(%20, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %22 = mhlo.dynamic_reshape %2, %21 : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> -// CHECK-NEXT: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> -// CHECK-NEXT: %24 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %25 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %28 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64> -// CHECK-NEXT: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> -// CHECK-NEXT: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> -// CHECK-NEXT: %31 = mhlo.dynamic_reshape %23, %30 : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32> -// CHECK-NEXT: return %31 : tensor<2x?x2x4xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 1, 3, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x?x4x3xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> +// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %10 = "tfl.gather"(%3, %9) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %11 = "tfl.concatenation"(%10, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %12 = "tfl.cast"(%11) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %13 = "tfl.reshape"(%arg0, %12) : (tensor<2x?x2x3xf32>, tensor<4xi32>) -> tensor<2x?x2x3xf32> +// CHECK: %14 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %17 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %18 = "tfl.unsorted_segment_prod"(%14, %15, %17) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %19 = "tfl.unsorted_segment_prod"(%14, %16, %17) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %20 = "tfl.pseudo_const"() <{value = dense<[0, 1]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %21 = "tfl.gather"(%14, %20) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %22 = "tfl.concatenation"(%21, %19, %18) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<2xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %23 = "tfl.cast"(%22) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %24 = "tfl.reshape"(%2, %23) : (tensor<2x?x3x4xf32>, tensor<4xi32>) -> tensor<2x?x3x4xf32> +// CHECK: %25 = "tfl.batch_matmul"(%13, %24) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x2x3xf32>, tensor<2x?x3x4xf32>) -> tensor<2x?x2x4xf32> +// CHECK: %26 = "tfl.shape"(%arg0) : (tensor<2x?x2x3xf32>) -> tensor<4xi32> +// CHECK: %27 = "tfl.shape"(%arg1) : (tensor<2x?x4x3xf32>) -> tensor<4xi32> +// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %29 = "tfl.gather"(%26, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK: %30 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: %31 = "tfl.gather"(%27, %30) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %32 = "tfl.concatenation"(%29, %31) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %33 = "tfl.cast"(%32) : (tensor<4xi32>) -> tensor<4xi32> +// CHECK: %34 = "tfl.reshape"(%25, %33) : (tensor<2x?x2x4xf32>, tensor<4xi32>) -> tensor<2x?x2x4xf32> +// CHECK: return %34 : tensor<2x?x2x4xf32> // ----- - // CHECK-LABEL: dot_general_dynamic_lhs_rhs_out_dims func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg1: tensor<2x4x?x3xf32>) -> tensor<2x2x?x4x?xf32> { %0 = "mhlo.dot_general"(%arg0, %arg1) { @@ -222,37 +223,40 @@ func.func @dot_general_dynamic_lhs_rhs_out_dims(%arg0: tensor<2x2x?x3xf32>, %arg func.return %0 : tensor<2x2x?x4x?xf32> } -// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK-NEXT: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> -// CHECK-NEXT: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32> -// CHECK-NEXT: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %11 = mhlo.dynamic_reshape %arg0, %10 : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32> -// CHECK-NEXT: %12 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %13 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %14 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %16 = "tfl.unsorted_segment_prod"(%12, %13, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %17 = "tfl.unsorted_segment_prod"(%12, %14, %15) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %18 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %19 = "tfl.concatenation"(%18, %17, %16) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %20 = mhlo.dynamic_reshape %2, %19 : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> -// CHECK-NEXT: %21 = "tfl.batch_matmul"(%11, %20) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> -// CHECK-NEXT: %22 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %23 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> -// CHECK-NEXT: %24 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK-NEXT: %25 = "tfl.gather"(%22, %24) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> -// CHECK-NEXT: %26 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> -// CHECK-NEXT: %27 = "tfl.gather"(%23, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> -// CHECK-NEXT: %28 = "tfl.concatenation"(%25, %27) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> -// CHECK-NEXT: %29 = mhlo.dynamic_reshape %21, %28 : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32> -// CHECK-NEXT: return %29 : tensor<2x2x?x4x?xf32> +// CHECK: %0 = "tfl.pseudo_const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %1 = "tfl.cast"(%0) : (tensor<4xi64>) -> tensor<4xi32> +// CHECK: %2 = "tfl.transpose"(%arg1, %1) : (tensor<2x4x?x3xf32>, tensor<4xi32>) -> tensor<2x3x4x?xf32> +// CHECK: %3 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %4 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %5 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %6 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %7 = "tfl.unsorted_segment_prod"(%3, %4, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %8 = "tfl.unsorted_segment_prod"(%3, %5, %6) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %9 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %10 = "tfl.concatenation"(%9, %7, %8) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %11 = "tfl.cast"(%10) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %12 = "tfl.reshape"(%arg0, %11) : (tensor<2x2x?x3xf32>, tensor<3xi32>) -> tensor<2x?x3xf32> +// CHECK: %13 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> +// CHECK-DAG: %14 = "tfl.pseudo_const"() <{value = dense<[-1, 0, 0, -1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %15 = "tfl.pseudo_const"() <{value = dense<[-1, -1, -1, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK-DAG: %16 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %17 = "tfl.unsorted_segment_prod"(%13, %14, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %18 = "tfl.unsorted_segment_prod"(%13, %15, %16) : (tensor<4xi32>, tensor<4xi32>, tensor) -> tensor<1xi32> +// CHECK: %19 = "tfl.pseudo_const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %20 = "tfl.concatenation"(%19, %18, %17) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %21 = "tfl.cast"(%20) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %22 = "tfl.reshape"(%2, %21) : (tensor<2x3x4x?xf32>, tensor<3xi32>) -> tensor<2x3x?xf32> +// CHECK: %23 = "tfl.batch_matmul"(%12, %22) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2x?x3xf32>, tensor<2x3x?xf32>) -> tensor<2x?x?xf32> +// CHECK: %24 = "tfl.shape"(%arg0) : (tensor<2x2x?x3xf32>) -> tensor<4xi32> +// CHECK: %25 = "tfl.shape"(%arg1) : (tensor<2x4x?x3xf32>) -> tensor<4xi32> +// CHECK: %26 = "tfl.pseudo_const"() <{value = dense<[0, 1, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %27 = "tfl.gather"(%24, %26) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<3xi64>) -> tensor<3xi32> +// CHECK: %28 = "tfl.pseudo_const"() <{value = dense<[1, 2]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK: %29 = "tfl.gather"(%25, %28) <{axis = 0 : i32, batch_dims = 0 : i32}> : (tensor<4xi32>, tensor<2xi64>) -> tensor<2xi32> +// CHECK: %30 = "tfl.concatenation"(%27, %29) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32> +// CHECK: %31 = "tfl.cast"(%30) : (tensor<5xi32>) -> tensor<5xi32> +// CHECK: %32 = "tfl.reshape"(%23, %31) : (tensor<2x?x?xf32>, tensor<5xi32>) -> tensor<2x2x?x4x?xf32> +// CHECK: return %32 : tensor<2x2x?x4x?xf32 // ----- @@ -268,27 +272,28 @@ func.func @dot_general_dynamic_contracting_dim(%arg0: tensor<4x4x?xf32>, %arg1: func.return %0 : tensor<4x4x256xf32> } -// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32> -// CHECK-NEXT: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> -// CHECK-NEXT: %9 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> -// CHECK-NEXT: %10 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %11 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK-NEXT: %12 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor -// CHECK-NEXT: %13 = "tfl.unsorted_segment_prod"(%9, %10, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %14 = "tfl.unsorted_segment_prod"(%9, %11, %12) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> -// CHECK-NEXT: %15 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> -// CHECK-NEXT: %16 = "tfl.concatenation"(%15, %14, %13) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> -// CHECK-NEXT: %17 = mhlo.dynamic_reshape %arg1, %16 : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> -// CHECK-NEXT: %18 = "tfl.batch_matmul"(%8, %17) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: %19 = mhlo.reshape %18 : (tensor<4x4x256xf32>) -> tensor<4x4x256xf32> -// CHECK-NEXT: return %19 : tensor<4x4x256xf32> +// CHECK: %0 = "tfl.shape"(%arg0) : (tensor<4x4x?xf32>) -> tensor<3xi32> +// CHECK-DAG: %1 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %2 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %3 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %4 = "tfl.unsorted_segment_prod"(%0, %1, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %5 = "tfl.unsorted_segment_prod"(%0, %2, %3) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %6 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %7 = "tfl.concatenation"(%6, %4, %5) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %8 = "tfl.cast"(%7) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %9 = "tfl.reshape"(%arg0, %8) : (tensor<4x4x?xf32>, tensor<3xi32>) -> tensor<4x4x?xf32> +// CHECK: %10 = "tfl.shape"(%arg1) : (tensor<4x?x256xf32>) -> tensor<3xi32> +// CHECK-DAG: %11 = "tfl.pseudo_const"() <{value = dense<[-1, -1, 0]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %12 = "tfl.pseudo_const"() <{value = dense<[-1, 0, -1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK-DAG: %13 = "tfl.pseudo_const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %14 = "tfl.unsorted_segment_prod"(%10, %11, %13) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %15 = "tfl.unsorted_segment_prod"(%10, %12, %13) : (tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1xi32> +// CHECK: %16 = "tfl.pseudo_const"() <{value = dense<4> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %17 = "tfl.concatenation"(%16, %15, %14) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> +// CHECK: %18 = "tfl.cast"(%17) : (tensor<3xi32>) -> tensor<3xi32> +// CHECK: %19 = "tfl.reshape"(%arg1, %18) : (tensor<4x?x256xf32>, tensor<3xi32>) -> tensor<4x?x256xf32> +// CHECK: %20 = "tfl.batch_matmul"(%9, %19) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x4x?xf32>, tensor<4x?x256xf32>) -> tensor<4x4x256xf32> +// CHECK: return %20 : tensor<4x4x256xf32> // ----- @@ -318,14 +323,10 @@ func.func @argmax(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32 func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32> } -// CHECK: %0 = mhlo.constant dense<0xFF800000> : tensor -// CHECK-DAG: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> -// CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> -// CHECK: %cst = arith.constant dense<2> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> -// CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_max"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<4x32xf32>, tensor<4x32xi32> // ----- @@ -410,12 +411,11 @@ func.func @argmax_bool(%arg0: tensor<2xi1>) -> tensor { return %3#1 : tensor } -// CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor -// CHECK: %2 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %3 = "tfl.reduce_any"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.arg_max"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK-DAG: %2 = mhlo.constant dense<0> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %3 = "tfl.reduce_any"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor // ----- @@ -442,14 +442,10 @@ func.func @argmin(%arg0: tensor<4x32x256xf32>) -> (tensor<4x32xf32>, tensor<4x32 func.return %4#0, %4#1 : tensor<4x32xf32>, tensor<4x32xi32> } -// CHECK-DAG: %0 = mhlo.constant dense<0x7F800000> : tensor -// CHECK: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<256xi32> -// CHECK: %3 = "mhlo.broadcast_in_dim"(%2) <{broadcast_dimensions = dense<2> : tensor<1xi64>}> : (tensor<256xi32>) -> tensor<4x32x256xi32> -// CHECK: %cst = arith.constant dense<2> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> -// CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> -// CHECK: return %4, %5 : tensor<4x32xf32>, tensor<4x32xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<2> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_min"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<4x32x256xf32>, tensor<1xi32>) -> tensor<4x32xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<4x32xf32>, tensor<4x32xi32> // ----- @@ -474,14 +470,10 @@ func.func @argmin_i16(%arg0: tensor<2xi16>) -> (tensor, tensor) { func.return %4#0, %4#1 : tensor, tensor } -// CHECK: %0 = mhlo.constant dense : tensor -// CHECK: %1 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> -// CHECK-DAG: %2 = mhlo.constant dense<32767> : tensor -// CHECK: %3 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_min"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi16>, tensor<1xi32>) -> tensor -// CHECK: %5 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi16>, tensor<1xi32>) -> tensor -// CHECK: return %4, %5 : tensor, tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_min"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi16>, tensor<1xi32>) -> tensor +// CHECK: %[[ARG:.*]] = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<2xi16>, tensor<1xi32>) -> tensor +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor, tensor // ----- @@ -535,12 +527,11 @@ func.func @argmin_bool(%arg0: tensor<2xi1>) -> tensor { return %3#1 : tensor } -// CHECK: %0 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<2xi32> // CHECK-DAG: %1 = mhlo.constant dense : tensor -// CHECK: %2 = mhlo.constant dense<0> : tensor -// CHECK: %cst = arith.constant dense<0> : tensor<1xi32> -// CHECK: %3 = "tfl.reduce_all"(%arg0, %cst) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor -// CHECK: %4 = "tfl.arg_min"(%arg0, %cst) : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK-DAG: %2 = mhlo.constant dense<0> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<0> : tensor<1xi32> +// CHECK: %3 = "tfl.reduce_all"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<2xi1>, tensor<1xi32>) -> tensor +// CHECK: %4 = "tfl.arg_min"(%arg0, %[[CST]]) : (tensor<2xi1>, tensor<1xi32>) -> tensor // CHECK: return %4 : tensor // ----- @@ -567,14 +558,10 @@ func.func @argmax_with_reshaped_iota(%arg0: tensor<1x32x1xf32>) -> (tensor<1x1xf func.return %4#0, %4#1 : tensor<1x1xf32>, tensor<1x1xi32> } -// CHECK-DAG: %0 = mhlo.constant dense<0xFF800000> : tensor -// CHECK: %1 = mhlo.constant dense<0> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<32xi32> -// CHECK: %3 = mhlo.reshape %2 : (tensor<32xi32>) -> tensor<1x32x1xi32> -// CHECK: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32> -// CHECK: return %4, %5 : tensor<1x1xf32>, tensor<1x1xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<1xi32> +// CHECK: %[[REDUCE:.*]] = "tfl.reduce_max"(%arg0, %[[CST]]) <{keep_dims = false}> : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xf32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<1x32x1xf32>, tensor<1xi32>) -> tensor<1x1xi32> +// CHECK: return %[[REDUCE]], %[[ARG]] : tensor<1x1xf32>, tensor<1x1xi32> // ----- @@ -597,14 +584,9 @@ func.func @pytorch_argmax(%arg0: tensor<1x9xi32>) -> tensor<1xi32> { func.return %4#1 : tensor<1xi32> } -// CHECK: %0 = mhlo.constant dense<0> : tensor -// CHECK-DAG: %1 = mhlo.constant dense<-2147483648> : tensor -// CHECK: %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<9xi32> -// CHECK: %3 = mhlo.reshape %2 : (tensor<9xi32>) -> tensor<1x9xi32> -// CHECK: %cst = arith.constant dense<1> : tensor<1xi32> -// CHECK: %4 = "tfl.reduce_max"(%arg0, %cst) <{keep_dims = false}> : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> -// CHECK: %5 = "tfl.arg_max"(%arg0, %cst) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> -// CHECK: return %5 : tensor<1xi32> +// CHECK: %[[CST:.*]] = arith.constant dense<1> : tensor<1xi32> +// CHECK: %[[ARG:.*]] = "tfl.arg_max"(%arg0, %[[CST]]) : (tensor<1x9xi32>, tensor<1xi32>) -> tensor<1xi32> +// CHECK: return %[[ARG]] : tensor<1xi32> // ----- @@ -618,11 +600,11 @@ func.func @cbrt_f32(%arg0: tensor<1x32x1xf32>) -> tensor<1x32x1xf32> { func.return %0 : tensor<1x32x1xf32> } -// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor -// CHECK: %cst_0 = arith.constant dense<3.000000e+00> : tensor -// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor -// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor) -> tensor<1x32x1xf32> -// CHECK: return %1 : tensor<1x32x1xf32> +// CHECK-DAG: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK-DAG: %cst_0 = arith.constant dense<3.000000e+00> : tensor +// CHECK: %0 = tfl.div %cst, %cst_0 {fused_activation_function = "NONE"} : tensor +// CHECK: %1 = tfl.pow(%arg0, %0) : (tensor<1x32x1xf32>, tensor) -> tensor<1x32x1xf32> +// CHECK: return %1 : tensor<1x32x1xf32> // ----- @@ -636,6 +618,100 @@ func.func @cbrt_f64(%arg0: tensor<1x32x1xf64>) -> tensor<1x32x1xf64> { // ----- +//===----------------------------------------------------------------------===// +// mhlo.(dynamic)reshape +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: reshape +func.func @reshape(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<2x3xf32>) -> tensor<3x2xf32> + func.return %0 : tensor<3x2xf32> +} + +// CHECK: %cst = arith.constant dense<[3, 2]> : tensor<2xi64> +// CHECK: %0 = "tfl.cast"(%cst) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_reshape_i32 +func.func @dynamic_reshape_i32(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi32>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi32>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + +// ----- + +// CHECK-LABEL: dynamic_reshape_i64 +func.func @dynamic_reshape_i64(%arg0: tensor<2x3xf32>, %arg1: tensor<2xi64>) -> tensor { + %0 = "mhlo.dynamic_reshape"(%arg0, %arg1) : (tensor<2x3xf32>, tensor<2xi64>) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.cast"(%arg1) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %1 = "tfl.reshape"(%arg0, %0) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo binary bit-wise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: logical_and +func.func @logical_and(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.and %arg0, %arg1 : tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.logical_and +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: bitwise_and +func.func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.and %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: mhlo.and +// CHECK-NOT: tfl + +// ----- + +// CHECK-LABEL: logical_or +func.func @logical_or(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> tensor<2xi1> { + %0 = mhlo.or %arg0, %arg1 : tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.logical_or +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: bitwise_or +func.func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.or %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: mhlo.or +// CHECK-NOT: tfl + +// ----- + +// CHECK-LABEL: logical_xor +func.func @logical_xor(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { + %0 = mhlo.xor %arg0, %arg1 : tensor<4xi1> + func.return %0 : tensor<4xi1> +} + +// ----- + //===----------------------------------------------------------------------===// // mhlo.convolution //===----------------------------------------------------------------------===// @@ -644,6 +720,73 @@ func.func @cbrt_f64(%arg0: tensor<1x32x1xf64>) -> tensor<1x32x1xf64> { // 2D //=--- +// CHECK-LABEL: transpose_conv2d_valid_padding_odd +func.func @transpose_conv2d_valid_padding_odd(%arg0: tensor<1x200x198x4xf32>, %arg1: tensor<4x4x4x4xf32>) -> tensor<1x402x398x4xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[3, 3], [3, 3]],lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x200x198x4xf32>, tensor<4x4x4x4xf32>) -> tensor<1x402x398x4xf32> + func.return %0 : tensor<1x402x398x4xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<4xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<4x4x4x4xf32>, tensor<2xi32>) -> tensor<4x4x4x4xf32> + // CHECK %cst_1 = arith.constant dense<[1, 402, 398, 4]> : tensor<4xi32> + // CHECK %1 = "tfl.transpose_conv"(%cst_1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<4x4x4x4xf32>, tensor<1x200x198x4xf32>, tensor<4xf32>) -> tensor<1x402x398x4xf32> + // CHECK return %1 : tensor<1x402x398x4xf32> +} + +// CHECK-LABEL: transpose_conv2d_same_padding +func.func @transpose_conv2d_same_padding(%input: tensor<1x256x256x2xf32>, %filter:tensor<2x4x4x2xf32>) -> tensor<1x512x512x2xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x256x256x2xf32>, tensor<2x4x4x2xf32>) -> tensor<1x512x512x2xf32> + func.return %0 : tensor<1x512x512x2xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<2x4x4x2xf32>, tensor<2xi32>) -> tensor<2x4x4x2xf32> + // CHECK %1 = "tfl.pseudo_const"() <{value = dense<[1, 512, 512, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK %2 = "tfl.transpose_conv"(%1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<2x4x4x2xf32>, tensor<1x256x256x2xf32>, tensor<2xf32>) -> tensor<1x512x512x2xf32> + // CHECK return %2 : tensor<1x512x512x2xf32> +} + +// ----- + +// CHECK-LABEL: transpose_conv2d_valid_padding +func.func @transpose_conv2d_valid_padding(%input: tensor<1x256x256x2xf32>, %filter:tensor<2x4x4x2xf32>) -> tensor<1x514x514x2xf32> { + %0 = mhlo.convolution(%input, %filter) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[3, 3], [3, 3]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x256x256x2xf32>, tensor<2x4x4x2xf32>) -> tensor<1x514x514x2xf32> + func.return %0 : tensor<1x514x514x2xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<2xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<2x4x4x2xf32>, tensor<2xi32>) -> tensor<2x4x4x2xf32> + // CHECK %1 = "tfl.pseudo_const"() <{value = dense<[1, 514, 514, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> + // CHECK %2 = "tfl.transpose_conv"(%1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<2x4x4x2xf32>, tensor<1x256x256x2xf32>, tensor<2xf32>) -> tensor<1x514x514x2xf32> + // CHECK return %2 : tensor<1x514x514x2xf32> +} + +// ----- + +// CHECK-LABEL: transpose_conv2d_valid_padding_equal_strides +func.func @transpose_conv2d_valid_padding_equal_strides(%arg0: tensor<1x200x198x3xf32>, %arg1: tensor<3x3x3x3xf32>) -> tensor<1x401x397x3xf32> { + %0 = mhlo.convolution(%arg0, %arg1) + dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], + window = {pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2]} + {batch_group_count = 1 : i64, feature_group_count = 1 : i64} + : (tensor<1x200x198x3xf32>, tensor<3x3x3x3xf32>) -> tensor<1x401x397x3xf32> + func.return %0 : tensor<1x401x397x3xf32> + // CHECK %cst = arith.constant dense<0.000000e+00> : tensor<3xf32> + // CHECK %cst_0 = arith.constant dense<[1, 2]> : tensor<2xi32> + // CHECK %0 = "tfl.reverse_v2"(%arg1, %cst_0) : (tensor<3x3x3x3xf32>, tensor<2xi32>) -> tensor<3x3x3x3xf32> + // CHECK %cst_1 = arith.constant dense<[1, 401, 397, 3]> : tensor<4xi32> + // CHECK %1 = "tfl.transpose_conv"(%cst_1, %0, %arg0, %cst) <{fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4xi32>, tensor<3x3x3x3xf32>, tensor<1x200x198x3xf32>, tensor<3xf32>) -> tensor<1x401x397x3xf32> + // CHECK return %1 : tensor<1x401x397x3xf32> +} // CHECK-LABEL: conv2d_nhwc_ohwi_nhwc func.func @conv2d_nhwc_ohwi_nhwc(%input: tensor<1x256x256x3xf32>, %filter: tensor<2x1x1x3xf32>) -> tensor<1x256x256x2xf32> { %0 = mhlo.convolution(%input, %filter) @@ -826,27 +969,86 @@ func.func @conv2d_nhwc_ohwi_nhwc_dynamic_batch(%input: tensor, // ----- -// TODO: b/351437662 - Add support for depthwise conv. -// CHECK-LABEL: depthwise_conv2d_nhwc_ohwi_nhwc -func.func @depthwise_conv2d_nhwc_ohwi_nhwc(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3312x3x3x1xf32>) -> tensor<1x8x8x3312xf32> { +// CHECK-LABEL: depthwise_conv2d_nhwc_ihwo_nhwc +func.func @depthwise_conv2d_nhwc_ihwo_nhwc(%arg0: tensor<1x10x10x207xf32>, %arg1: tensor<1x3x3x207xf32>) -> tensor<1x8x8x207xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f]>, feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<1> : tensor<2x2xi64>, + padding = dense<0> : tensor<2x2xi64>, precision_config = [#mhlo, #mhlo], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64> - } : (tensor<1x8x8x207xf32>, tensor<3312x3x3x1xf32>) -> tensor<1x8x8x3312xf32> - func.return %0 : tensor<1x8x8x3312xf32> + } : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>) -> tensor<1x8x8x207xf32> + func.return %0 : tensor<1x8x8x207xf32> } -// CHECK-NOT: tfl +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>, tensor<207xf32>) -> tensor<1x8x8x207xf32> +// CHECK: return %0 + +// ----- + +// CHECK-LABEL: depthwise_conv2d_nhwc_ihwo_nhwc_strided +func.func @depthwise_conv2d_nhwc_ihwo_nhwc_strided(%arg0: tensor<1x10x10x207xf32>, %arg1: tensor<1x3x3x207xf32>) -> tensor<1x4x4x207xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<1> : tensor<2xi64>, + window_strides = dense<2> : tensor<2xi64> + } : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>) -> tensor<1x4x4x207xf32> + func.return %0 : tensor<1x4x4x207xf32> +} + +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>, tensor<207xf32>) -> tensor<1x4x4x207xf32> +// CHECK: return %0 + +// ----- + +// CHECK-LABEL: depthwise_conv2d_nhwc_ihwo_nhwc_dilated +func.func @depthwise_conv2d_nhwc_ihwo_nhwc_dilated(%arg0: tensor<1x10x10x207xf32>, %arg1: tensor<1x3x3x207xf32>) -> tensor<1x6x6x207xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<2> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>) -> tensor<1x6x6x207xf32> + func.return %0 : tensor<1x6x6x207xf32> +} + +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst) <{depth_multiplier = 1 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x10x10x207xf32>, tensor<1x3x3x207xf32>, tensor<207xf32>) -> tensor<1x6x6x207xf32> +// CHECK: return %0 + +// ----- + +// CHECK-LABEL: depthwise_conv2d_nhwc_ihwo_nhwc_non_trivial_depth_multiplier +func.func @depthwise_conv2d_nhwc_ihwo_nhwc_non_trivial_depth_multiplier(%arg0: tensor<1x10x10x207xf32>, %arg1: tensor<1x3x3x3519xf32>) -> tensor<1x6x6x3519xf32> { + %0 = "mhlo.convolution"(%arg0, %arg1) { + batch_group_count = 1 : i64, + dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f]>, + feature_group_count = 207 : i64, + lhs_dilation = dense<1> : tensor<2xi64>, + padding = dense<0> : tensor<2x2xi64>, + precision_config = [#mhlo, #mhlo], + rhs_dilation = dense<2> : tensor<2xi64>, + window_strides = dense<1> : tensor<2xi64> + } : (tensor<1x10x10x207xf32>, tensor<1x3x3x3519xf32>) -> tensor<1x6x6x3519xf32> + func.return %0 : tensor<1x6x6x3519xf32> +} + +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst) <{depth_multiplier = 17 : i32, dilation_h_factor = 2 : i32, dilation_w_factor = 2 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x10x10x207xf32>, tensor<1x3x3x3519xf32>, tensor<3519xf32>) -> tensor<1x6x6x3519xf32> +// CHECK: return %0 // ----- -// TODO: b/351437662 - Add support for conv to resize. // CHECK-LABEL: conv2d_resize_perferred_nhwc_hwoi_nhwc func.func @conv2d_resize_perferred_nhwc_hwoi_nhwc(%arg0: tensor<1x56x1248x16xf32>, %arg1: tensor<16x3x1x1xf32>) -> tensor<1x111x1248x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { @@ -860,13 +1062,13 @@ func.func @conv2d_resize_perferred_nhwc_hwoi_nhwc(%arg0: tensor<1x56x1248x16xf32 window_strides = dense<[1, 1]> : tensor<2xi64> } : (tensor<1x56x1248x16xf32>, tensor<16x3x1x1xf32>) -> tensor<1x111x1248x16xf32> func.return %0 : tensor<1x111x1248x16xf32> + // CHECK %0 = "tfl.pseudo_const"() <{value = dense<[111, 1248]> : tensor<2xi32>}> : () -> tensor<2xi32> + // CHECK %1 = "tfl.resize_bilinear"(%arg0, %0) <{align_corners = false, half_pixel_centers = false}> : (tensor<1x56x1248x16xf32>, tensor<2xi32>) -> tensor<1x111x1248x16xf32> + // CHECK return %1 : tensor<1x111x1248x16xf32> } -// CHECK-NOT: tfl - // ----- -// TODO: b/351437662 - Add support for conv to resize. // CHECK-LABEL: conv2d_to_resize_nhwc_hwoi_nhwc func.func @conv2d_to_resize_nhwc_hwoi_nhwc(%arg0: tensor<1x56x624x16xf32>, %arg1: tensor<16x1x257x1xf32>) -> tensor<1x56x904x16xf32> { %0 = "mhlo.convolution"(%arg0, %arg1) { @@ -880,52 +1082,11 @@ func.func @conv2d_to_resize_nhwc_hwoi_nhwc(%arg0: tensor<1x56x624x16xf32>, %arg1 window_strides = dense<[1, 89]> : tensor<2xi64> } : (tensor<1x56x624x16xf32>, tensor<16x1x257x1xf32>) -> tensor<1x56x904x16xf32> func.return %0 : tensor<1x56x904x16xf32> + // CHECK %0 = "tfl.pseudo_const"() <{value = dense<[56, 904]> : tensor<2xi32>}> : () -> tensor<2xi32> + // CHECK %1 = "tfl.resize_bilinear"(%arg0, %0) <{align_corners = true, half_pixel_centers = false}> : (tensor<1x56x624x16xf32>, tensor<2xi32>) -> tensor<1x56x904x16xf32> + // CHECK return %1 : tensor<1x56x904x16xf32> } -// CHECK-NOT: tfl - -// ----- - -// TODO: b/351437662 - Add support for feature groups. -// CHECK-LABEL: group_conv2d_nhwc_ohwi_nhwc -func.func @group_conv2d_nhwc_ohwi_nhwc(%arg0: tensor<1x14x14x2240xf32>, %arg1: tensor<2240x3x3x112xf32>) -> tensor<1x7x7x2240xf32> { - %0 = "mhlo.convolution"(%arg0, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = #mhlo.conv<[b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f]>, - feature_group_count = 20 : i64, - lhs_dilation = dense<1> : tensor<2xi64>, - padding = dense<1> : tensor<2x2xi64>, - precision_config = [#mhlo, #mhlo], - rhs_dilation = dense<1> : tensor<2xi64>, - window_reversal = dense : tensor<2xi1>, - window_strides = dense<2> : tensor<2xi64> - } : (tensor<1x14x14x2240xf32>, tensor<2240x3x3x112xf32>) -> tensor<1x7x7x2240xf32> - func.return %0 : tensor<1x7x7x2240xf32> -} - -// CHECK-NOT: tfl - -// ----- - -// CHECK-LABEL: conv2d_nhwc_ohwi_nhwc_trivial_in_channels -func.func @conv2d_nhwc_ohwi_nhwc_trivial_in_channels(%input: tensor<1x256x256x1xf32>, %filter: tensor<2x1x1x1xf32>) -> tensor<1x256x256x2xf32> { - %0 = mhlo.convolution(%input, %filter) - dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], - window = {stride = [1, 1], pad = [[0, 0], [0, 0]]} { - batch_group_count = 1 : i64, - feature_group_count = 1 : i64, - window_strides = dense<1> : tensor<2xi64>, - padding = dense<0> : tensor<2x2xi64>, - rhs_dilation = dense<[1, 1]> : tensor<2xi64>, - lhs_dilation = dense<[1, 1]> : tensor<2xi64> - } : (tensor<1x256x256x1xf32>, tensor<2x1x1x1xf32>) -> tensor<1x256x256x2xf32> - func.return %0 : tensor<1x256x256x2xf32> -} - -// NOTE: This case is depthwise. - -// CHECK-NOT: tfl - // ----- // @@ -1455,7 +1616,7 @@ func.func @gather_nd(%arg0: tensor<98x128xf32>, %arg1: tensor<4x64xi32>) -> tens func.return %0 : tensor<4x64x128xf32> } -// CHECK: %[[VAL_0:.*]] = mhlo.reshape %arg1 : (tensor<4x64xi32>) -> tensor<4x64x1xi32> +// CHECK: %[[VAL_0:.*]] = "tfl.reshape"(%arg1, %0) : (tensor<4x64xi32>, tensor<3xi32>) -> tensor<4x64x1xi32 // CHECK: %[[VAL_1:.*]] = "tfl.gather_nd"(%arg0, %[[VAL_0]]) : (tensor<98x128xf32>, tensor<4x64x1xi32>) -> tensor<4x64x128xf32> // ----- @@ -1564,3 +1725,1344 @@ func.func @gather_scalar_dynamic_indices(%arg0: tensor<256000xf32>, %arg1: tenso } // CHECK: %0 = "tfl.gather_nd"(%arg0, %arg1) : (tensor<256000xf32>, tensor) -> tensor + +// ----- + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> avg pool +//===------------------------------------------------------------------------=== + +// CHECK-LABEL: avgpool_same_channel_first +func.func @avgpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<4x16x16x3xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> + %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %8 : tensor + }) : (tensor<4x16x16x3xf32>, tensor) -> tensor<4x8x8x3xf32> + %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32> + %5 = "mhlo.reduce_window"(%0, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %8 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %8 : tensor + }) : (tensor<4x16x16x3xf32>, tensor) -> tensor<4x8x8x3xf32> + %6 = "mhlo.transpose"(%5) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32> + %7 = mhlo.divide %4, %6 : tensor<4x3x8x8xf32> + return %7 : tensor<4x3x8x8xf32> +} + +// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0 +// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32> +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32> +// CHECK: %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]] +// CHECK-SAME: (tensor<4x8x8x3xf32>, tensor<4xi32>) -> tensor<4x3x8x8xf32> +// CHECK: return %[[TPOSED_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_valid_channel_first +func.func @avgpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> { + %0 = mhlo.constant dense<9.000000e+00> : tensor<4x3x7x7xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> + %3 = "mhlo.reduce_window"(%2, %1) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %6 : tensor + }) : (tensor<4x16x16x3xf32>, tensor) -> tensor<4x7x7x3xf32> + %4 = "mhlo.transpose"(%3) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32> + %5 = mhlo.divide %4, %0 : tensor<4x3x7x7xf32> + return %5 : tensor<4x3x7x7xf32> +} + +// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0 +// CHECK-SAME: (tensor<4x3x16x16xf32>, tensor<4xi32>) -> tensor<4x16x16x3xf32> +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32> +// CHECK: %[[TPOSED_OUT:.*]] = "tfl.transpose"(%[[POOL_OUT]] +// CHECK-SAME: (tensor<4x7x7x3xf32>, tensor<4xi32>) -> tensor<4x3x7x7xf32> +// CHECK: return %[[TPOSED_OUT]] + +// ----- + +func.func @avgpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<0.0> : tensor + %1 = mhlo.constant dense<9.0> : tensor<4x7x7x8xf32> + %2 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %5 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%5) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %3 = mhlo.divide %2, %1 : tensor<4x7x7x8xf32> + func.return %3 : tensor<4x7x7x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_valid_broadcasted_divisor +func.func @avgpool_valid_broadcasted_divisor(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<0.0> : tensor + %1 = mhlo.constant dense<9.0> : tensor + %2 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x7x7x8xf32> + %3 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %5 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%5) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<0> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %4 = mhlo.divide %3, %2 : tensor<4x7x7x8xf32> + func.return %4 : tensor<4x7x7x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_valid_rw +func.func @avgpool_valid_rw(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32> + %1 = mhlo.constant dense<0.0> : tensor + %2 = "mhlo.reduce_window"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %4 = mhlo.divide %2, %3 : tensor<4x7x7x8xf32> + func.return %4 : tensor<4x7x7x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_valid_rw_broadcasted_const_lhs +func.func @avgpool_valid_rw_broadcasted_const_lhs(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + %0 = mhlo.constant dense<1.0> : tensor + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<> : tensor<0xi64>}> : (tensor) -> tensor<4x16x16x8xf32> + %2 = mhlo.constant dense<0.0> : tensor + %3 = "mhlo.reduce_window"(%arg0, %2) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %4 = "mhlo.reduce_window"(%1, %2) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + %5 = mhlo.divide %3, %4 : tensor<4x7x7x8xf32> + func.return %5 : tensor<4x7x7x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_same +func.func @avgpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { + %0 = mhlo.constant dense<1.0> : tensor<4x16x16x8xf32> + %1 = mhlo.constant dense<0.0> : tensor + %2 = "mhlo.reduce_window"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x8x8x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x8x8x8xf32> + %4 = mhlo.divide %2, %3 : tensor<4x8x8x8xf32> + func.return %4 : tensor<4x8x8x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +// CHECK-LABEL: avgpool_reshape_broadcast +func.func @avgpool_reshape_broadcast(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<1x16x16x1xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor + %2 = "mhlo.reduce_window"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x8x8x8xf32> + %3 = "mhlo.reduce_window"(%0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = mhlo.add %arg1, %arg2 : tensor + mhlo.return %7 : tensor + }) {base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<1x16x16x1xf32>, tensor) -> tensor<1x8x8x1xf32> + %4 = mhlo.reshape %3 : (tensor<1x8x8x1xf32>) -> tensor<8x8xf32> + %5 = "mhlo.broadcast_in_dim"(%4) <{broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>}> : (tensor<8x8xf32>) -> tensor<4x8x8x8xf32> + %6 = mhlo.divide %2, %5 : tensor<4x8x8x8xf32> + return %6 : tensor<4x8x8x8xf32> +} + +// CHECK: %[[POOL_OUT:.*]] = "tfl.average_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> +// CHECK: return %[[POOL_OUT]] + +// ----- + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> max pool +//===------------------------------------------------------------------------=== + +// CHECK-LABEL: maxpool_same +func.func @maxpool_same(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x8x8x8xf32> + func.return %1 : tensor<4x8x8x8xf32> +} + +// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x8x8x8xf32> + +// ----- + +// CHECK-LABEL: maxpool_valid +func.func @maxpool_valid(%arg0: tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %6 = mhlo.maximum %arg1, %arg2 : tensor + "mhlo.return"(%6) : (tensor) -> () + }) { + base_dilations = dense<1> : tensor<4xi64>, + padding = dense<[[0, 0], [0, 0], [0, 0], [0, 0]]> : tensor<4x2xi64>, + window_dilations = dense<1> : tensor<4xi64>, + window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, + window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>} : (tensor<4x16x16x8xf32>, tensor) -> tensor<4x7x7x8xf32> + func.return %1 : tensor<4x7x7x8xf32> +} + +// CHECK: %1 = "tfl.max_pool_2d"(%arg0) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x8xf32>) -> tensor<4x7x7x8xf32> + +// ----- + +// CHECK-LABEL: maxpool_valid_channel_first +func.func @maxpool_valid_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x7x7xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> + %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %4 = mhlo.maximum %arg1, %arg2 : tensor + mhlo.return %4 : tensor + }) : (tensor<4x16x16x3xf32>, tensor) -> tensor<4x7x7x3xf32> + %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x7x7x3xf32>) -> tensor<4x3x7x7xf32> + return %3 : tensor<4x3x7x7xf32> +} + +// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0 +// CHECK: "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x7x7x3xf32> +// CHECK: return +// CHECK-SAME: tensor<4x3x7x7xf32> + +// ----- + +// CHECK-LABEL: maxpool_same_channel_first +func.func @maxpool_same_channel_first(%arg0: tensor<4x3x16x16xf32>) -> tensor<4x3x8x8xf32> { + // "0xFF800000" represents -INF for f32. + %0 = mhlo.constant dense<0xFF800000> : tensor + %1 = "mhlo.transpose"(%arg0) <{permutation = dense<[0, 2, 3, 1]> : tensor<4xi64>}> : (tensor<4x3x16x16xf32>) -> tensor<4x16x16x3xf32> + %2 = "mhlo.reduce_window"(%1, %0) <{base_dilations = dense<1> : tensor<4xi64>, padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 3, 3, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %4 = mhlo.maximum %arg1, %arg2 : tensor + mhlo.return %4 : tensor + }) : (tensor<4x16x16x3xf32>, tensor) -> tensor<4x8x8x3xf32> + %3 = "mhlo.transpose"(%2) <{permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>}> : (tensor<4x8x8x3xf32>) -> tensor<4x3x8x8xf32> + return %3 : tensor<4x3x8x8xf32> +} + +// CHECK: %[[TPOSED_ARG0:.*]] = "tfl.transpose"(%arg0 +// CHECK: "tfl.max_pool_2d"(%[[TPOSED_ARG0]]) <{filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<4x16x16x3xf32>) -> tensor<4x8x8x3xf32> +// CHECK: return +// CHECK-SAME: tensor<4x3x8x8xf32> + +// ----- + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> tfl.cumsum +//===------------------------------------------------------------------------=== + +// CHECK-LABEL: reduce_window_sum +func.func @reduce_window_sum(%arg0: tensor<4x12xf32>) -> tensor<4x12xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor + %1 = "mhlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = mhlo.add %arg1, %arg2 : tensor + "mhlo.return"(%2) : (tensor) -> () + }) {base_dilations = dense<1> : tensor<2xi64>, padding = dense<[[3, 0], [0, 0]]> : tensor<2x2xi64>, window_dilations = dense<1> : tensor<2xi64>, window_dimensions = dense<[4, 1]> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : (tensor<4x12xf32>, tensor) -> tensor<4x12xf32> + func.return %1 : tensor<4x12xf32> +} + +// CHECK: %[[AXIS:.*]] = arith.constant dense<0> : tensor +// CHECK: "tfl.cumsum"(%arg0, %[[AXIS]]) <{exclusive = false, reverse = false}> : (tensor<4x12xf32>, tensor) -> tensor<4x12xf32> + + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.slice +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: slice +func.func @slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { + %0 = "mhlo.slice"(%arg0) <{limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}> : (tensor<1x4672xf32>) -> tensor<1x519xf32> + func.return %0 : tensor<1x519xf32> +} + +// CHECK: %[[CST:.*]] = arith.constant dense<[0, 4153]> : tensor<2xi64> +// CHECK: %[[CST_0:.*]] = arith.constant dense<[1, 4672]> : tensor<2xi64> +// CHECK: %[[CST_1:.*]] = arith.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = "tfl.cast"(%[[CST]]) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_1:.*]] = "tfl.cast"(%[[CST_0]]) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_2:.*]] = "tfl.cast"(%[[CST_1]]) : (tensor<2xi64>) -> tensor<2xi32> +// CHECK: %[[VAL_3:.*]] = "tfl.strided_slice"(%arg0, %[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) <{begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, offset = false, shrink_axis_mask = 0 : i32}> : (tensor<1x4672xf32>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<1x519xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.sort +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: sort_to_topk_iota_broadcast +func.func @sort_to_topk_iota_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) { + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<6xi32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32> + %2:2 = "mhlo.sort"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%3) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) + func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32> +} + +// CHECK: arith.constant dense<6> : tensor +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) + +// ----- + +// CHECK-LABEL: sort_to_topk_iota_cst_broadcast +func.func @sort_to_topk_iota_cst_broadcast(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) { + %0 = mhlo.constant dense<[0, 1, 2, 3, 4, 5]> : tensor<6xi32> + %1 = "mhlo.broadcast_in_dim"(%0) <{broadcast_dimensions = dense<[1]> : tensor<1xi64>, name = "broadcast.0"}> : (tensor<6xi32>) -> tensor<3x6xi32> + %2:2 = "mhlo.sort"(%arg0, %1) ({ + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%3) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) + func.return %2#0, %2#1 : tensor<3x6xf32>, tensor<3x6xi32> +} + +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32>) + +// ----- + +// CHECK-LABEL: sort_to_topk_const +func.func @sort_to_topk_const(%arg0: tensor<3x6xf32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) { + %0 = mhlo.constant dense<[[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]> : tensor<3x6xi32> + %1:2 = "mhlo.sort"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor, %arg3: tensor, %arg4: tensor): + %3 = "mhlo.compare"(%arg1, %arg2) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor, tensor) -> tensor + "mhlo.return"(%3) : (tensor) -> () + }) {dimension = 1 : i64, is_stable = true} : (tensor<3x6xf32>, tensor<3x6xi32>) -> (tensor<3x6xf32>, tensor<3x6xi32>) + func.return %1#0, %1#1 : tensor<3x6xf32>, tensor<3x6xi32> +} + +// CHECK: %[[CST:.*]] = arith.constant dense<6> : tensor +// CHECK: %values, %indices = "tfl.topk_v2"(%arg0, %[[CST]]) : (tensor<3x6xf32>, tensor) -> (tensor<3x6xf32>, tensor<3x6xi32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.iota +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: iota_1d +func.func @iota_1d() -> tensor<123xf32> { + %0 = "mhlo.iota"() <{ iota_dimension = 0 : i64 }> : () -> tensor<123xf32> + func.return %0 : tensor<123xf32> +} + +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<1.230000e+02> : tensor +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<1.000000e+00> : tensor +// CHECK: "tfl.range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<123xf32> + +// ----- + +// CHECK-LABEL: iota_3d +func.func @iota_3d() -> tensor<5x7x9xi32> { + %0 = "mhlo.iota"() <{ iota_dimension = 1 : i64 }> : () -> tensor<5x7x9xi32> + func.return %0 : tensor<5x7x9xi32> +} + +// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<7> : tensor +// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[RANGE:.*]] = "tfl.range"(%[[CST_1]], %[[CST_2]], %[[CST_3]]) : (tensor, tensor, tensor) -> tensor<7xi32> +// CHECK: %[[CST_4:.*]] = arith.constant dense<[1, 7, 1]> : tensor<3xi64> +// CHECK: %[[CAST:.*]] = "tfl.cast"(%[[CST_4]]) : (tensor<3xi64>) -> tensor<3xi32> +// CHECK: %[[RESHAPE:.*]] = "tfl.reshape"(%[[RANGE]], %[[CAST]]) : (tensor<7xi32>, tensor<3xi32>) -> tensor<1x7x1xi32> +// CHECK: %[[CST_5:.*]] = arith.constant dense<[5, 7, 9]> : tensor<3xi64> +// CHECK: "tfl.broadcast_to"(%[[RESHAPE]], %[[CST_5]]) : (tensor<1x7x1xi32>, tensor<3xi64>) -> tensor<5x7x9xi32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.dynamic_slice +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: dynamic_slice +func.func @dynamic_slice(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor +// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor, tensor) -> tensor<2xi32> +// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64> +// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<4x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_slice_i64 +func.func @dynamic_slice_i64(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<4x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<[4, 2]> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<4x2xf32> + func.return %0 : tensor<4x2xf32> +} + +// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_IS_3:.*]] = arith.constant dense<3> : tensor +// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_3]], %[[MAX_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor, tensor) -> tensor<2xi64> +// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<[4, 2]> : tensor<2xi64> +// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi64>, tensor<2xi64>) -> tensor<4x2xf32> + +// ----- + +// CHECK-LABEL: dynamic_slice_splat_sizes +func.func @dynamic_slice_splat_sizes(%arg0: tensor<7x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<2x2xf32> { + %0 = "mhlo.dynamic_slice"(%arg0, %arg1, %arg2) <{slice_sizes = dense<2> : tensor<2xi64>}> : (tensor<7x3xf32>, tensor, tensor) -> tensor<2x2xf32> + func.return %0 : tensor<2x2xf32> +} + +// CHECK-DAG: %[[CST_IS_0:.*]] = arith.constant dense<0> : tensor +// CHECK-DAG: %[[CST_IS_5:.*]] = arith.constant dense<5> : tensor +// CHECK: %[[MAX_1:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg1) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_1:.*]] = "tfl.minimum"(%[[CST_IS_5]], %[[MAX_1]]) : (tensor, tensor) -> tensor +// CHECK: %[[CST_IS_1:.*]] = arith.constant dense<1> : tensor +// CHECK: %[[MAX_2:.*]] = "tfl.maximum"(%[[CST_IS_0]], %arg2) : (tensor, tensor) -> tensor +// CHECK: %[[MIN_2:.*]] = "tfl.minimum"(%[[CST_IS_1]], %[[MAX_2]]) : (tensor, tensor) -> tensor +// CHECK: %[[PACK:.*]] = "tfl.pack"(%[[MIN_1]], %[[MIN_2]]) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor, tensor) -> tensor<2xi32> +// CHECK: %[[SLICE_SIZE:.*]] = arith.constant dense<2> : tensor<2xi64> +// CHECK: "tfl.slice"(%arg0, %[[PACK]], %[[SLICE_SIZE]]) : (tensor<7x3xf32>, tensor<2xi32>, tensor<2xi64>) -> tensor<2x2xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.dynamic_update_slice +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: dynamic_update_slice +func.func @dynamic_update_slice(%arg0: tensor<28x1x100xf32>, %arg1: tensor<1x1x100xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<28x1x100xf32> { + %0 = "mhlo.dynamic_update_slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor, tensor, tensor) -> tensor<28x1x100xf32> + func.return %0 : tensor<28x1x100xf32> +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3, %arg4) <{axis = 0 : i32, values_count = 3 : i32}> : (tensor, tensor, tensor) -> tensor<3xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<28x1x100xf32>, tensor<1x1x100xf32>, tensor<3xi32>) -> tensor<28x1x100xf32> + +// ----- + +// CHECK-LABEL: dynamic_update_slice_inputs_have_dynamic_dim +func.func @dynamic_update_slice_inputs_have_dynamic_dim(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor { + %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3 : (tensor, tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3) <{axis = 0 : i32, values_count = 2 : i32}> : (tensor, tensor) -> tensor<2xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor, tensor, tensor<2xi32>) -> tensor + +// ----- + +// CHECK-LABEL: dynamic_update_slice_operand_has_dynamic_dim +func.func @dynamic_update_slice_operand_has_dynamic_dim(%arg0: tensor<1x?x256xf32>, %arg1: tensor<1x1x256xf32>, %arg2: tensor, %arg3: tensor, %arg4: tensor) -> tensor<1x?x256xf32> { + %0 = mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3, %arg4 : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor, tensor, tensor) -> tensor<1x?x256xf32> + func.return %0 : tensor<1x?x256xf32> +} + +// CHECK: %0 = "tfl.pack"(%arg2, %arg3, %arg4) <{axis = 0 : i32, values_count = 3 : i32}> : (tensor, tensor, tensor) -> tensor<3xi32> +// CHECK: %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<1x?x256xf32>, tensor<1x1x256xf32>, tensor<3xi32>) -> tensor<1x?x256xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// rounding +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: round +func.func @round(%arg0: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<8x128xf32> + %1 = mhlo.constant dense<5.000000e-01> : tensor<8x128xf32> + %2 = mhlo.constant dense<1.000000e+00> : tensor<8x128xf32> + %3 = "mhlo.floor"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %4 = mhlo.subtract %arg0, %3 : tensor<8x128xf32> + %5 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %6 = "mhlo.compare"(%4, %1) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %7 = mhlo.multiply %arg0, %1 : tensor<8x128xf32> + %8 = "mhlo.floor"(%7) : (tensor<8x128xf32>) -> tensor<8x128xf32> + %9 = mhlo.multiply %8, %0 : tensor<8x128xf32> + %10 = mhlo.subtract %3, %9 : tensor<8x128xf32> + %11 = "mhlo.compare"(%10, %2) {comparison_direction = #mhlo} : (tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xi1> + %12 = mhlo.and %6, %11 : tensor<8x128xi1> + %13 = mhlo.or %5, %12 : tensor<8x128xi1> + %14 = mhlo.add %3, %2 : tensor<8x128xf32> + %15 = "mhlo.select"(%13, %14, %3) : (tensor<8x128xi1>, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + func.return %15 : tensor<8x128xf32> +} + +// CHECK: "tfl.round"(%arg0) : (tensor<8x128xf32>) -> tensor<8x128xf32> + +// ----- + +// CHECK-LABEL: floor_mod_float +func.func @floor_mod_float(%arg0: tensor<192x8xf32>, %arg1: tensor<192x8xf32>) -> tensor<192x8xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32> + %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xf32> + %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1> + %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %6 = mhlo.and %4, %5 : tensor<192x8xi1> + %7 = mhlo.add %1, %arg1 : tensor<192x8xf32> + %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + func.return %8 : tensor<192x8xf32> +} + +// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + +// ----- + +// CHECK-LABEL: floor_mod_int +func.func @floor_mod_int(%arg0: tensor<192x8xi32>, %arg1: tensor<192x8xi32>) -> tensor<192x8xi32> { + %0 = mhlo.constant dense<0> : tensor<192x8xi32> + %1 = mhlo.remainder %arg0, %arg1 : tensor<192x8xi32> + %2 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %3 = "mhlo.compare"(%arg1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %3) {comparison_direction = #mhlo} : (tensor<192x8xi1>, tensor<192x8xi1>) -> tensor<192x8xi1> + %5 = "mhlo.compare"(%1, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %6 = mhlo.and %4, %5 : tensor<192x8xi1> + %7 = mhlo.add %1, %arg1 : tensor<192x8xi32> + %8 = "mhlo.select"(%6, %7, %1) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + func.return %8 : tensor<192x8xi32> +} + +// CHECK: "tfl.floor_mod"(%arg0, %arg1) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + +// ----- + +// CHECK-LABEL: floor_mod_float_cst +func.func @floor_mod_float_cst(%arg0: tensor<192x8xf32>) -> tensor<192x8xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<192x8xf32> + %1 = mhlo.constant dense<2.000000e+00> : tensor<192x8xf32> + %2 = mhlo.remainder %arg0, %1 : tensor<192x8xf32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xi1> + %5 = mhlo.and %3, %4 : tensor<192x8xi1> + %6 = mhlo.add %2, %1 : tensor<192x8xf32> + %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + func.return %7 : tensor<192x8xf32> +} + +// CHECK: %cst = arith.constant dense<2.000000e+00> : tensor<192x8xf32> +// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xf32>, tensor<192x8xf32>) -> tensor<192x8xf32> + +// ----- + +// CHECK-LABEL: floor_mod_int_cst +func.func @floor_mod_int_cst(%arg0: tensor<192x8xi32>) -> tensor<192x8xi32> { + %0 = mhlo.constant dense<0> : tensor<192x8xi32> + %1 = mhlo.constant dense<2> : tensor<192x8xi32> + %2 = mhlo.remainder %arg0, %1 : tensor<192x8xi32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %4 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi1> + %5 = mhlo.and %3, %4 : tensor<192x8xi1> + %6 = mhlo.add %2, %1 : tensor<192x8xi32> + %7 = "mhlo.select"(%5, %6, %2) : (tensor<192x8xi1>, tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + func.return %7 : tensor<192x8xi32> +} + +// CHECK: %cst = arith.constant dense<2> : tensor<192x8xi32> +// CHECK: "tfl.floor_mod"(%arg0, %cst) : (tensor<192x8xi32>, tensor<192x8xi32>) -> tensor<192x8xi32> + +// ----- + +// CHECK-LABEL: floor_div +func.func @floor_div(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %2 = mhlo.remainder %arg0, %arg1 : tensor<10x10xf32> + %3 = "mhlo.compare"(%2, %0) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %4 = "mhlo.sign"(%arg1) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %5 = "mhlo.sign"(%2) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %6 = "mhlo.compare"(%4, %5) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %7 = mhlo.and %3, %6 : tensor<10x10xi1> + %8 = mhlo.subtract %arg0, %2 : tensor<10x10xf32> + %9 = mhlo.divide %8, %arg1 : tensor<10x10xf32> + %10 = mhlo.add %9, %1 : tensor<10x10xf32> + %11 = "mhlo.select"(%7, %10, %9) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %12 = "mhlo.round_nearest_afz"(%11) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %13 = "mhlo.tuple"(%12) : (tensor<10x10xf32>) -> tuple> + func.return %12 : tensor<10x10xf32> +} + +// CHECK: tfl.floor_div %arg0, %arg1 : tensor<10x10xf32 + +// ----- + +// CHECK-LABEL: floor_div_cst +func.func @floor_div_cst(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %2 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32> + %3 = mhlo.constant dense<5.000000e-01> : tensor<10x10xf32> + %4 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %5 = mhlo.remainder %arg0, %0 : tensor<10x10xf32> + %6 = "mhlo.compare"(%5, %1) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %7 = "mhlo.sign"(%5) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %8 = "mhlo.compare"(%2, %7) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %9 = mhlo.and %6, %8 : tensor<10x10xi1> + %10 = mhlo.subtract %arg0, %5 : tensor<10x10xf32> + %11 = mhlo.multiply %10, %3 : tensor<10x10xf32> + %12 = mhlo.add %11, %4 : tensor<10x10xf32> + %13 = "mhlo.select"(%9, %12, %11) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %14 = "mhlo.round_nearest_afz"(%13) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %15 = "mhlo.tuple"(%14) : (tensor<10x10xf32>) -> tuple> + func.return %14 : tensor<10x10xf32> +} + +// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> +// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32> + +// ----- + +// CHECK-LABEL: floor_div_cst2 +func.func @floor_div_cst2(%arg0: tensor<10x10xf32>) -> tensor<10x10xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<10x10xf32> + %1 = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor<10x10xf32> + %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x10xf32> + %4 = mhlo.remainder %arg0, %1 : tensor<10x10xf32> + %5 = "mhlo.compare"(%4, %2) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %6 = "mhlo.sign"(%4) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %7 = "mhlo.compare"(%0, %6) {comparison_direction = #mhlo} : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xi1> + %8 = mhlo.and %5, %7 : tensor<10x10xi1> + %9 = mhlo.subtract %arg0, %4 : tensor<10x10xf32> + %10 = mhlo.divide %9, %1 : tensor<10x10xf32> + %11 = mhlo.add %10, %3 : tensor<10x10xf32> + %12 = "mhlo.select"(%8, %11, %10) : (tensor<10x10xi1>, tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<10x10xf32> + %13 = "mhlo.round_nearest_afz"(%12) : (tensor<10x10xf32>) -> tensor<10x10xf32> + %14 = "mhlo.tuple"(%13) : (tensor<10x10xf32>) -> tuple> + func.return %13 : tensor<10x10xf32> +} + +// CHECK: %[[CST:.*]] = mhlo.constant dense<2.000000e+00> : tensor<10x10xf32> +// CHECK: tfl.floor_div %arg0, %[[CST]] : tensor<10x10xf32> + +// ----- + +// CHECK-LABEL: floor_div_broadcast_cst +func.func @floor_div_broadcast_cst(%arg0: tensor<10x8xf32>) -> tensor<10x8xf32> { + %0 = mhlo.constant dense<1.000000e+00> : tensor<10x8xf32> + %1 = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 4.000000e+00, 8.000000e+00, 1.600000e+01, 3.200000e+01, 6.400000e+01, 1.280000e+02]> : tensor<8xf32> + %2 = mhlo.constant dense<0.000000e+00> : tensor<10x8xf32> + %3 = mhlo.constant dense<-1.000000e+00> : tensor<10x8xf32> + %5 = "mhlo.broadcast_in_dim"(%1) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<8xf32>) -> tensor<10x8xf32> + %6 = mhlo.remainder %arg0, %5 : tensor<10x8xf32> + %7 = "mhlo.compare"(%6, %2) {comparison_direction = #mhlo} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1> + %8 = "mhlo.sign"(%6) : (tensor<10x8xf32>) -> tensor<10x8xf32> + %9 = "mhlo.compare"(%0, %8) {comparison_direction = #mhlo} : (tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xi1> + %10 = mhlo.and %7, %9 : tensor<10x8xi1> + %11 = mhlo.subtract %arg0, %6 : tensor<10x8xf32> + %12 = mhlo.divide %11, %5 : tensor<10x8xf32> + %13 = mhlo.add %12, %3 : tensor<10x8xf32> + %14 = "mhlo.select"(%10, %13, %12) : (tensor<10x8xi1>, tensor<10x8xf32>, tensor<10x8xf32>) -> tensor<10x8xf32> + %15 = "mhlo.round_nearest_afz"(%14) : (tensor<10x8xf32>) -> tensor<10x8xf32> + %16 = "mhlo.tuple"(%15) : (tensor<10x8xf32>) -> tuple> + func.return %15 : tensor<10x8xf32> +} + +// CHECK: %[[BCAST:.*]] = "mhlo.broadcast_in_dim"(%1) +// CHECK: tfl.floor_div %arg0, %[[BCAST]] : tensor<10x8xf32> + +// ----- + +//===----------------------------------------------------------------------===// +// unary elementwise +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: convert_i32_f32 +func.func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { + %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.cast + +// ----- + +// CHECK-LABEL: abs +func.func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.abs + +// ----- + +// CHECK-LABEL: abs_dynamic +func.func @abs_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.abs + +// ----- + +// CHECK-LABEL: ceil +func.func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.ceil + +// ----- + +// CHECK-LABEL: ceil_dynamic +func.func @ceil_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.ceil + +// ----- + +// CHECK-LABEL: complex_abs +func.func @complex_abs(%arg0: tensor<2xcomplex>) -> tensor<2xf32> { + %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK-NOT: tfl + +// ----- + +func.func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { + %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor<2xf32> +// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor +// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor<2xf32>, tensor) -> tensor<2xi1> +// CHECK: return %1 : tensor<2xi1> + +// ----- + +func.func @is_finite_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %0 = tfl.sub %arg0, %arg0 {fused_activation_function = "NONE"} : tensor +// CHECK: %cst = arith.constant dense<0.000000e+00> : tensor +// CHECK: %1 = "tfl.equal"(%0, %cst) : (tensor, tensor) -> tensor + +// ----- + +// CHECK-LABEL: cos +func.func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.cos + +// ----- + +// CHECK-LABEL: cos_dynamic +func.func @cos_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.cos + +// ----- + +// CHECK-LABEL: logistic +func.func @logistic(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.logistic"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.logistic + +// ----- + +// CHECK-LABEL: exp +func.func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.exp + +// ----- + +// CHECK-LABEL: exp_dynamic +func.func @exp_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.exp + +// ----- + +// CHECK-LABEL: expm1 +func.func @expm1(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %0 = "tfl.exp"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %1 = tfl.sub(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor) -> tensor<2xf32> + +// ----- + +// CHECK-LABEL: floor +func.func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.floor + +// ----- + +// CHECK-LABEL: floor_dynamic +func.func @floor_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.floor + +// ----- + +// CHECK-LABEL: log +func.func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.log + +// ----- + +// CHECK-LABEL: log_dynamic +func.func @log_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.log + +// ----- + +// CHECK-LABEL: log1p +func.func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2xf32>, tensor) -> tensor<2xf32> +// CHECK: %1 = "tfl.log"(%0) : (tensor<2xf32>) -> tensor<2xf32> + +// ----- + +// CHECK-LABEL: log1p_dynamic +func.func @log1p_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor +// CHECK: %1 = "tfl.log"(%0) : (tensor) -> tensor + +// ----- + +// CHECK-LABEL: neg +func.func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.neg + +// ----- + +// CHECK-LABEL: neg_dynamic +func.func @neg_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.neg + +// ----- + +// CHECK-LABEL: sin +func.func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.sin + +// ----- + +// CHECK-LABEL: sin_dynamic +func.func @sin_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.sin + +// ----- + +// CHECK-LABEL: rsqrt +func.func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.rsqrt + +// ----- + +// CHECK-LABEL: rsqrt_dynamic +func.func @rsqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.rsqrt + +// ----- + +// CHECK-LABEL: @sqrt +func.func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.sqrt + +// ----- + +// CHECK-LABEL: sqrt_dynamic +func.func @sqrt_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.sqrt + +// ----- + +// CHECK-LABEL: tanh +func.func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.tanh + +// ----- + +// CHECK-LABEL: tanh_dynamic +func.func @tanh_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.tanh + +// ----- + +// CHECK-LABEL: bitcast +func.func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: tfl.bitcast + +// ----- + +// CHECK-LABEL: bitcast_dynamic +func.func @bitcast_dynamic(%arg0: tensor) -> tensor { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK: tfl.bitcast + +// ----- + +// CHECK-LABEL: bitcast_same_widths +func.func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { + %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// CHECK: tfl.bitcast + +// ----- + +//===----------------------------------------------------------------------===// +// logical and bitwise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: not +func.func @not(%arg0: tensor<5x3x1xi1>) -> tensor<5x3x1xi1> { + %0 = "mhlo.not"(%arg0): (tensor<5x3x1xi1>) -> (tensor<5x3x1xi1>) + func.return %0 : tensor<5x3x1xi1> +} + +// CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<5x3x1xi1>) -> tensor<5x3x1xi1> + +// ----- + +// CHECK-LABEL: not_i8 +func.func @not_i8(%arg0: tensor<7x9x11xi8>) -> tensor<7x9x11xi8> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi8>) -> (tensor<7x9x11xi8>) + func.return %0 : tensor<7x9x11xi8> +} + +// CHECK: %cst = arith.constant dense<-1> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi8>, tensor) -> tensor<7x9x11xi8> + +// ----- + +// CHECK-LABEL: not_i16 +func.func @not_i16(%arg0: tensor<7x9x11xi16>) -> tensor<7x9x11xi16> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi16>) -> (tensor<7x9x11xi16>) + func.return %0 : tensor<7x9x11xi16> +} + +// CHECK: %cst = arith.constant dense<-1> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi16>, tensor) -> tensor<7x9x11xi16> + +// ----- + +// CHECK-LABEL: not_i32 +func.func @not_i32(%arg0: tensor<7x9x11xi32>) -> tensor<7x9x11xi32> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xi32>) -> (tensor<7x9x11xi32>) + func.return %0 : tensor<7x9x11xi32> +} + +// CHECK: %cst = arith.constant dense<-1> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xi32>, tensor) -> tensor<7x9x11xi32> + +// ----- + +// CHECK-LABEL: not_ui8 +func.func @not_ui8(%arg0: tensor<7x9x11xui8>) -> tensor<7x9x11xui8> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui8>) -> (tensor<7x9x11xui8>) + func.return %0 : tensor<7x9x11xui8> +} + +// CHECK: %cst = arith.constant dense<255> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui8>, tensor) -> tensor<7x9x11xui8> + +// ----- + +// CHECK-LABEL: not_ui16 +func.func @not_ui16(%arg0: tensor<7x9x11xui16>) -> tensor<7x9x11xui16> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui16>) -> (tensor<7x9x11xui16>) + func.return %0 : tensor<7x9x11xui16> +} + +// CHECK: %cst = arith.constant dense<65535> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui16>, tensor) -> tensor<7x9x11xui16> + +// ----- + +// CHECK-LABEL: not_ui32 +func.func @not_ui32(%arg0: tensor<7x9x11xui32>) -> tensor<7x9x11xui32> { + %0 = "mhlo.not"(%arg0): (tensor<7x9x11xui32>) -> (tensor<7x9x11xui32>) + func.return %0 : tensor<7x9x11xui32> +} + +// CHECK: %cst = arith.constant dense<4294967295> : tensor +// CHECK: %0 = "tfl.bitwise_xor"(%arg0, %cst) : (tensor<7x9x11xui32>, tensor) -> tensor<7x9x11xui32> + +// ----- + +//===----------------------------------------------------------------------===// +// binary ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: remainder +func.func @remainder(%arg0: tensor<10x8xi32>, %arg1: tensor<10x8xi32>) -> tensor<10x8xi32> { + %0 = mhlo.remainder %arg0, %arg1 : tensor<10x8xi32> + func.return %0 : tensor<10x8xi32> +} + +// CHECK: %0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<10x8xi32>, tensor<10x8xi32>) -> tensor<10x8xi32> + +// ----- + +// CHECK-LABEL: shift_right_arith +func.func @shift_right_arith(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + +// ----- + +// CHECK-LABEL: shift_right_logical +func.func @shift_right_logical(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { + %0 = mhlo.shift_right_logical %arg0, %arg1 : tensor<4xi32> + func.return %0 : tensor<4xi32> +} + +// CHECK: %0 = "tfl.right_shift"(%arg0, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo.compare +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: greater_unsupported_compare_type +func.func @greater_unsupported_compare_type(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {compare_type = #mhlo, comparison_direction = #mhlo} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK-NOT: tfl +// CHECK: mhlo.compare + +// ----- + +// CHECK-LABEL: equal +func.func @equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.equal + +// ----- + +// CHECK-LABEL: notequal +func.func @notequal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.not_equal + +// ----- + +// CHECK-LABEL: greater +func.func @greater(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.greater + +// ----- + +// CHECK-LABEL: greater_equal +func.func @greater_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.greater_equal + +// ----- + +// CHECK-LABEL: less +func.func @less(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.less + +// ----- + +// CHECK-LABEL: less_equal +func.func @less_equal(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2xi1> { + %0 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + func.return %0 : tensor<2xi1> +} + +// CHECK: tfl.less_equal + +// ----- + +//===----------------------------------------------------------------------===// +// mhlo binary element-wise ops +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: maximum +func.func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.maximum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: minimum +func.func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: "tfl.minimum"(%arg0, %arg1) +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: mul +func.func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { + %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32> + func.return %0 : tensor<2xi32> +} + +// CHECK: tfl.mul %arg0, %arg0 +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: pow +func.func @pow(%arg0: tensor<4xf32>) -> tensor<4xf32> { + %0 = mhlo.power %arg0, %arg0 : tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: tfl.pow +// CHECK-NOT: mhlo + +// ----- + +// CHECK-LABEL: clamp +func.func @clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "mhlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-NEXT: %0 = "tfl.minimum"(%arg1, %arg2) +// CHECK-NEXT: %1 = "tfl.maximum"(%0, %arg0) +// CHECK-NEXT: return %1 : tensor + + diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index eb866dc64931d0..2b96254e04fc3d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -1144,7 +1144,8 @@ class ComposeUniformQuantizedDotGeneralOp .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), - /*precision_config=*/op.getPrecisionConfigAttr()); + /*precision_config=*/op.getPrecisionConfigAttr(), + /*algorithm=*/op.getAlgorithmAttr()); rewriter.replaceAllUsesWith(op.getResult(), new_dot_general_op.getResult()); @@ -1489,7 +1490,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), - /*precision_config=*/op.getPrecisionConfigAttr()); + /*precision_config=*/op.getPrecisionConfigAttr(), + /*algorithm=*/op.getAlgorithmAttr()); rewriter.replaceAllUsesWith(op.getResult(), new_dot_general_op.getResult()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index 2ecfcd5406e2d6..27e655c3aa51d3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -21,7 +21,6 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -125,6 +124,7 @@ cc_library( hdrs = ["conv.h"], deps = [ ":conv_util", + ":op_util_common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -140,7 +140,10 @@ cc_library( srcs = ["conv_util.cc"], hdrs = ["conv_util.h"], deps = [ + ":op_util_common", + "//tensorflow/core/lib/math:math_util", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@local_xla//xla/mlir_hlo", @@ -152,6 +155,7 @@ cc_library( srcs = ["pad.cc"], hdrs = ["pad.h"], deps = [ + ":op_util_common", ":pad_util", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "@llvm-project//llvm:Support", @@ -168,6 +172,7 @@ cc_library( srcs = ["pad_util.cc"], hdrs = ["pad_util.h"], deps = [ + ":op_util_common", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -190,3 +195,93 @@ cc_library( "@local_xla//xla/mlir_hlo", ], ) + +cc_library( + name = "reduce_window", + srcs = ["reduce_window.cc"], + hdrs = ["reduce_window.h"], + deps = [ + ":op_util_common", + ":reduce_window_util", + ":util", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "op_util_common", + srcs = ["op_util_common.cc"], + hdrs = ["op_util_common.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "reduce_window_util", + srcs = ["reduce_window_util.cc"], + hdrs = ["reduce_window_util.h"], + deps = [ + ":op_util_common", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "slice", + srcs = ["slice.cc"], + hdrs = ["slice.h"], + deps = [ + ":op_util_common", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "sort", + srcs = ["sort.cc"], + hdrs = ["sort.h"], + deps = [ + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/stablehlo:hlo_matchers", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "iota", + srcs = ["iota.cc"], + hdrs = ["iota.h"], + deps = [ + ":op_util_common", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], +) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc index fa875069d66a9b..1ad6e7bfc044e3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.cc @@ -15,18 +15,23 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" #include +#include +#include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -43,27 +48,27 @@ bool IsShapeFullyStatic(ArrayRef shape) { return llvm::all_of(shape, [](int64_t d) { return d >= 0; }); } -bool AreShapesSupported(const ConvData& data) { +bool AreShapesSupported(const ConvView& data) { return IsShapeFullyStatic(data.InputShape()) && IsShapeFullyStatic(data.KernelShape()) && IsShapeFullyStatic(data.OutputShape()); } -bool IsPaddingSupported(const ConvData& data) { +bool IsPaddingSupported(const ConvView& data) { return llvm::all_of(data.Padding(), [](const DimPadding& p) { return p.Hi() == 0 && p.Lo() == 0; }); } -bool IsInputDilationSupported(const ConvData& data) { +bool IsInputDilationSupported(const ConvView& data) { return llvm::all_of(data.InputDilations(), [](int64_t v) { return v == 1; }); } -bool IsBatchGroupSupported(const ConvData& data) { +bool IsBatchGroupSupported(const ConvView& data) { return data.BatchGroupCount() == 1; } -bool IsWindowReversalSupported(const ConvData& data) { +bool IsWindowReversalSupported(const ConvView& data) { return llvm::all_of(data.WindowReversal(), [](bool b) { return !b; }); } @@ -71,12 +76,19 @@ bool IsWindowReversalSupported(const ConvData& data) { // Used externally to setup a ConversionTarget with dynamically legal // mhlo.convolution. Doubles as matching predicate during legalization. bool IsConvLegal(mhlo::ConvolutionOp op) { - const ConvData data(op); + const ConvView data(op); - return !IsBatchGroupSupported(data) || !IsStandardConv(data) || - !IsInputDilationSupported(data) || !AreShapesSupported(data) || - !IsTFLNativeLayout(data) || !IsPaddingSupported(data) || - !IsWindowReversalSupported(data); + const bool supported_conv_type = IsStandardConv(data) || + IsDepthwiseConv(data) || + IsSupportedNonTrivialConv(data); + + const bool is_non_supported_trivial_conv = + (!IsSupportedNonTrivialConv(data) && + (!IsPaddingSupported(data) || !IsInputDilationSupported(data))); + + return !supported_conv_type || !IsBatchGroupSupported(data) || + !AreShapesSupported(data) || !IsTFLNativeLayout(data) || + is_non_supported_trivial_conv || !IsWindowReversalSupported(data); } //===----------------------------------------------------------------------===// @@ -85,7 +97,7 @@ bool IsConvLegal(mhlo::ConvolutionOp op) { // Bias is a zero tensor of shape [output_channels]. arith::ConstantOp BuildEmptyBias(OpBuilder& b, Location loc, - const ConvData& data) { + const ConvView& data) { auto bias_type = RankedTensorType::get( {data.OutputLayout().SpecialDim2(data.OutputShape())}, data.ElementType()); @@ -105,9 +117,10 @@ LogicalResult LegalizeConv2D::matchAndRewrite( mhlo::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const { // Parse mhlo.convolution attrs into cc types. - const ConvData data(op); + const ConvView data(op); - if (IsConvLegal(op) || data.InputLayout().Rank() != 4) { + if (IsConvLegal(op) || !IsStandardConv(data) || + data.InputLayout().Rank() != 4) { return failure(); } @@ -151,6 +164,74 @@ LogicalResult LegalizeConv2D::matchAndRewrite( return success(); } +class LegalizeConvDepthwise : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeConvDepthwise::matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // Parse mhlo.convolution attrs into cc types. + const ConvView data(op); + + if (IsConvLegal(op) || !IsDepthwiseConv(data)) { + return failure(); + } + + // + // dilations + //===------- + + const auto& kernel_dilations = data.KernelDilations(); + auto tfl_h_dilation = rewriter.getI32IntegerAttr(kernel_dilations[0]); + auto tfl_w_dilation = rewriter.getI32IntegerAttr(kernel_dilations[1]); + + // + // strides + //===----- + + const auto& window_strides = data.Strides(); + auto tfl_h_stride = rewriter.getI32IntegerAttr(window_strides[0]); + auto tfl_w_stride = rewriter.getI32IntegerAttr(window_strides[1]); + + // + // padding + //===----- + + // Explicit and same padding should be handeled in upstream "prepare" phase. + // Same padding will be fused in downstream "optimize" phase on tfl dialect. + auto tfl_padding = rewriter.getStringAttr("VALID"); + + // + // depth multiplier + //===----- + + const int64_t out_channels = + data.OutputLayout().SpecialDim2(data.OutputShape()); + const int64_t in_channels = data.InputLayout().SpecialDim2(data.InputShape()); + const int32_t depth_multiplier = out_channels / in_channels; + auto depth_multipler_attr = rewriter.getI32IntegerAttr(depth_multiplier); + + // + // build tfl + //===------- + + auto bias = BuildEmptyBias(rewriter, op->getLoc(), data); + + auto tfl_faf_none = rewriter.getStringAttr("NONE"); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), op.getLhs(), op.getRhs(), bias, + tfl_h_dilation, tfl_w_dilation, tfl_faf_none, tfl_padding, tfl_h_stride, + tfl_w_stride, depth_multipler_attr); + + return success(); +} + class LegalizeConv3D : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -163,9 +244,10 @@ LogicalResult LegalizeConv3D::matchAndRewrite( mhlo::ConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const { // Parse mhlo.convolution attrs into cc types. - const ConvData data(op); + const ConvView data(op); - if (IsConvLegal(op) || data.InputLayout().Rank() != 5) { + if (IsConvLegal(op) || !IsStandardConv(data) || + data.InputLayout().Rank() != 5) { return failure(); } @@ -211,11 +293,454 @@ LogicalResult LegalizeConv3D::matchAndRewrite( return success(); } +//===----------------------------------------------------------------------===// +// mhlo.convolution -> TFL::ResizeBilinearOp +//===----------------------------------------------------------------------===// + +// Convert a 2d mhlo.convolution op to a tfl.resize_bilinear +class ConvertNonTrivialConvToResizeBilinearOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult ConvertNonTrivialConvToResizeBilinearOp::matchAndRewrite( + mhlo::ConvolutionOp conv_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + const ConvView data(conv_op); + bool align_corners; + if (!MatchWithResizeBilinearOp(data, align_corners)) { + return rewriter.notifyMatchFailure( + conv_op, "op does not match with resize_bilinear op"); + } + + // The output size attribute is an array of 32bit values. + SmallVector output_shape_i32; + for (int64_t spatial_dim : data.InputLayout().Spatials()) { + output_shape_i32.push_back( + static_cast(data.OutputShape()[spatial_dim])); + } + Value output_sizes_attr = rewriter.create( + conv_op.getLoc(), rewriter.getI32TensorAttr(output_shape_i32)); + // The value of half_pixel_centers couldn't be inferred from the IR and XLA + // only support half_pixel_centers=True as in 01/11/2022. Here + // half_pixel_centers=False is hardcoded. + rewriter.replaceOpWithNewOp( + conv_op, conv_op.getType(), conv_op.getLhs(), output_sizes_attr, + /*align_corners=*/rewriter.getBoolAttr(align_corners), + /*half_pixel_centers=*/rewriter.getBoolAttr(false)); + + return success(); +} + +//===----------------------------------------------------------------------===// +// mhlo.convolution -> TFL::TransposeConv2dOp +//===----------------------------------------------------------------------===// + +// Convert a 2d mhlo.convolution op to a tfl.transpose_conv2d +class ConvertNonTrivialConvToTransposeConvOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult ConvertNonTrivialConvToTransposeConvOp::matchAndRewrite( + mhlo::ConvolutionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + const ConvView data(op); + + // + // Test if the op is a supported non-trivial convolution. + //===----- + + if (!IsSupportedNonTrivialConv(data)) { + return rewriter.notifyMatchFailure(op, "Not a non-trivial convolution."); + } + + // For depthwise and group convolutions, feature_group_count != 1 + if (op.getFeatureGroupCount() != 1) { + // Depthwise or Group convolution is not supported yet. + return rewriter.notifyMatchFailure( + op, "group or depthwise convolution is not supported"); + } + + // + // strides + //===----- + + // TFL::TravsposeConv2D applies strides on LHS. strides == lhs_dilation + auto strides = data.InputDilations(); + auto tfl_h_stride = rewriter.getI32IntegerAttr(strides[0]); + auto tfl_w_stride = rewriter.getI32IntegerAttr(strides[1]); + + // + // padding + //===----- + + std::string padding; + SmallVector padding_array; + for (auto& padding : data.Padding()) { + padding_array.push_back(padding.Lo()); + padding_array.push_back(padding.Hi()); + } + + if (IsTransposeConvPaddingValid(op, /*num_spatial_dims*/ 2, strides, + padding_array)) { + padding = "VALID"; + } else if (IsTransposeConvPaddingSame(op, /*num_spatial_dims*/ 2, strides, + padding_array)) { + padding = "SAME"; + } else { + return rewriter.notifyMatchFailure(op, + "requires padding to be SAME or VALID"); + } + + // + // build tfl op + //===------- + + auto bias = BuildEmptyBias(rewriter, op->getLoc(), data); + auto tfl_faf_none = rewriter.getStringAttr("NONE"); + + // Need to reverse the kernel data inorder to run TFL::TransposeConv2d + // The axis along which to reverse. In this case, we want to mirror the + // kernel's spatial dimensions. + SmallVector kernel_spatial_dims_i32( + data.KernelLayout().Spatials().begin(), + data.KernelLayout().Spatials().end()); + Value axis = rewriter.create( + op.getLoc(), rewriter.getI32TensorAttr(kernel_spatial_dims_i32)); + + // Create the tfl::ReverseV2Op + auto filter = rewriter.create( + op.getLoc(), op.getRhs().getType(), op.getRhs(), axis); + + // Calculate the output size and shape for TFL::TransposeConv2dOp + SmallVector output_shape_i32(data.OutputShape().begin(), + data.OutputShape().end()); + + auto output_sizes = rewriter.create( + op.getLoc(), rewriter.getI32TensorAttr(output_shape_i32)); + + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), /*output_shape=*/output_sizes, + /*filter=*/filter, /*input=*/op.getLhs(), /*bias=*/bias, + /*padding=*/rewriter.getStringAttr(padding), + /*stride_h=*/tfl_h_stride, /*stride_w=*/tfl_w_stride, + /*fused_activation_function=*/tfl_faf_none); + + return success(); +} + +//===----------------------------------------------------------------------===// + +class SliceDepthwiseTransposedConvolution + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, + PatternRewriter& rewriter) const final; +}; + +// Pattern rewriter to match a depthwise transposed convolution and rewrite it +// to depth-times slices of input and filter to perform the transposed +// convolution on individual slices of tensors and concatenate the results of. +// the convolutions. This is a. workaround because the TFLite runtime doesn't +// support depthwise-transposed-conv op natively. +LogicalResult SliceDepthwiseTransposedConvolution::matchAndRewrite( + mhlo::ConvolutionOp conv_op, PatternRewriter& rewriter) const { + const ConvView data(conv_op); + + // + // Test if the op is a supported non-trivial convolution. + //===----- + if (!IsSupportedNonTrivialConv(data)) { + return rewriter.notifyMatchFailure(conv_op, + "Not a non-trivial convolution."); + } + + // These checks narrow down the support to depthwise transpose conv2d. + mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); + const int64_t input_feature_dimension = dnums.getInputFeatureDimension(); + const int64_t input_channels = + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); + const int64_t feature_group_count = conv_op.getFeatureGroupCount(); + const int64_t kernel_input_feature_dimension = + dnums.getKernelInputFeatureDimension(); + const int64_t kernel_input_channels = + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); + const int64_t kernel_output_feature_dimension = + dnums.getKernelOutputFeatureDimension(); + const int64_t kernel_output_channels = + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_output_feature_dimension); + + // To support a depthwise convolution, we need- + // 1. feature_group_count != 1 (except when input_channels==1) + // 2. feature_group_count == input_channels + // 3. kernel_input_channels == 1 + // 4. kernel_output_channels % kernel_input_channels == 0 + if (feature_group_count == 1) { + return rewriter.notifyMatchFailure(conv_op, "Not a depthwise convolution"); + } + + if (input_channels != feature_group_count) { + return rewriter.notifyMatchFailure( + conv_op, "Not a detphwise transposed convolution"); + } + + if (MatchWithResizeBilinearOp(data)) { + return rewriter.notifyMatchFailure( + conv_op, "Op will be legalized to ResizeBilinearOp"); + } + + if ((kernel_output_channels % feature_group_count != 0) || + (kernel_input_channels != 1)) { + return rewriter.notifyMatchFailure( + conv_op, "Not a supported detphwise transposed convolution"); + } + + // This needs to be checked because the TFLite runtime generated incorrect + // results for depthwise transpose convolutions with non-1 channel + // multiplier. + if ((kernel_output_channels / feature_group_count) != 1) { + return rewriter.notifyMatchFailure( + conv_op, + "Unsupported detphwise transpose convolution with non-1 channel " + "multiplier"); + } + + // Slicing with dynamic offsets (helper method advised) + auto create_slice = [&](mlir::Value tensor, int64_t depth_idx, + int64_t channel_idx, + bool is_kernel = false) -> mlir::Value { + auto tensor_shape = + mlir::cast(tensor.getType()).getShape().vec(); + + // Calculate offsets based on depth_idx, channel_idx and tensor_shape + llvm::SmallVector start_indices(tensor_shape.size(), 0); + auto limit_indices = tensor_shape; + const llvm::SmallVector strides(tensor_shape.size(), 1); + start_indices[channel_idx] = depth_idx; + if (is_kernel) { + // kernel can have a channel_multiplier that needs to be accounted for + limit_indices[channel_idx] = + depth_idx + (kernel_output_channels / feature_group_count); + } else { + limit_indices[channel_idx] = depth_idx + 1; + } + return rewriter.create( + conv_op.getLoc(), tensor, rewriter.getI64TensorAttr(start_indices), + rewriter.getI64TensorAttr(limit_indices), + rewriter.getI64TensorAttr(strides)); + }; + + // Storage for smaller convolution results + llvm::SmallVector conv_results; + + // Iterative Slicing and Convolutions + for (int i = 0; i < feature_group_count; ++i) { + auto sliced_input = + create_slice(conv_op.getLhs(), i, input_feature_dimension); + auto sliced_kernel = create_slice(conv_op.getRhs(), i, + kernel_output_feature_dimension, true); + + // Calculate convolution output_type based on sliced_input and + // sliced_kernel + auto output_type = mlir::cast(conv_op->getResult(0).getType()); + auto new_output_shape = output_type.getShape().vec(); + new_output_shape[dnums.getOutputFeatureDimension()] /= feature_group_count; + auto new_output_type = + RankedTensorType::get(new_output_shape, output_type.getElementType()); + + // Create a Smaller Convolution (Ensure compatibility) + auto conv_result = rewriter.create( + conv_op.getLoc(), new_output_type, sliced_input, sliced_kernel, + conv_op.getWindowStridesAttr(), conv_op.getPaddingAttr(), + conv_op.getLhsDilationAttr(), conv_op.getRhsDilationAttr(), + conv_op.getWindowReversalAttr(), conv_op.getDimensionNumbers(), + /*feature_group_count*/ 1, /*batch_group_count*/ 1, + conv_op.getPrecisionConfigAttr()); + + conv_results.push_back(conv_result); + } + + auto final_output = rewriter.create( + conv_op.getLoc(), conv_results, + rewriter.getI64IntegerAttr(dnums.getOutputFeatureDimension())); + rewriter.replaceOp(conv_op, final_output.getResult()); + return success(); +} + +//===----------------------------------------------------------------------===// + +// Convert a 1-D convolution into a 2-D convolution (which TF supports) so that +// it can be rewritten by the pattern `Convert2DConvOp`. +class Conv1DToConv2D : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ConvolutionOp op, + PatternRewriter& rewriter) const final; +}; + +std::tuple, Layout> InsertTrivialSpatialDim( + const Layout& layout, ArrayRef shape) { + // Make new Layout with extra spatial dimension. + const int64_t last_spatial = layout.Spatials()[layout.Rank() - 3]; + const int64_t new_dim1 = (layout.SpecialDim1() > last_spatial) + ? layout.SpecialDim1() + 1 + : layout.SpecialDim1(); + const int64_t new_dim2 = (layout.SpecialDim2() > last_spatial) + ? layout.SpecialDim2() + 1 + : layout.SpecialDim2(); + + llvm::SmallVector new_spatials(layout.Spatials()); + const int64_t new_last_spatial = new_spatials.back() + 1; + new_spatials.push_back(new_last_spatial); + + // Get new shape. + llvm::SmallVector new_shape(shape.size() + 1, 1); + new_shape[new_dim1] = layout.SpecialDim1(shape); + new_shape[new_dim2] = layout.SpecialDim2(shape); + for (auto new_spatial : new_spatials) { + if (new_spatial == new_last_spatial) { + continue; + } + new_shape[new_spatial] = shape[new_spatial]; + } + return std::tuple(new_shape, Layout(new_dim1, new_dim2, new_spatials)); +} + +LogicalResult Conv1DToConv2D::matchAndRewrite(mhlo::ConvolutionOp op, + PatternRewriter& rewriter) const { + const ConvView view(op); + + if (view.InputLayout().Rank() != 3) { + return rewriter.notifyMatchFailure(op, "Not 1D conv."); + } + + if (!IsInputDilationSupported(view)) { + return rewriter.notifyMatchFailure(op, "Expects trivial lhs dims."); + } + + if (!AreShapesSupported(view)) { + return rewriter.notifyMatchFailure(op, "Expects static dims."); + } + + if (!IsWindowReversalSupported(view)) { + return rewriter.notifyMatchFailure(op, "Expects window reversal trivial."); + } + + if (!view.InputLayout().AreSpatialsIota() || + !view.KernelLayout().AreSpatialsIota() || + !view.OutputLayout().AreSpatialsIota()) { + return rewriter.notifyMatchFailure(op, + "Expects well formed spatials dims."); + } + + // + // Transpose and reshape the input and kernel + //=----- + + // Add new trivial spatial dimension to input (LHS). + auto [lhs_new_shape, lhs_new_layout] = + InsertTrivialSpatialDim(view.InputLayout(), view.InputShape()); + auto lhs_new_type = op.getLhs().getType().clone(lhs_new_shape); + auto new_lhs = + rewriter.create(op.getLoc(), lhs_new_type, op.getLhs()); + + // Add new trivial spatial dimension to kernel. + auto [rhs_new_shape, rhs_new_layout] = + InsertTrivialSpatialDim(view.KernelLayout(), view.KernelShape()); + auto rhs_new_type = op.getRhs().getType().clone(rhs_new_shape); + auto new_rhs = + rewriter.create(op.getLoc(), rhs_new_type, op.getRhs()); + + // Add new trivial spatial dimension to output (insert reshape later). + auto [out_new_shape, out_new_layout] = + InsertTrivialSpatialDim(view.OutputLayout(), view.OutputShape()); + auto out_new_type = op.getResult().getType().clone(out_new_shape); + + // + // Create 2d equivalents for 1d convolution attributes. + //=----- + + // Window Strides + llvm::SmallVector strides_2d; + strides_2d.push_back(view.Strides()[0]); + strides_2d.push_back(1); + auto strides_2d_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), strides_2d); + + // Padding + SmallVector padding_2d; + const auto& dim_pad = view.Padding()[0]; + padding_2d.push_back(dim_pad.Lo()); + padding_2d.push_back(dim_pad.Hi()); + padding_2d.push_back(0); + padding_2d.push_back(0); + auto padding_2d_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2, 2}, rewriter.getI64Type()), padding_2d); + + // LHS dilation + SmallVector lhs_dilation_2d(2, 1); + auto lhs_dilation_2d_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), lhs_dilation_2d); + + // RHS dilation + SmallVector rhs_dilation_2d; + rhs_dilation_2d.push_back(view.KernelDilations()[0]); + rhs_dilation_2d.push_back(1); + auto rhs_dilation_2d_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getI64Type()), rhs_dilation_2d); + + auto window_reversal_2d_attr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, rewriter.getIntegerType(1)), + SmallVector({false, false})); + + // New dnums. + auto dnums_2d = mhlo::ConvDimensionNumbersAttr::get( + rewriter.getContext(), lhs_new_layout.SpecialDim1(), + lhs_new_layout.SpecialDim2(), lhs_new_layout.Spatials(), + rhs_new_layout.SpecialDim1(), rhs_new_layout.SpecialDim2(), + rhs_new_layout.Spatials(), out_new_layout.SpecialDim1(), + out_new_layout.SpecialDim2(), out_new_layout.Spatials()); + + // + // Build 2-D convolution with reshaped output. + //=----- + + auto conv2d_op = rewriter.create( + op.getLoc(), out_new_type, new_lhs, new_rhs, strides_2d_attr, + padding_2d_attr, lhs_dilation_2d_attr, rhs_dilation_2d_attr, + window_reversal_2d_attr, dnums_2d, op.getFeatureGroupCount(), + op.getBatchGroupCount(), op.getPrecisionConfigAttr()); + + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + conv2d_op.getResult()); + return success(); +} + } // namespace -void PopulateConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, - ConversionTarget& target) { - patterns.add(ctx); +void PopulateLegalizeConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); target.addDynamicallyLegalOp(IsConvLegal); } + +void PopulatePrepareConvPatterns(MLIRContext* ctx, + RewritePatternSet& patterns) { + patterns.add(ctx); +} } // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h index 6b5b4f591eea04..0f741d9ce0b70b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h @@ -17,24 +17,19 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { -// Legalizes mhlo.convolutions to the corresponding tfl op. -// -// Only considers convolutions with tfl-native layout and trivial (no) -// padding. It is expected that convolutions will re-layouted in upstream -// prepare pass. Additionally it is expected that padding will be pulled out -// into an explicit mhlo.pad op in said prepare pass. +// Prepares mhlo.convolutions and legalizes to the corresponding tfl op. // // Note: "tfl-native" layouts are as follows: -// 2D : [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] -// 3D : [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] +// 2D : [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// 3D : [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] +// 2D (depthwise) : [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] // // Matches: mhlo.convolution -// layout: tfl-native -// padding: trivial (all 0) +// layout: any (will transpose to tfl-native) +// padding: any (will pull into explicit pad_op) // lhs_dilations: trivial (all 1) // rhs_dilations: any // strides: any @@ -54,8 +49,10 @@ namespace mlir::odml { // if rank == 4: tfl.conv_2D // else: // tfl.transpose_conv TODO: b/352954597 - Add support. -void PopulateConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, - ConversionTarget& target); +void PopulateLegalizeConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +void PopulatePrepareConvPatterns(MLIRContext* ctx, RewritePatternSet& patterns); } // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc index b8fbdc2a3cac66..70c0ab5acc5f1e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.cc @@ -14,93 +14,25 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h" +#include #include +#include +#include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { -llvm::SmallVector Layout::GetPermForReLayout( - const Layout& to_layout) const { - llvm::SmallVector perm(to_layout.Rank()); - perm[to_layout.SpecialDim1()] = SpecialDim1(); - perm[to_layout.SpecialDim2()] = SpecialDim2(); - for (const auto [to_spatial, from_spatial] : - llvm::zip(to_layout.Spatials(), Spatials())) { - perm[to_spatial] = from_spatial; - } - return perm; -} - -llvm::SmallVector Layout::PermuteShape( - const Layout& to_layout, llvm::ArrayRef shape) const { - llvm::SmallVector new_shape(to_layout.Rank()); - const auto perm = GetPermForReLayout(to_layout); - for (const auto [ind, val] : llvm::enumerate(perm)) { - new_shape[ind] = shape[val]; - } - return new_shape; -} - -bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const { - return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2; -} - -bool Layout::AreSpatialsIota() const { - llvm::ArrayRef spatials = Spatials(); - return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) { - return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1); - }); -} - -llvm::SmallVector ResolveStridesOrDilations( - const int64_t num_spatials, - std::optional opt_attr) { - if (!opt_attr.has_value()) { - return llvm::SmallVector(num_spatials, 1); - } - auto attr = opt_attr.value(); - if (attr.isSplat()) { - return llvm::SmallVector(num_spatials, - attr.getSplatValue()); - } - return llvm::SmallVector(attr.getValues()); -} - -llvm::SmallVector ResolvePadding( - const int64_t num_spatials, - std::optional opt_padding) { - llvm::SmallVector res; - if (!opt_padding.has_value()) { - for (int i = 0; i < num_spatials; ++i) { - res.push_back(DimPadding(0, 0)); - } - return res; - } - auto padding = opt_padding.value(); - if (padding.isSplat()) { - const int64_t val = padding.getSplatValue(); - for (int i = 0; i < num_spatials; ++i) { - res.push_back(DimPadding(val, val)); - } - return res; - } - int64_t prev; - for (const auto [ind, val] : llvm::enumerate(padding.getValues())) { - const int64_t side = ind % 2; - if (side == 1) { - res.push_back(DimPadding(prev, val)); - } - prev = val; - } - return res; -} - llvm::SmallVector ResolveWindowReversal( const int64_t num_spatials, std::optional opt_reversals) { @@ -115,7 +47,7 @@ llvm::SmallVector ResolveWindowReversal( return llvm::SmallVector(reversals.getValues()); } -ConvData::ConvData(mhlo::ConvolutionOp op) +ConvView::ConvView(mhlo::ConvolutionOp op) : input_layout_( Layout{op.getDimensionNumbers().getInputBatchDimension(), op.getDimensionNumbers().getInputFeatureDimension(), @@ -152,4 +84,187 @@ ConvData::ConvData(mhlo::ConvolutionOp op) ResolveWindowReversal(num_spatials, op.getWindowReversal()); } +Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op) { + const ConvView data(op); + const auto rank = data.InputLayout().Rank(); + auto input_spatials = data.InputLayout().Spatials(); + + llvm::SmallVector hi_padding(rank, 0); + llvm::SmallVector lo_padding(rank, 0); + + for (const auto& [ind, dim_padding] : llvm::enumerate(data.Padding())) { + const size_t cur_input_spatial = input_spatials[ind]; + hi_padding[cur_input_spatial] = dim_padding.Hi(); + lo_padding[cur_input_spatial] = dim_padding.Lo(); + } + + const llvm::SmallVector interior_padding(rank, 0); + + auto padding_attr_type = RankedTensorType::get({rank}, b.getI64Type()); + auto hi_padding_attr = + DenseIntElementsAttr::get(padding_attr_type, hi_padding); + auto lo_padding_attr = + DenseIntElementsAttr::get(padding_attr_type, lo_padding); + auto interior_padding_attr = + DenseIntElementsAttr::get(padding_attr_type, interior_padding); + + auto padding_value_type = RankedTensorType::get({}, data.ElementType()); + auto padding_value_attr = b.getZeroAttr(padding_value_type); + auto padding_value_op = + b.create(op->getLoc(), padding_value_attr); + + auto pad_op = b.create(padding_value_op->getLoc(), op.getLhs(), + padding_value_op, lo_padding_attr, + hi_padding_attr, interior_padding_attr); + + return pad_op; +} + +bool MatchWithResizeBilinearOp(const ConvView& data, bool& align_corners) { + if (data.InputLayout().Rank() != 4 || data.KernelLayout().Rank() != 4 || + data.OutputLayout().Rank() != 4 || + data.InputLayout().Spatials() != data.OutputLayout().Spatials()) { + return false; + } + + if (data.InputDilations().size() != 2 || + !(llvm::all_of(data.KernelDilations(), [](auto d) { return d == 1; })) || + data.Strides().size() != 2 || data.Padding().size() != 2) { + return false; + } + + // This is based on method in compiler/tf2xla/kernels/image_resize_ops.cc + auto can_convert_to_bilinear = + [](bool align_corners, int64_t dilation, int64_t padding, int64_t stride, + int64_t input_spatial, int64_t output_spatial) { + int64_t input_spatial_size = + align_corners ? input_spatial - 1 : input_spatial; + int64_t output_spatial_size = + align_corners ? output_spatial - 1 : output_spatial; + + int64_t gcd = std::gcd(static_cast(input_spatial_size), + static_cast(output_spatial_size)); + + if ((gcd == 0) || (input_spatial_size % gcd != 0) || + (input_spatial_size / gcd != stride) || (dilation - 1 != padding)) { + return false; + } + return true; + }; + + if (data.InputDilations()[0] != 1 && data.InputDilations()[1] == 1) { + if (can_convert_to_bilinear( + /*align_corners=*/true, data.InputDilations()[0], + data.Padding()[0].Lo(), data.Strides()[0], + data.InputShape()[data.InputLayout().Spatials()[0]], + data.OutputShape()[data.OutputLayout().Spatials()[0]])) { + align_corners = true; + return true; + } else if (can_convert_to_bilinear( + /*align_corners=*/false, data.InputDilations()[0], + data.Padding()[0].Lo(), data.Strides()[0], + data.InputShape()[data.InputLayout().Spatials()[0]], + data.OutputShape()[data.OutputLayout().Spatials()[0]])) { + align_corners = false; + return true; + }; + } else if (data.InputDilations()[0] == 1 && data.InputDilations()[1] != 1) { + if (can_convert_to_bilinear( + /*align_corners=*/true, data.InputDilations()[1], + data.Padding()[1].Lo(), data.Strides()[1], + data.InputShape()[data.InputLayout().Spatials()[1]], + data.OutputShape()[data.OutputLayout().Spatials()[1]])) { + align_corners = true; + return true; + } else if (can_convert_to_bilinear( + /*align_corners=*/false, data.InputDilations()[1], + data.Padding()[1].Lo(), data.Strides()[1], + data.InputShape()[data.InputLayout().Spatials()[1]], + data.OutputShape()[data.OutputLayout().Spatials()[1]])) { + align_corners = false; + return true; + }; + } + + return false; +} + +bool IsTransposeConvPaddingValid(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding) { + auto dnums = conv_op.getDimensionNumbers(); + // The newly added spatial dimension requires zero left and right padding. + ArrayRef input_spatial_dims = dnums.getInputSpatialDimensions(); + ArrayRef kernel_spatial_dims = dnums.getKernelSpatialDimensions(); + ArrayRef output_spatial_dims = dnums.getOutputSpatialDimensions(); + + for (size_t i = 0; i < num_spatial_dims; ++i) { + int64_t stride = strides[i]; + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t kernel_size = mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); + + // stablehlo.convolution op needs explicit padding to be set to model any + // Transposed-Convolution in JAX/PT. Checking to see if- + // 1. Pre set padding matches to the desired padding + // 2. Output size respects the `VALID` padding scenario + if ((padding[2 * i] == padding[2 * i + 1]) && + (((kernel_size - 1) != padding[2 * i]) || + (output_size != (stride * (input_size - 1)) + kernel_size))) { + // padding[2 * i] == padding[2 * i + 1] means equal padding is applied + // on both sides of a spatial dimension. + // This happens when kernel_dim >= stride + return false; + } else if ((padding[2 * i] != padding[2 * i + 1]) && + (((kernel_size - 1) != padding[2 * i]) || + ((stride - 1) != padding[2 * i + 1]) || + (output_size != (stride * input_size)))) { + return false; + } + } + + return true; +} + +bool IsTransposeConvPaddingSame(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding) { + auto dnums = conv_op.getDimensionNumbers(); + + // The newly added spatial dimension requires zero left and right padding. + ArrayRef input_spatial_dims = dnums.getInputSpatialDimensions(); + ArrayRef output_spatial_dims = dnums.getOutputSpatialDimensions(); + for (size_t i = 0; i < num_spatial_dims; ++i) { + // In some cases the total padding is odd, so we have 1 leftover, which is + // why below we check pad_delta > 1. + int64_t pad_delta = std::abs(padding[2 * i] - padding[2 * i + 1]); + if (pad_delta > 1) { + return false; + } + int64_t stride = strides[i]; + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); + // The reason for the below check is as follows: + // When computing the output, we have the following relation between + // o - output dim size, i - input dim size, s - stride, P - total pads + // o = (i-k+1) + (s-1)(i-1) + P + // Where the first term is the kernel applications on the input, + // the second term is the additional applications from the stride + // and P is a term that captures the total padding. After expanding we get + // o = si + k - s + 2 + P + // Here JAX sets P to cancel k-s+2, leading to the expression below + if (output_size != input_size * stride) { + return false; + } + } + return true; +} + } // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h index e0f7a00f731653..ed8b06e036d816 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h @@ -15,10 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ +#include +#include +#include + #include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // Helpers for working with mhlo.convolution attrs in the mlir api as @@ -26,91 +33,7 @@ limitations under the License. namespace mlir::odml { -// Generic class that wraps the "layout" of a convolution parameter. -// Both kernel (e.g. [o, 0, 1, i]) and input/output (e.g. [b, 0, 1, f]) -// share the same structure just with different terminology for the -// batch/feature/input_feature/output_feature dims. -class Layout { - public: - llvm::ArrayRef Spatials() const { return spatials_; } - - int64_t NumSpatials() const { return spatials_.size(); } - - int64_t Rank() const { return NumSpatials() + 2; } - - Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef spatials) - : special_dim1_(special_dim1), - special_dim2_(special_dim2), - spatials_(spatials) {} - - // Gets index of first special dim. The batch dim for input and outputs, - // or the output feature dim for the kernel. - int64_t SpecialDim1() const { return special_dim1_; } - - // Conveniance accesor for getting the dimension size of the first - // special dimension from a shape. - int64_t SpecialDim1(llvm::ArrayRef shape) const { - return shape[special_dim1_]; - } - - // Gets index of second special dim. The feature dim for input and outputs, - // or the input feature dim for the kernel. - int64_t SpecialDim2() const { return special_dim2_; } - - // Convenience accesor for getting the dimension size of the second - // special dimension from a shape. - int64_t SpecialDim2(llvm::ArrayRef shape) const { - return shape[special_dim2_]; - } - - // Conveniance method for equality checking special dims. - bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const; - - // Determines if the spatial dimensions are all adjacent and in - // ascending order (HWD). - bool AreSpatialsIota() const; - - // Gets a "permutation array" to be used for transposing a tensor - // of "this" layout to the given layout. A permutation array is some - // permutation of [0, 1, i...] for i < rank(layout). Assumes - // "this" and given layout have the same rank. - llvm::SmallVector GetPermForReLayout( - const Layout& to_layout) const; - - // Permutes given shape based on the permutaion implied to take this Layout to - // the given one. - llvm::SmallVector PermuteShape(const Layout& to_layout, - ArrayRef shape) const; - - bool operator==(const Layout& other) const { - return SpecialDim1() == other.SpecialDim1() && - SpecialDim2() == other.SpecialDim2() && - Spatials() == other.Spatials(); - } - - bool operator!=(const Layout& other) const { return !(*this == other); } - - private: - int64_t special_dim1_; - int64_t special_dim2_; - llvm::SmallVector spatials_; -}; - -// Wrapper for the padding attrs along a single dimension. -class DimPadding { - public: - int64_t Hi() const { return hi_; } - - int64_t Lo() const { return lo_; } - - DimPadding(int64_t hi, int64_t lo) : hi_(hi), lo_(lo) {} - - private: - int64_t hi_; - int64_t lo_; -}; - -class ConvData { +class ConvView { public: // int for each spatial dim. Default 1. llvm::ArrayRef Strides() const { return strides_; } @@ -145,7 +68,7 @@ class ConvData { mlir::Type ElementType() const { return element_type_; } - explicit ConvData(mhlo::ConvolutionOp op); + explicit ConvView(mhlo::ConvolutionOp op); private: llvm::SmallVector strides_; @@ -171,7 +94,11 @@ class ConvData { mlir::Type element_type_; }; -inline bool ValidStandardConvOutFeatureDims(const ConvData& data) { +inline bool HasSupportedRank(const ConvView& data) { + return data.InputLayout().Rank() == 4 || data.InputLayout().Rank() == 5; +} + +inline bool HasSupportedOutFeatureDims(const ConvView& data) { const int64_t kernel_out_features = data.KernelLayout().SpecialDim2(data.KernelShape()); const int64_t out_features = @@ -179,7 +106,53 @@ inline bool ValidStandardConvOutFeatureDims(const ConvData& data) { return kernel_out_features == out_features; } -inline bool ValidStandardConvInFeatureDims(const ConvData& data) { +inline bool IsTrivialConv(const ConvView& data) { + return llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; }); +} + +// +// Supported non-trivial conv predicates +//=----- + +bool MatchWithResizeBilinearOp(const ConvView& data, bool& align_corners); + +inline bool MatchWithResizeBilinearOp(const ConvView& data) { + bool align_corners = false; + return MatchWithResizeBilinearOp(data, align_corners); +} + +bool IsTransposeConvPaddingValid(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +bool IsTransposeConvPaddingSame(mhlo::ConvolutionOp conv_op, + size_t num_spatial_dims, + const ArrayRef& strides, + const ArrayRef& padding); + +inline bool IsSupportedNonTrivialConv(const ConvView& data) { + // Only non-trivial 2d convolutions are supported. + const bool valid_rank = data.InputLayout().Rank() == 4; + + // Negative padding is unsupported. + bool has_nagative_padding = llvm::all_of( + data.Padding(), + [](const DimPadding& p) { return p.Hi() < 0 || p.Lo() < 0; }); + + return (valid_rank && !IsTrivialConv(data) && !has_nagative_padding); +} + +inline bool IsSupportedNonTrivialConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsSupportedNonTrivialConv(data); +} + +// +// Standard conv predicates +//=----- + +inline bool HasStandardConvInFeatureDims(const ConvView& data) { // kernel_in_features * feature_groups = input_features by definition. const int64_t input_features = data.InputLayout().SpecialDim2(data.InputShape()); @@ -192,25 +165,43 @@ inline bool ValidStandardConvInFeatureDims(const ConvData& data) { return !trivial_kernel_in_features && (!is_grouped_conv || rank == 4); } -inline bool HasStandardFeatureGroup(const ConvData& data) { - return ValidStandardConvInFeatureDims(data) && - ValidStandardConvOutFeatureDims(data); +inline bool IsStandardConv(const ConvView& data) { + return HasSupportedRank(data) && IsTrivialConv(data) && + HasStandardConvInFeatureDims(data) && HasSupportedOutFeatureDims(data); } // Does this convolution map to a standard conv_2d or conv_3d -// (not depthwise or tranpose conv). -inline bool IsStandardConv(const ConvData& data) { - const bool trivial_lhs_dilate = - llvm::all_of(data.InputDilations(), [](auto d) { return d == 1; }); +// (not depthwise or tranpose conv)? +inline bool IsStandardConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsStandardConv(data); +} - return trivial_lhs_dilate && HasStandardFeatureGroup(data); +// +// Depthwise conv predicates +//=----- + +inline bool IsDepthwiseConv(const ConvView& data) { + const bool valid_rank = data.InputLayout().Rank() == 4; + if (!valid_rank || !HasSupportedOutFeatureDims(data) || + !IsTrivialConv(data)) { + return false; + } + const int64_t in_channel_dim = + data.InputLayout().SpecialDim2(data.InputShape()); + return data.FeatureGroupCount() == in_channel_dim; } -inline bool IsStandardConv(mhlo::ConvolutionOp op) { - const ConvData data(op); - return IsStandardConv(data); +// Does this convolution map to depthwise conv? +inline bool IsDepthwiseConv(mhlo::ConvolutionOp op) { + const ConvView data(op); + return IsDepthwiseConv(data); } +// +// Tfl native layouts +//=----- + inline int64_t DnumRank(mhlo::ConvDimensionNumbersAttr dnums) { return dnums.getInputSpatialDimensions().size() + 2; } @@ -225,7 +216,7 @@ inline Layout GetTFLNativeInputOrOutputLayout( return GetTFLNativeInputOrOutputLayout((DnumRank(dnums))); } -inline Layout GetTFLNativeKernelLayout(int64_t rank) { +inline Layout GetTFLNativeStandardConvKernelLayout(int64_t rank) { if (rank != 5) { auto spatials = llvm::to_vector(llvm::seq(1, rank - 1)); return Layout(rank - 1, 0, spatials); @@ -234,19 +225,38 @@ inline Layout GetTFLNativeKernelLayout(int64_t rank) { return Layout(rank - 2, rank - 1, spatials); } -inline Layout GetTFLNativeKernelLayout(mhlo::ConvDimensionNumbersAttr dnums) { - return GetTFLNativeKernelLayout(DnumRank(dnums)); +inline Layout GetTFLNativeDepthwiseConvKernelLayout() { + return Layout(0, 3, {1, 2}); } -inline bool IsTFLNativeLayout(const ConvData& data) { - const auto rank = data.InputLayout().Rank(); +inline Layout GetTFLNativeStandardConvKernelLayout( + mhlo::ConvDimensionNumbersAttr dnums) { + return GetTFLNativeStandardConvKernelLayout(DnumRank(dnums)); +} + +inline bool IsTFLNativeLayout(const ConvView& data) { + const int64_t rank = data.KernelLayout().Rank(); const auto native_io_layout = GetTFLNativeInputOrOutputLayout(rank); - const auto native_kernel_layout = GetTFLNativeKernelLayout(rank); + + std::optional native_kernel_layout = std::nullopt; + if (IsDepthwiseConv(data)) { + native_kernel_layout = GetTFLNativeDepthwiseConvKernelLayout(); + } else if (IsStandardConv(data) || IsSupportedNonTrivialConv(data)) { + native_kernel_layout = GetTFLNativeStandardConvKernelLayout(rank); + } + if (!native_kernel_layout.has_value()) { + return false; + } + return data.InputLayout() == native_io_layout && - data.KernelLayout() == native_kernel_layout && + data.KernelLayout() == *native_kernel_layout && data.OutputLayout() == native_io_layout; } +// +// ConvDimensionNumbers utils +//=----- + inline mhlo::ConvDimensionNumbersAttr CloneDnumsWithInputLayout( OpBuilder& b, mhlo::ConvDimensionNumbersAttr dnums, const Layout& layout) { return mhlo::ConvDimensionNumbersAttr::get( @@ -278,6 +288,10 @@ inline mhlo::ConvDimensionNumbersAttr CloneDnumsWithOutputLayout( layout.SpecialDim2(), layout.Spatials()); } +// Wraps the lhs of given conv op in an explicit pad op matching the same +// behavior implicit in the paddings attribute. Gets result of new pad op. +Value CreatePadOpFromConvPadding(OpBuilder& b, mhlo::ConvolutionOp op); + } // namespace mlir::odml #endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_CONV_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc new file mode 100644 index 00000000000000..74aaa81519ea88 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h" + +#include +#include + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +class LegalizeIota : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::IotaOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +std::tuple +BuildRangeParams(Type e_type, int64_t iota_dim_size, OpBuilder& b) { + if (e_type.isInteger()) { + return std::tuple(BuildScalarDense(e_type, 0), + BuildScalarDense(e_type, iota_dim_size), + BuildScalarDense(e_type, 1)); + } + return std::tuple(BuildScalarDense(e_type, 0.0), + BuildScalarDense(e_type, iota_dim_size), + BuildScalarDense(e_type, 1.0)); +} + +LogicalResult LegalizeIota::matchAndRewrite( + mhlo::IotaOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto type = llvm::cast(op.getType()); + if (!type.getElementType().isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Must be int or float"); + } + + auto e_type = type.getElementType(); + const int64_t iota_dim_size = type.getDimSize(op.getIotaDimension()); + + auto [start, limit, delta] = + BuildRangeParams(e_type, iota_dim_size, rewriter); + + auto start_op = rewriter.create(op->getLoc(), start); + auto limit_op = rewriter.create(op->getLoc(), limit); + auto delta_op = rewriter.create(op->getLoc(), delta); + + auto range_type = RankedTensorType::get({iota_dim_size}, e_type); + auto range_op = rewriter.create(op->getLoc(), range_type, + start_op, limit_op, delta_op); + + if (type.getRank() == 1) { + rewriter.replaceOp(op, range_op); + return success(); + } + + // mhlo.iota allows filling ND tensors iota-style. Reshape and broadcast + // tfl 1D range output. + + llvm::SmallVector reshape_shape(type.getRank(), 1); + reshape_shape[op.getIotaDimension()] = iota_dim_size; + Value reshape_shape_cst = rewriter.create( + op->getLoc(), rewriter.getI64TensorAttr(reshape_shape)); + reshape_shape_cst = rewriter.create( + op->getLoc(), + llvm::cast(reshape_shape_cst.getType()) + .clone(rewriter.getI32Type()), + reshape_shape_cst); + + auto reshape_type = RankedTensorType::get(reshape_shape, e_type); + auto reshape_op = rewriter.create( + op->getLoc(), reshape_type, range_op, reshape_shape_cst); + + auto broad_cast_shape_cst = rewriter.create( + op->getLoc(), rewriter.getI64TensorAttr(type.getShape())); + + rewriter.replaceOpWithNewOp(op, type, reshape_op, + broad_cast_shape_cst); + + return success(); +} + +} // namespace + +void PopulateIotaPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + target.addIllegalOp(); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h new file mode 100644 index 00000000000000..a53bdeda2a2097 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ + +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateIotaPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_IOTA_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc new file mode 100644 index 00000000000000..3d67bbfd123b33 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.cc @@ -0,0 +1,111 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::odml { + +llvm::SmallVector Layout::GetPermForReLayout( + const Layout& to_layout) const { + llvm::SmallVector perm(to_layout.Rank()); + perm[to_layout.SpecialDim1()] = SpecialDim1(); + perm[to_layout.SpecialDim2()] = SpecialDim2(); + for (const auto [to_spatial, from_spatial] : + llvm::zip(to_layout.Spatials(), Spatials())) { + perm[to_spatial] = from_spatial; + } + return perm; +} + +llvm::SmallVector Layout::PermuteShape( + const Layout& to_layout, llvm::ArrayRef shape) const { + llvm::SmallVector new_shape(to_layout.Rank()); + const auto perm = GetPermForReLayout(to_layout); + for (const auto [ind, val] : llvm::enumerate(perm)) { + new_shape[ind] = shape[val]; + } + return new_shape; +} + +bool Layout::HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const { + return SpecialDim1() == special_dim1 && SpecialDim2() == special_dim2; +} + +bool Layout::AreSpatialsIota() const { + llvm::ArrayRef spatials = Spatials(); + return llvm::all_of(llvm::enumerate(spatials), [&](const auto& it) { + return it.index() == 0 || (it.value() == spatials[it.index() - 1] + 1); + }); +} + +llvm::SmallVector ResolveStridesOrDilations( + int64_t rank, std::optional opt_attr) { + if (!opt_attr.has_value()) { + return llvm::SmallVector(rank, 1); + } + auto attr = opt_attr.value(); + if (attr.isSplat()) { + return llvm::SmallVector(rank, attr.getSplatValue()); + } + return llvm::SmallVector(attr.getValues()); +} + +llvm::SmallVector ResolvePadding( + int64_t rank, std::optional opt_padding) { + llvm::SmallVector res; + if (!opt_padding.has_value()) { + for (int i = 0; i < rank; ++i) { + res.push_back(DimPadding(0, 0)); + } + return res; + } + auto padding = opt_padding.value(); + if (padding.isSplat()) { + const int64_t val = padding.getSplatValue(); + for (int i = 0; i < rank; ++i) { + res.push_back(DimPadding(val, val)); + } + return res; + } + int64_t prev; + for (const auto [ind, val] : llvm::enumerate(padding.getValues())) { + const int64_t side = ind % 2; + if (side == 1) { + res.push_back(DimPadding(prev, val)); + } + prev = val; + } + return res; +} + +bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k, + const DimPadding& pad) { + const int64_t pad_diff = pad.Hi() - pad.Lo(); + if (pad_diff > 1 || pad_diff < 0) { + return false; + } + const int64_t pad_total = pad.Lo() + pad.Hi(); + const int64_t out = (in + stride - 1) / stride; + const int effective_filter = (k - 1) * dilate + 1; + return ((out - 1) * stride + effective_filter) == in + pad_total; +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h new file mode 100644 index 00000000000000..3c2c8ae5ced600 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir::odml { + +// Class that encodes the "layout" of a tensor. Layouts, generically +// are some naming of the dimensions of a tensor. In all cases, 2 dimensions +// are "special" (e.g. batch / feature) and the rest are referred to as "spatial +// dims". When the special dims are batch and feature, batch is special dim 1 +// and feature is special dim 2. When special dims are input and output features +// (conv filter), input features is special dim 1 and output features is special +// dim 2. +class Layout { + public: + llvm::ArrayRef Spatials() const { return spatials_; } + + int64_t NumSpatials() const { return spatials_.size(); } + + int64_t Rank() const { return NumSpatials() + 2; } + + Layout(int64_t special_dim1, int64_t special_dim2, ArrayRef spatials) + : special_dim1_(special_dim1), + special_dim2_(special_dim2), + spatials_(spatials) {} + + // TODO: b/351437662 - Consider just using 2 arrays for the case where + // there are more than 2 special dims. + int64_t SpecialDim1() const { return special_dim1_; } + + // Conveniance accesor for getting the dimension size of the first + // special dimension from a shape. + int64_t SpecialDim1(llvm::ArrayRef shape) const { + return shape[special_dim1_]; + } + + int64_t SpecialDim2() const { return special_dim2_; } + + // Convenience accesor for getting the dimension size of the second + // special dimension from a shape. + int64_t SpecialDim2(llvm::ArrayRef shape) const { + return shape[special_dim2_]; + } + + // Conveniance method for equality checking special dims. + bool HasSpecialDims(int64_t special_dim1, int64_t special_dim2) const; + + // Determines if the spatial dimensions are all adjacent and in + // ascending order. + bool AreSpatialsIota() const; + + // Gets a "permutation array" to be used for transposing a tensor + // of "this" layout to the given layout. A permutation array is some + // permutation of [0, 1, i...] for i < rank(layout). Assumes + // "this" and given layout have the same rank. + llvm::SmallVector GetPermForReLayout( + const Layout& to_layout) const; + + // Permutes given shape based on the permutaion implied to take this Layout to + // the given one. + llvm::SmallVector PermuteShape(const Layout& to_layout, + ArrayRef shape) const; + + bool operator==(const Layout& other) const { + return SpecialDim1() == other.SpecialDim1() && + SpecialDim2() == other.SpecialDim2() && + Spatials() == other.Spatials(); + } + + bool operator!=(const Layout& other) const { return !(*this == other); } + + private: + int64_t special_dim1_; + int64_t special_dim2_; + llvm::SmallVector spatials_; +}; + +// Wrapper for the padding attrs along a single dimension. +class DimPadding { + public: + int64_t Hi() const { return hi_; } + + int64_t Lo() const { return lo_; } + + bool Trivial() const { return Hi() == 0 && Lo() == 0; } + + DimPadding(int64_t lo, int64_t hi) : lo_(lo), hi_(hi) {} + + private: + int64_t lo_; + int64_t hi_; +}; + +inline llvm::SmallVector UnrollI64Splat(DenseElementsAttr data) { + if (!data.isSplat()) { + return llvm::SmallVector(data.getValues()); + } + return llvm::SmallVector(data.getType().getNumElements(), + data.getSplatValue()); +} + +// Resolves optional strides or dilations attributes. If not present, +// will return trivial 1's vector. +llvm::SmallVector ResolveStridesOrDilations( + int64_t rank, std::optional opt_attr); + +// Resolves optional paddings attributes. If not present, will return +// trivial [0, 0] paddings on each dim. +llvm::SmallVector ResolvePadding( + int64_t rank, std::optional opt_padding); + +// Does the padding correspond to "SAME" on given dimension configuration. +// Assumes given dimension configuration is well formed. +bool IsSamePaddingOnDim(int64_t in, int64_t dilate, int64_t stride, int64_t k, + const DimPadding& pad); + +template +inline DenseElementsAttr BuildScalarDense(Type e_type, T val) { + auto type = RankedTensorType::get({}, e_type); + return DenseElementsAttr::get(type, val); +} + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_OP_UTIL_COMMON_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc index ac5e3ffa70d240..c25f27acebe9de 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -49,8 +50,8 @@ bool IsPadValCstZero(mhlo::PadOp op) { } DenseIntElementsAttr BuildTFLPaddingAttr(OpBuilder& b, mhlo::PadOp op) { - auto lows = UnrollSplat(op.getEdgePaddingLow()); - auto highs = UnrollSplat(op.getEdgePaddingHigh()); + auto lows = UnrollI64Splat(op.getEdgePaddingLow()); + auto highs = UnrollI64Splat(op.getEdgePaddingHigh()); llvm::SmallVector res; for (auto [l, h] : llvm::zip(lows, highs)) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc index 859cdfe90eb46f..cb004d3b44daf9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -28,16 +29,8 @@ ShapedType GetPaddingAttrType(mhlo::PadOp op) { return op.getEdgePaddingLow().getType(); } -llvm::SmallVector UnrollSplat(DenseElementsAttr data) { - if (!data.isSplat()) { - return llvm::SmallVector(data.getValues()); - } - return llvm::SmallVector(data.getType().getNumElements(), - data.getSplatValue()); -} - DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op) { - auto vals = UnrollSplat(op.getEdgePaddingLow()); + auto vals = UnrollI64Splat(op.getEdgePaddingLow()); auto starts = llvm::map_range( vals, [](auto v) -> int64_t { return (v >= 0) ? 0 : -1 * v; }); return DenseIntElementsAttr::get(GetPaddingAttrType(op), @@ -45,7 +38,7 @@ DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op) { } DenseIntElementsAttr SliceEndFromNegPadHighs(mhlo::PadOp op) { - auto vals = UnrollSplat(op.getEdgePaddingHigh()); + auto vals = UnrollI64Splat(op.getEdgePaddingHigh()); auto zip = llvm::zip(vals, op.getOperand().getType().getShape()); auto ends = llvm::map_range(zip, [](auto it) -> int64_t { return (std::get<0>(it) >= 0) ? std::get<1>(it) @@ -56,7 +49,7 @@ DenseIntElementsAttr SliceEndFromNegPadHighs(mhlo::PadOp op) { } DenseIntElementsAttr ReplaceNegsWithZero(DenseElementsAttr data) { - auto vals = UnrollSplat(data); + auto vals = UnrollI64Splat(data); auto res = llvm::map_range(vals, [](auto v) -> int64_t { return (v < 0) ? 0 : v; }); return DenseIntElementsAttr::get(data.getType(), llvm::to_vector(res)); @@ -64,8 +57,8 @@ DenseIntElementsAttr ReplaceNegsWithZero(DenseElementsAttr data) { bool AnyNegativePads(mhlo::PadOp op) { auto is_neg = [](int64_t v) { return v < 0; }; - auto lows_data = UnrollSplat(op.getEdgePaddingLow()); - auto highs_data = UnrollSplat(op.getEdgePaddingHigh()); + auto lows_data = UnrollI64Splat(op.getEdgePaddingLow()); + auto highs_data = UnrollI64Splat(op.getEdgePaddingHigh()); return llvm::any_of(lows_data, is_neg) || llvm::any_of(highs_data, is_neg); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h index aa0428f1040dd2..5041903941bbbd 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h @@ -21,8 +21,6 @@ limitations under the License. namespace mlir::odml { -llvm::SmallVector UnrollSplat(DenseElementsAttr data); - // Gets elements corresponding to slice starts from negative padding // values. DenseIntElementsAttr SliceStartFromNegPadLows(mhlo::PadOp op); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc new file mode 100644 index 00000000000000..a00ee33c45a8ca --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.cc @@ -0,0 +1,742 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h" + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +// filters, strides, padding, faf. +using TFLPoolAttrsT = std::tuple; + +bool AreDilationsSupported(const ReduceWindowView& op) { + auto is_one = [](int64_t v) { return v == 1; }; + return llvm::all_of(op.BaseDilations(), is_one) && + llvm::all_of(op.WindowDilations(), is_one); +} + +bool IsRankSupported(const ReduceWindowView& op) { return op.Rank() == 4; } + +std::optional> GetViewIfAttrsSupported( + mhlo::ReduceWindowOp op) { + const ReduceWindowView view(op); + + if (!IsRankSupported(view)) { + return std::nullopt; + } + + if (!AreDilationsSupported(view)) { + return std::nullopt; + } + + auto opt_layout = view.GuessLayout(); + if (!opt_layout.has_value()) { + return std::nullopt; + } + auto layout = opt_layout.value(); + + const int64_t batch = layout.SpecialDim1(); + if (!view.Paddings()[batch].Trivial()) { + return std::nullopt; + } + + const int64_t chan = layout.SpecialDim2(); + if (!view.Paddings()[chan].Trivial()) { + return std::nullopt; + } + + return std::tuple(view, layout); +} + +std::optional IsReduceWindowLegal(mhlo::ReduceWindowOp op) { + return std::nullopt; +} + +std::optional IsDivideLegal(mhlo::DivOp op) { return std::nullopt; } + +Layout TFLNativePoolingLayout(int64_t rank) { + return Layout(0, rank - 1, llvm::to_vector(llvm::seq(1, rank - 1))); +} + +bool IsCstFloatZero(Value val) { + DenseFPElementsAttr initial_value; + return matchPattern(val, m_Constant(&initial_value)) && + initial_value.getNumElements() == 1 && + initial_value.getValues()[0].isZero(); +} + +bool IsCstIntZero(Value val) { + DenseIntElementsAttr initial_value; + return matchPattern(val, m_Constant(&initial_value)) && + initial_value.getNumElements() == 1 && + initial_value.getValues()[0].isZero(); +} + +llvm::SmallVector Permute(llvm::ArrayRef data, + llvm::ArrayRef perm) { + llvm::SmallVector res(data.size()); + for (int i = 0; i < data.size(); ++i) { + res[i] = data[perm[i]]; + } + return res; +} + +Value TransposeTensor(OpBuilder& b, Value tensor, + llvm::SmallVector perm) { + const int64_t perm_size = perm.size(); + auto perm_attr_type = RankedTensorType::get({perm_size}, b.getI64Type()); + auto perm_attr = DenseIntElementsAttr::get(perm_attr_type, perm); + return b.create(tensor.getLoc(), tensor, perm_attr); +} + +DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef shape, + ArrayRef data) { + return DenseIntElementsAttr::get(RankedTensorType::get(shape, b.getI64Type()), + data); +} + +DenseIntElementsAttr BuildDenseI64(OpBuilder& b, ArrayRef data) { + const int64_t dim = data.size(); + return BuildDenseI64(b, {dim}, data); +} + +std::optional> GetInputAndInitIfValid( + mhlo::ReduceWindowOp op) { + if (op->getNumResults() != 1) { + return std::nullopt; + } + if (op.getInputs().size() > 1) { + return std::nullopt; + } + if (op.getInitValues().size() > 1) { + return std::nullopt; + } + auto init_val = op.getInitValues().front(); + if (llvm::dyn_cast(init_val.getType()).getNumElements() != 1) { + return std::nullopt; + } + return std::tuple(op.getInputs().front(), op.getInitValues().front()); +} + +std::optional GetTFLPadding(ArrayRef paddings, + ArrayRef window_strides, + ArrayRef in_shape, + ArrayRef window_dims) { + const int64_t rank = paddings.size(); + std::string tfl_padding = "VALID"; + for (int i = 1; i < rank - 1; ++i) { + const auto& dim_pad = paddings[i]; + if (dim_pad.Trivial()) { + continue; + } + if (!IsSamePaddingOnDim(in_shape[i], 1, window_strides[i], window_dims[i], + dim_pad)) { + return std::nullopt; + } + tfl_padding = "SAME"; + } + return tfl_padding; +} + +TFLPoolAttrsT BuildTFLPoolAttrs(OpBuilder& b, const ReduceWindowView& view, + StringRef padding) { + const int32_t filter_h = view.WindowDims()[1]; + auto filter_h_attr = b.getI32IntegerAttr(filter_h); + + const int32_t filter_w = view.WindowDims()[2]; + auto filter_w_attr = b.getI32IntegerAttr(filter_w); + + const int32_t stride_h = view.WindowStrides()[1]; + auto stride_h_attr = b.getI32IntegerAttr(stride_h); + + const int32_t stride_w = view.WindowStrides()[2]; + auto stride_w_attr = b.getI32IntegerAttr(stride_w); + + auto padding_attr = b.getStringAttr(padding); + auto faf_attr = b.getStringAttr("NONE"); + + return std::tuple(filter_h_attr, filter_w_attr, stride_h_attr, stride_w_attr, + padding_attr, faf_attr); +} + +//===------------------------------------------------------------------------=== +// relayout reduce_window to channel last +//===------------------------------------------------------------------------=== + +class RelayoutReduceWindow : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::ReduceWindowOp op, + PatternRewriter& rewriter) const final; +}; + +LogicalResult RelayoutReduceWindow::matchAndRewrite( + mhlo::ReduceWindowOp op, PatternRewriter& rewriter) const { + // + // check and parse attributes + //=----- + + auto opt_view = GetViewIfAttrsSupported(op); + if (!opt_view.has_value()) { + return rewriter.notifyMatchFailure( + op, "Reduce window attributes not supported."); + } + const auto [view, layout] = opt_view.value(); + + // + // get inputs and inits if there are only one + //=----- + + auto opt_input_and_init = GetInputAndInitIfValid(op); + if (!opt_input_and_init.has_value()) { + return rewriter.notifyMatchFailure( + op, "Reduce window has wrong number of inputs or init values."); + } + auto [input, init_val] = opt_input_and_init.value(); + + // + // figure out permutations for layout change + //=----- + + const auto target_layout = TFLNativePoolingLayout(view.Rank()); + if (layout == target_layout) { + return rewriter.notifyMatchFailure( + op, "Reduce window does not need layout change"); + } + + llvm::SmallVector perm_for_inputs = + layout.GetPermForReLayout(target_layout); + + // + // permute layout sensitive attrs + //=----- + + // permute paddings + auto paddings = view.Paddings(); + llvm::SmallVector new_paddings(paddings.size() * 2); + for (int i = 0; i < new_paddings.size() / 2; ++i) { + const auto& dim_pad = paddings[perm_for_inputs[i]]; + new_paddings[2 * i] = dim_pad.Lo(); + new_paddings[2 * i + 1] = dim_pad.Hi(); + } + const int64_t new_paddings_size = paddings.size(); + auto new_paddings_type = + RankedTensorType::get({new_paddings_size, 2}, rewriter.getI64Type()); + auto new_paddings_attr = + DenseIntElementsAttr::get(new_paddings_type, new_paddings); + + // permute window dims + llvm::SmallVector new_window_dims = + Permute(view.WindowDims(), perm_for_inputs); + auto new_window_dims_attr = BuildDenseI64(rewriter, new_window_dims); + + // permute window strides + llvm::SmallVector new_window_strides = + Permute(view.WindowStrides(), perm_for_inputs); + auto new_window_strides_attr = BuildDenseI64(rewriter, new_window_strides); + + // + // permute params and build new op + //=----- + + // figure out permuted result type + llvm::SmallVector perm_for_outputs = + target_layout.GetPermForReLayout(layout); + auto cur_out_type = llvm::dyn_cast(op.getResult(0).getType()); + llvm::SmallVector new_rw_out_shape = + layout.PermuteShape(target_layout, cur_out_type.getShape()); + auto new_out_type = cur_out_type.clone(new_rw_out_shape); + + // transpose input and build new reduce_window + auto new_input = TransposeTensor(rewriter, input, perm_for_inputs); + auto new_rw = rewriter.create( + op.getLoc(), new_out_type, new_input, init_val, new_window_dims_attr, + new_window_strides_attr, BuildDenseI64(rewriter, view.BaseDilations()), + BuildDenseI64(rewriter, view.WindowDilations()), new_paddings_attr); + IRMapping ir_map; + op.getBody().cloneInto(&new_rw.getBody(), ir_map); + + // transpose output and update ir + auto new_output = + TransposeTensor(rewriter, new_rw.getResult(0), perm_for_outputs); + rewriter.replaceOp(op, new_output); + + return success(); +} + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> tfl.cum_sum +//===------------------------------------------------------------------------=== + +class LegalizeCumSum : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeCumSum::matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // + // check singular params and trivial attrs + //=----- + + auto opt_input_init = GetInputAndInitIfValid(op); + if (!opt_input_init.has_value()) { + return rewriter.notifyMatchFailure(op, + "Must have 1 input, init and result."); + } + auto [input, init] = opt_input_init.value(); + + if (failed(MatchBinaryReduceFunction(op.getBody()))) { + return rewriter.notifyMatchFailure(op, "Requires scalar add in region."); + } + + if (!IsCstFloatZero(init) && !IsCstIntZero(init)) { + return rewriter.notifyMatchFailure(op, "Requires 0 for init value."); + } + + const ReduceWindowView view(op); + + auto trivial = [](int64_t v) { return v == 1; }; + const bool trivial_window_dilate = + llvm::all_of(view.WindowDilations(), trivial); + const bool trivial_base_dilate = llvm::all_of(view.BaseDilations(), trivial); + const bool trivial_stride = llvm::all_of(view.WindowStrides(), trivial); + if (!trivial_window_dilate || !trivial_stride || !trivial_base_dilate) { + return rewriter.notifyMatchFailure( + op, "Requires trivial strides and dilations attributes."); + } + + // + // figure out the implicit axis of reduction + //=----- + + auto input_type = llvm::cast(input.getType()); + if (view.WindowDims().size() != input_type.getRank()) { + return rewriter.notifyMatchFailure(op, "Splat window dims not supported."); + } + int64_t axis = -1; + for (auto [ind, val] : llvm::enumerate(view.WindowDims())) { + if (val == 1) { + continue; + } + + if (axis != -1) { + return rewriter.notifyMatchFailure(op, "Multiple non 1 dimensions."); + } + + if (val != input_type.getShape()[ind]) { + return rewriter.notifyMatchFailure( + op, "Axis dimension requires size be same as input shape's."); + } + axis = ind; + } + + if (axis == -1) { + return rewriter.notifyMatchFailure(op, "Could not identify axis."); + } + + const int64_t axis_size = input_type.getShape()[axis]; + + // + // validate padding is [N-1, 0] on axis and zero elsewhere + //=----- + + for (const auto& [ind, dim_pad] : llvm::enumerate(view.Paddings())) { + if (dim_pad.Hi() != 0) { + return rewriter.notifyMatchFailure(op, "Has non trivial high padding."); + } + + if (ind != axis) { + if (!dim_pad.Trivial()) { + return rewriter.notifyMatchFailure( + op, "Has non trivial padding on non axis dim."); + } + } else { + if (dim_pad.Lo() != axis_size - 1) { + return rewriter.notifyMatchFailure( + op, "Requires low padding on axis dim to be N - 1."); + } + } + } + + // + // build axis constant and tfl op + //=----- + + auto axis_cst_attr = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), + static_cast(axis)); + auto axis_cst = + rewriter.create(op->getLoc(), axis_cst_attr); + + auto tfl_exclusive_attr = rewriter.getBoolAttr(false); + auto tfl_reverse_attr = rewriter.getBoolAttr(false); + + rewriter.replaceOpWithNewOp(op, op->getResultTypes()[0], input, + axis_cst, tfl_exclusive_attr, + tfl_reverse_attr); + + return success(); +} + +//===------------------------------------------------------------------------=== +// mhlo.reduce_window -> tfl.max_pool +//===------------------------------------------------------------------------=== + +bool isFloatMinusInfinity(Value value) { + DenseFPElementsAttr float_value; + if (!matchPattern(value, m_Constant(&float_value))) { + return false; + } + if (float_value.getNumElements() != 1) { + return false; + } + APFloat element = float_value.getValues()[0]; + return element.isInfinity() && element.isNegative(); +} + +class LegalizeMaxPool : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeMaxPool::matchAndRewrite( + mhlo::ReduceWindowOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // + // parse and validate lhs reduce window + //=----- + + const auto opt_view = GetViewIfAttrsSupported(op); + if (!opt_view.has_value()) { + return rewriter.notifyMatchFailure(op, "Reduce window is not valid."); + } + const auto [view, layout] = opt_view.value(); + if (layout != TFLNativePoolingLayout(layout.Rank())) { + return rewriter.notifyMatchFailure(op, "Not tfl standard layout."); + } + + // Check that the reduce-window is a max-reduce-window. + if (failed(MatchBinaryReduceFunction(op.getBody()))) { + return rewriter.notifyMatchFailure(op, "Must be a max pool."); + } + + auto type = mlir::dyn_cast(op.getResult(0).getType()); + if (!mlir::isa(type.getElementType())) { + return rewriter.notifyMatchFailure(op, "Not a floating point pool."); + } + + // + // validate inputs and init + //=----- + + auto opt_inputs_and_init = GetInputAndInitIfValid(op); + if (!opt_inputs_and_init.has_value()) { + return rewriter.notifyMatchFailure(op, "Too many inputs or inits."); + } + auto [input, init] = opt_inputs_and_init.value(); + auto input_type = llvm::dyn_cast(input.getType()); + + if (!isFloatMinusInfinity(init)) { + return rewriter.notifyMatchFailure(op, "Init not minus infinity."); + } + + // + // build tfl + //=----- + + auto opt_tfl_padding = + GetTFLPadding(view.Paddings(), view.WindowStrides(), + input_type.getShape(), view.WindowDims()); + if (!opt_tfl_padding.has_value()) { + return rewriter.notifyMatchFailure(op, "Padding not SAME or VALID."); + } + const auto& tfl_padding = opt_tfl_padding.value(); + + auto [fh, fw, sh, sw, p, faf] = + BuildTFLPoolAttrs(rewriter, view, tfl_padding); + rewriter.replaceOpWithNewOp(op, type, input, p, sw, sh, fw, + fh, faf); + + return success(); +} + +//===------------------------------------------------------------------------=== +// mhlo.div(mhlo.reduce_window, cst | mhlo.reduce_window) -> tfl.avg_pool +//===------------------------------------------------------------------------=== + +void ReplaceWithAvgPool(mhlo::DivOp op, Value rw_lhs_input, + const ReduceWindowView& lhs_view, + llvm::StringRef padding, PatternRewriter& rewriter, + mhlo::TransposeOp opt_final_tpose) { + Type out_type = + opt_final_tpose ? opt_final_tpose.getOperand().getType() : op.getType(); + + auto [fh, fw, sh, sw, p, faf] = + BuildTFLPoolAttrs(rewriter, lhs_view, padding); + Value final_op = rewriter.create( + op->getLoc(), out_type, rw_lhs_input, fh, fw, p, sh, sw, faf); + + if (opt_final_tpose) { + final_op = rewriter + .create(final_op.getLoc(), final_op, + opt_final_tpose.getPermutation()) + .getResult(); + } + + rewriter.replaceOp(op, final_op); +} + +// Walks up the op and ignore all precedding ops of type Tys. +// Returns the first producer op whose type is not in Tys. +template +Value RecursivelyWalkUp(Value op) { + while (llvm::isa_and_nonnull(op.getDefiningOp())) { + Operation* producer = op.getDefiningOp(); + op = producer->getOperand(/*idx=*/0); + } + + return op; +} + +class LegalizeAvgPool : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + mhlo::DivOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeAvgPool::matchAndRewrite( + mhlo::DivOp div_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + // + // parse and validate lhs reduce window + //=----- + + auto div_lhs = div_op.getLhs(); + // If div's input is transposed, save it to chain on the new pool op. + mhlo::TransposeOp opt_final_tpose; + if (auto div_lhs_op = div_lhs.getDefiningOp()) { + opt_final_tpose = llvm::dyn_cast_or_null(div_lhs_op); + } + + auto rw_lhs_val = RecursivelyWalkUp(div_lhs); + auto rw_lhs = + llvm::dyn_cast_or_null(rw_lhs_val.getDefiningOp()); + if (!rw_lhs) { + return rewriter.notifyMatchFailure( + div_op, "Could not match lhs of div on reduce window."); + } + + const auto opt_rw_lhs_view = GetViewIfAttrsSupported(rw_lhs); + if (!opt_rw_lhs_view.has_value()) { + return rewriter.notifyMatchFailure(div_op, "Lhs rw is not valid."); + } + const auto [rw_lhs_view, rw_lhs_layout] = opt_rw_lhs_view.value(); + if (rw_lhs_layout != TFLNativePoolingLayout(rw_lhs_layout.Rank())) { + return rewriter.notifyMatchFailure( + div_op, "Lhs reduce window not tfl standard layout."); + } + + // Check that the reduce-window is a sum-reduce-window. + if (failed(MatchBinaryReduceFunction(rw_lhs.getBody()))) { + return rewriter.notifyMatchFailure(div_op, + "Failed to match rw lhs binary func."); + } + + // + // validate inputs and init val + //=----- + + auto opt_rw_lhs_input_and_init = GetInputAndInitIfValid(rw_lhs); + if (!opt_rw_lhs_input_and_init.has_value()) { + return rewriter.notifyMatchFailure( + div_op, "Lhs reduce window has wrong number of inputs or init values."); + } + auto [rw_lhs_input, rw_lhs_init_val] = opt_rw_lhs_input_and_init.value(); + auto rw_lhs_input_type = llvm::dyn_cast(rw_lhs_input.getType()); + + auto rw_lhs_type = + mlir::dyn_cast(rw_lhs.getResult(0).getType()); + if (!mlir::isa(rw_lhs_type.getElementType())) { + return rewriter.notifyMatchFailure(div_op, + "Reduce window lhs most be float type."); + } + + // If the init value isn't zero then it can't be an average pool. + if (!IsCstFloatZero(rw_lhs_init_val)) { + return rewriter.notifyMatchFailure( + div_op, "Reduce window lhs init value is not zero."); + } + + // + // case 1: rhs is splat const with val == window_size + //=----- + + auto opt_tfl_padding = + GetTFLPadding(rw_lhs_view.Paddings(), rw_lhs_view.WindowStrides(), + rw_lhs_input_type.getShape(), rw_lhs_view.WindowDims()); + if (!opt_tfl_padding.has_value()) { + return rewriter.notifyMatchFailure(div_op, + "Padding must be VALID or SAME."); + } + const auto& tfl_padding = opt_tfl_padding.value(); + + { + DenseFPElementsAttr divisor; + auto div_rhs = RecursivelyWalkUp( + div_op.getRhs()); + if (matchPattern(div_rhs, m_Constant(&divisor))) { + if (!divisor.isSplat()) { + return failure(); + } + + if (!divisor.getSplatValue().isExactlyValue( + rw_lhs_view.WindowSize())) { + return rewriter.notifyMatchFailure( + div_op, "Rhs splat const is not equal to window size."); + } + + if (tfl_padding != "VALID") { + return rewriter.notifyMatchFailure(div_op, + "Matching on rhs splat const where " + "rw lhs has non-trivial padding."); + } + + ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding, + rewriter, opt_final_tpose); + return success(); + } + } + + // + // case 2: rhs is another reduce window over 1's with same config as lhs + //=----- + + { + Value divisor = RecursivelyWalkUp(div_op.getRhs()); + auto rw_rhs = + dyn_cast_or_null(divisor.getDefiningOp()); + if (!rw_rhs) { + return rewriter.notifyMatchFailure( + div_op, "Rhs of div op is not a reduce window."); + } + + const auto opt_rw_rhs_view = GetViewIfAttrsSupported(rw_rhs); + if (!opt_rw_rhs_view.has_value()) { + return rewriter.notifyMatchFailure(div_op, "Rhs rw is not valid."); + } + const auto [rw_rhs_view, rw_rhs_layout] = opt_rw_rhs_view.value(); + if (rw_rhs_layout != TFLNativePoolingLayout(rw_rhs_layout.Rank())) { + return rewriter.notifyMatchFailure( + div_op, "Rhs reduce window not tfl standard layout."); + } + + // Check that RHS is a sum-reduce-window. + if (failed(MatchBinaryReduceFunction(rw_rhs.getBody()))) { + return rewriter.notifyMatchFailure( + div_op, "Rhs rw body function is not an add op."); + } + + auto opt_rw_rhs_input_and_init = GetInputAndInitIfValid(rw_rhs); + if (!opt_rw_rhs_input_and_init.has_value()) { + return rewriter.notifyMatchFailure( + div_op, + "Rhs reduce window has wrong number of inputs or init values."); + } + auto [rw_rhs_input, rw_rhs_init_val] = opt_rw_rhs_input_and_init.value(); + + if (!IsCstFloatZero(rw_rhs_init_val)) { + return rewriter.notifyMatchFailure(div_op, + "Rhs rw init vals is not zero."); + } + + rw_rhs_input = RecursivelyWalkUp( + rw_rhs_input); + DenseFPElementsAttr rhs_input_data; + if (!matchPattern(rw_rhs_input, m_Constant(&rhs_input_data)) || + !rhs_input_data.isSplat() || + !rhs_input_data.getSplatValue().isExactlyValue(1.0)) { + return rewriter.notifyMatchFailure(div_op, + "Rw rhs input is not splat of 1.0."); + } + + // Check that the two reduce window have the same window configuration. + if (rw_lhs.getWindowDimensions() != rw_rhs.getWindowDimensions() || + rw_lhs.getWindowStrides() != rw_rhs.getWindowStrides() || + rw_lhs.getPadding() != rw_rhs.getPadding()) { + return rewriter.notifyMatchFailure( + div_op, "Lhs rw and Rhs rw do not have the same config."); + } + + ReplaceWithAvgPool(div_op, rw_lhs_input, rw_lhs_view, tfl_padding, rewriter, + opt_final_tpose); + return success(); + } + + return failure(); +} + +} // namespace + +void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + target.addDynamicallyLegalOp(IsReduceWindowLegal); + target.addDynamicallyLegalOp(IsDivideLegal); +} + +void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns) { + patterns.add(ctx); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h new file mode 100644 index 00000000000000..ccc9c27f6955cd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +// Patterns to legalize mhlo.reduce_window to TFL. +// +// Maps the following representations of AvgPool in MHLO into a tfl.avg_pool +// operation when they cleanly map to 2D or 3D average pool with VALID or SAME +// padding: +// * div(reduce_sum_window(x), constant(sizeof(window))) +// * div(reduce_sum_window(x), reduce_sum_window(constant(1))) +// +// Emits: tfl.average_pool2d +void PopulateLegalizeReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target); + +// Patterns to prepare mhlo.reduce_window for legalization. +// Transposes reduce_windows to be NHWC. +// +// Emits: tfl.transpose +void PopulatePrepareReduceWindowPatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc new file mode 100644 index 00000000000000..67e4db62566452 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h" + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +ReduceWindowView::ReduceWindowView(mhlo::ReduceWindowOp op) { + rank_ = op.getWindowDimensions().size(); + window_dims_ = + SmallVector(op.getWindowDimensions().getValues()); + window_strides_ = ResolveStridesOrDilations(rank_, op.getWindowStrides()); + window_dilations_ = ResolveStridesOrDilations(rank_, op.getWindowDilations()); + base_dilations_ = ResolveStridesOrDilations(rank_, op.getBaseDilations()); + paddings_ = ResolvePadding(rank_, op.getPadding()); + window_size_ = 1; + for (auto d : window_dims_) { + window_size_ *= d; + } +} + +std::optional ReduceWindowView::GuessLayout() const { + auto zip_dims_strides = llvm::zip(WindowDims(), WindowStrides()); + auto simple_window_dims = + llvm::to_vector(llvm::map_range(zip_dims_strides, [](auto it) { + return std::get<0>(it) == 1 && std::get<1>(it) == 1; + })); + + if (llvm::count(simple_window_dims, 1) < 2) { + return std::nullopt; + } + + const bool is_channel_last = + simple_window_dims[0] && simple_window_dims[Rank() - 1]; + if (is_channel_last) { + return Layout(0, Rank() - 1, + llvm::to_vector(llvm::seq(1, Rank() - 1))); + } + + const bool is_channel_first = simple_window_dims[0] && simple_window_dims[1]; + if (is_channel_first) { + return Layout(0, 1, llvm::to_vector(llvm::seq(2, Rank()))); + } + + // In theory, we can support any layout with at least 2 1's in + // `simple_window_dims` by permuting layouts such that the 1's are + // the first and last position. Unclear if such a case ever comes up. + return std::nullopt; +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h new file mode 100644 index 00000000000000..512389bdb4ec04 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window_util.h @@ -0,0 +1,62 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +// Helpers for working with mhlo.reduce_window attrs in the mlir api as +// native cc types. + +namespace mlir::odml { + +class ReduceWindowView { + public: + explicit ReduceWindowView(mhlo::ReduceWindowOp op); + + llvm::ArrayRef WindowDims() const { return window_dims_; } + int64_t WindowSize() const { return window_size_; } + llvm::ArrayRef WindowStrides() const { return window_strides_; } + llvm::ArrayRef Paddings() const { return paddings_; } + llvm::ArrayRef WindowDilations() const { return window_dilations_; } + llvm::ArrayRef BaseDilations() const { return base_dilations_; } + int64_t Rank() const { return rank_; } + + std::optional GuessLayout() const; + + private: + int64_t rank_; + + llvm::SmallVector window_dims_; + llvm::SmallVector window_strides_; + llvm::SmallVector window_dilations_; + + llvm::SmallVector paddings_; + + llvm::SmallVector base_dilations_; + + int64_t window_size_; +}; + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_REDUCE_WINDOW_UTIL_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc new file mode 100644 index 00000000000000..07f9f0368ad665 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.cc @@ -0,0 +1,241 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/op_util_common.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +// mhlo encodes ND indice arguments as a variadiac of scalars. Pack them +// into a single tensor for use in TFL. +Value PackScalarIndices(mlir::ValueRange indices, OpBuilder& b) { + auto e_type = + llvm::cast(indices.front().getType()).getElementType(); + const int64_t num_indices = indices.size(); + auto packed_indices_type = RankedTensorType::get({num_indices}, e_type); + + auto values_count_attr = b.getI32IntegerAttr(num_indices); + auto pack_axis_attr = b.getI32IntegerAttr(0); + + return b.create(indices.back().getLoc(), packed_indices_type, + indices, values_count_attr, pack_axis_attr); +} + +//===----------------------------------------------------------------------===// +// mhlo.slice +//===----------------------------------------------------------------------===// + +// Cast the value to i32. +Value BuildTFLCastOp(OpBuilder& b, Value value) { + return b.create( + value.getLoc(), + RankedTensorType::get(llvm::cast(value.getType()).getShape(), + b.getI32Type()), + value); +} + +class LegalizeSliceOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::SliceOp slice_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final { + auto begin = rewriter.create(slice_op.getLoc(), + slice_op.getStartIndices()); + auto end = rewriter.create(slice_op.getLoc(), + slice_op.getLimitIndices()); + auto strides = rewriter.create(slice_op.getLoc(), + slice_op.getStrides()); + auto zero = rewriter.getIntegerAttr(rewriter.getI32Type(), 0); + auto no_offset = rewriter.getBoolAttr(false); + + rewriter.replaceOpWithNewOp( + slice_op, slice_op.getType(), slice_op.getOperand(), + BuildTFLCastOp(rewriter, begin), BuildTFLCastOp(rewriter, end), + BuildTFLCastOp(rewriter, strides), zero, zero, zero, zero, zero, + no_offset); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// mhlo.dynamic_slice +//===----------------------------------------------------------------------===// + +class CastSliceIndicesToSignless + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::DynamicSliceOp op, + PatternRewriter& rewriter) const final; +}; + +LogicalResult CastSliceIndicesToSignless::matchAndRewrite( + mhlo::DynamicSliceOp op, PatternRewriter& rewriter) const { + // All start inds have the same element type. + auto start_type = + llvm::cast(op.getStartIndices().front().getType()); + auto start_e_type = start_type.getElementType(); + + if (start_e_type.isSignlessIntOrFloat()) { + return rewriter.notifyMatchFailure(op, "Already signless."); + } + auto new_start_e_type = + rewriter.getIntegerType(start_e_type.getIntOrFloatBitWidth()); + + llvm::SmallVector casted_start_inds; + for (auto start_ind_opr : op.getStartIndices()) { + auto casted_start_ind_opr = rewriter.create( + start_ind_opr.getLoc(), start_ind_opr, new_start_e_type); + casted_start_inds.push_back(casted_start_ind_opr.getResult()); + } + + rewriter.replaceOpWithNewOp( + op, op.getOperand(), casted_start_inds, op.getSliceSizes()); + + return success(); +} + +bool IsDynamicSliceLegal(mhlo::DynamicSliceOp op) { + return !llvm::cast(op.getOperand().getType()).hasStaticShape(); +} + +class LegalizeDynamicSliceOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeDynamicSliceOp::matchAndRewrite( + mhlo::DynamicSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto start_type = + llvm::cast(op.getStartIndices().front().getType()); + auto start_e_type = start_type.getElementType(); + if (!start_e_type.isSignlessIntOrFloat()) { + return rewriter.notifyMatchFailure( + op, "Must be signless integer for start indices."); + } + + auto input_type = llvm::cast(op.getOperand().getType()); + if (!input_type.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "Input must be statically shaped."); + } + + // + // clamp start indices between zero and shape(operand) - slice_sizes + //=----- + + Value clamp_left_cst = rewriter.create( + op->getLoc(), rewriter.getZeroAttr(start_type)); + + llvm::SmallVector new_start_indices; + const auto stride_sizes = UnrollI64Splat(op.getSliceSizes()); + + for (auto [dim_size, start_ind_opr, stride_size] : + llvm::zip(input_type.getShape(), op.getStartIndices(), stride_sizes)) { + const int64_t clamp_right_val = dim_size - stride_size; + auto clamp_right_cst = rewriter.create( + op->getLoc(), + DenseElementsAttr::get(start_type, rewriter.getIntegerAttr( + start_e_type, clamp_right_val))); + + Value new_start_ind = rewriter.create( + op->getLoc(), start_type, clamp_left_cst, start_ind_opr); + new_start_ind = rewriter.create( + op->getLoc(), start_type, clamp_right_cst, new_start_ind); + + new_start_indices.push_back(new_start_ind); + } + + // + // build tfl + //=----- + + auto packed_indices = PackScalarIndices(new_start_indices, rewriter); + + auto slice_sizes_cst = + rewriter.create(op->getLoc(), op.getSliceSizes()); + + rewriter.replaceOpWithNewOp(op, op.getType(), op.getOperand(), + packed_indices, slice_sizes_cst); + + return success(); +} + +//===----------------------------------------------------------------------===// +// mhlo.dynamic_update_slice +//===----------------------------------------------------------------------===// + +class LegalizeDynamicUpdateSliceOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeDynamicUpdateSliceOp::matchAndRewrite( + mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto packed_indices = PackScalarIndices(op.getStartIndices(), rewriter); + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getOperand(), op.getUpdate(), packed_indices); + return success(); +}; + +} // namespace + +void PopulateLegalizeSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + + target.addIllegalOp(); + target.addDynamicallyLegalOp(IsDynamicSliceLegal); +} + +void PopulatePrepareSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns) { + patterns.add(ctx); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h new file mode 100644 index 00000000000000..024cbb4a2fe327 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h @@ -0,0 +1,34 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { + +// Patterns to legalize mhlo.slice to TFL. +void PopulateLegalizeSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns, + ConversionTarget& target); + +void PopulatePrepareSlicePatterns(MLIRContext* ctx, + RewritePatternSet& patterns); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SLICE_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc new file mode 100644 index 00000000000000..43477cc0dcdfab --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.cc @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h" + +#include + +#include "llvm/ADT/ilist.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::odml { +namespace { + +using OpListType = llvm::iplist; + +template +bool MatchTopKComparator(Region& comparator) { + if (!comparator.hasOneBlock()) return false; + Block& comparator_blk = comparator.front(); + + OpListType& operations = comparator_blk.getOperations(); + if (operations.size() != 2) return false; + + auto compare_op = + llvm::dyn_cast_or_null(&operations.front()); + auto return_op = llvm::dyn_cast_or_null(&operations.back()); + if (!compare_op || !return_op) return false; + + if (compare_op.getComparisonDirection() != mhlo::ComparisonDirection::GT) { + return false; + } + + if (compare_op.getOperands()[0] != comparator_blk.getArgument(0) || + compare_op.getOperands()[1] != comparator_blk.getArgument(1)) { + return false; + } + + return return_op.getOperands().front() == compare_op.getResult(); +} + +bool IsSortOpNotTopK(mhlo::SortOp op) { + if (op->getNumOperands() != 2) { + return true; + } + + auto keys_opr = op.getInputs().front(); + auto keys_type = llvm::cast(keys_opr.getType()); + + if (!keys_type.hasStaticShape() || + !keys_type.getElementType().isIntOrFloat()) { + return true; + } + + auto indices_opr = op.getInputs().back(); + auto indices_type = llvm::cast(indices_opr.getType()); + + if (!indices_type.hasStaticShape() || + !indices_type.getElementType().isInteger(32)) { + return true; + } + + const int64_t sort_dim = op.getDimension(); + const auto k = indices_type.getDimSize(sort_dim); + const auto rank = keys_type.getRank(); + + if (sort_dim != rank - 1 || k < 1) { + return true; + } + + OpBuilder b(op->getContext()); + if (!MatchIota(b.getI64TensorAttr({sort_dim}), indices_opr)) { + return true; + } + + if (!MatchTopKComparator(op.getComparator())) { + return true; + } + + return false; +} + +class LegalizeSortOp : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::SortOp sort_op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const final; +}; + +LogicalResult LegalizeSortOp::matchAndRewrite( + mhlo::SortOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + if (IsSortOpNotTopK(op)) { + return failure(); + } + + auto keys = op.getInputs().front(); + auto indices = op.getInputs().back(); + auto indices_type = llvm::cast(indices.getType()); + + const int32_t k = indices_type.getShape().back(); + auto k_cst_attr = DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getI32Type()), k); + auto k_cst = rewriter.create(op->getLoc(), k_cst_attr); + + rewriter.replaceOpWithNewOp(op, keys.getType(), + indices.getType(), keys, k_cst); + + return success(); +} + +} // namespace + +void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target) { + patterns.add(ctx); + target.addDynamicallyLegalOp(IsSortOpNotTopK); +} + +} // namespace mlir::odml diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h new file mode 100644 index 00000000000000..9bbb1f3fde06ab --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h @@ -0,0 +1,28 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ + +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project + +namespace mlir::odml { + +void PopulateSortPatterns(MLIRContext* ctx, RewritePatternSet& patterns, + ConversionTarget& target); + +} // namespace mlir::odml + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_HLO_CONVERSIONS_SORT_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index fe988ba9b20265..49d38d78cb6f2a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -380,7 +380,7 @@ def ConvertDotGeneralOp : NativeCodeCall<"ConvertDotGeneralOp($_builder, " def : Pat<(MHLO_DotGeneralOp:$old_value RankedTensorOf<[TF_ElementType]>:$lhs, RankedTensorOf<[TF_ElementType]>:$rhs, - $dot_dimension_numbers, $precision_config), + $dot_dimension_numbers, $precision_config, $algorithm), (ConvertDotGeneralOp $old_value)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index d9c23dfa12b8ae..062222f72b3b9a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -58,7 +58,7 @@ LogicalResult ConvertDotToDotGeneral(mhlo::DotOp op, /*rhsBatchingDimensions=*/{}, /*lhsContractingDimensions=*/{lhs_type.getRank() - 1}, /*rhsContractingDimensions=*/{0}), - op.getPrecisionConfigAttr()); + op.getPrecisionConfigAttr(), mhlo::DotAlgorithmAttr{}); return success(); } @@ -161,7 +161,7 @@ LogicalResult RemoveReshapeAroundDotGeneral(mhlo::ReshapeOp reshape_after, range(batch_dims_count + shape_y1.size(), contracting_dims_count), /*rhsContractingDimensions=*/ range(batch_dims_count, contracting_dims_count)), - dot.getPrecisionConfigAttr()); + dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr()); return success(); } @@ -273,7 +273,8 @@ LogicalResult LiftDotConcatLHS(mhlo::ConcatenateOp concat, rewriter.getI64IntegerAttr(new_concat_dim)); rewriter.replaceOpWithNewOp( concat, concat.getType(), new_concat, first_dot.getRhs(), - first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr(), + first_dot.getAlgorithmAttr()); return success(); } @@ -374,7 +375,8 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, all_dot_rhs, rewriter.getI64IntegerAttr(rhs_batch_dim)); rewriter.replaceOpWithNewOp( concat, concat.getType(), lhs_new_concat, rhs_new_concat, - first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr()); + first_dot.getDotDimensionNumbers(), first_dot.getPrecisionConfigAttr(), + first_dot.getAlgorithmAttr()); return success(); } @@ -611,10 +613,134 @@ LogicalResult ConvertReshapeDotRhsToBatchedDot(mhlo::DotGeneralOp dot, /*rhsBatchingDimensions=*/{0}, /*lhsContractingDimensions=*/dim_nums.getLhsContractingDimensions(), /*rhsContractingDimensions=*/new_rhs_contracting_dims), - dot.getPrecisionConfigAttr()); + dot.getPrecisionConfigAttr(), dot.getAlgorithmAttr()); return success(); } +//===----------------------------------------------------------------------===// +// BroadcastInDimsOp +//===----------------------------------------------------------------------===// + +// Minimizing unit dimensions in reshape(broadcast(X)). +// +// There are situations where X, or broadcast(X) have some number of `1` (unit) +// sized dimensions which are not meaningful to the computation. E.g. +// +// ``` +// x = [1x1x1x3] +// b = broadast(x) : [1x2x1x3] +// r = reshape(b) : [2x3] +// ``` +// +// Provided the relative broadcast dims are preserved, removing any number +// of unit dims from the input or output shape of a broadcast has no effect on +// the semantic of the computation. +// +// Assume a reshape(broadcast(x)) where the shape of the broadcast and reshape +// have the same non-unit dims in the same order. In this case we can +// change the broadcast shape into the reshape shape simply by adding or +// removing unit-dims, and the reshape can be replaced with the broadcast. +// +// When removing unit dims from the broadcast in this way, we may also need +// to remove the corresponding unit dim from the input shape. This pattern takes +// the approach of removing all unit dims for the broadcast input +// rather than explicitly checking each. +// +// The result on the above example: +// +// ``` +// x = [1x1x1x3] +// r = reshape(x) : [3] +// b = broadast(r) : [2x3] +// ``` +// +// Note that the ability of removing unit dims from the input or output shape of +// a broascast is not contingent on matching and replacing a reshaped output. We +// require however for this pattern to not increase the net number of reshapes. +// Additionally, we want to minimize the rank of broadcasts so only considered +// are cases where rank(reshape) < rank(broadcast). +class SimplifyBroadcastInDimsReshape + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mhlo::BroadcastInDimOp op, + PatternRewriter &rewriter) const override { + if (!op->hasOneUse()) { + return rewriter.notifyMatchFailure(op, "has more than one use."); + } + + auto reshape = mlir::dyn_cast(*op->getUsers().begin()); + if (!reshape) { + return rewriter.notifyMatchFailure(op, "user not reshape."); + } + + auto broadcast_type = mlir::cast(op.getType()); + auto broadcast_input_type = + mlir::cast(op.getOperand().getType()); + auto reshape_type = mlir::cast(reshape.getType()); + + // Reshape must be squeezing unit dimensions. + if (!(reshape_type.getRank() < broadcast_type.getRank())) { + return rewriter.notifyMatchFailure(op, "reshape doesn't reduce rank."); + } + + // Reshape and broadcast must have the same non-unit dims in the + // same order. + llvm::SmallVector broadcast_dim_to_reshape_dim( + broadcast_type.getRank()); + int64_t reshape_dim_idx = -1; + for (auto [idx, dim] : llvm::enumerate(broadcast_type.getShape())) { + if (dim == 1) { + continue; + } + + int64_t reshape_dim_size = 1; + while (reshape_dim_idx < reshape_type.getRank() - 1) { + reshape_dim_size = reshape_type.getDimSize(++reshape_dim_idx); + if (reshape_dim_size != 1) { + break; + } + } + + if (dim != reshape_dim_size) { + return rewriter.notifyMatchFailure( + op, "reshape and broadcast have different non-unit dim sizes."); + } + + // Maps index of non-unit broadcast dims to corresponding reshape dim. + broadcast_dim_to_reshape_dim[idx] = reshape_dim_idx; + } + // Unchecked reshape dim sizes are guaranteed to be unit at this point. + + llvm::SmallVector current_broadcast_dims( + op.getBroadcastDimensions().getValues()); + llvm::SmallVector new_broadcast_dims; + llvm::SmallVector new_broadcast_input_shape; + + for (auto [idx, dim] : llvm::enumerate(broadcast_input_type.getShape())) { + if (dim == 1) { + continue; + } + // If dim != 1 then it must be broadcasted to a non-unit dimension + // and must have a corresponding reshape dimension in our vectors. + new_broadcast_dims.push_back( + broadcast_dim_to_reshape_dim[current_broadcast_dims[idx]]); + new_broadcast_input_shape.push_back(dim); + } + + auto new_broadcast_input_type = RankedTensorType::get( + new_broadcast_input_shape, broadcast_type.getElementType()); + auto new_broadcast_input = rewriter.create( + op->getLoc(), new_broadcast_input_type, op.getOperand()); + auto new_broadcast_dims_attr = + rewriter.getI64TensorAttr(new_broadcast_dims); + + rewriter.replaceOpWithNewOp( + reshape, reshape_type, new_broadcast_input, new_broadcast_dims_attr); + + return success(); + } +}; + class OptimizePass : public PassWrapper> { public: @@ -632,6 +758,7 @@ class OptimizePass patterns.add(FuseSliceConcat); patterns.add(ConvertReshapeDotRhsToBatchedDot); patterns.add(MergeConsecutivePad); + patterns.add(&getContext()); if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td index c789b3bde293c6..3eb051d38d8917 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.td @@ -151,3 +151,19 @@ def PrepareHloPass ]; } +def LiftCallSiteLocCallerPass : Pass<"lift-callsite-loc-caller", "ModuleOp"> { + let summary = "Lifts CallSites in pytorch generated stablehlo."; + let description = [{ + Lifts CallSites in pytorch generated stablehlo to make the Loc's consitent + after inlining. + }]; + let dependentDialects = ["func::FuncDialect"]; +} + +def BuildStableHLOCompositePass : Pass<"build-stablehlo-composite", "ModuleOp"> { + let summary = "Build stablehlo.composite from inlined stablehlo.custom_call mark_tensor ops."; + let description = [{ + Build stablehlo.composite from inlined stablehlo.custom_call mark_tensor ops. + }]; + let dependentDialects = ["func::FuncDialect", "stablehlo::StablehloDialect"]; +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc index 6b6741c2742698..9ff5a6ef58f1b1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -26,8 +25,11 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/conv_util.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad_util.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep namespace mlir { @@ -52,6 +54,10 @@ void PrepareHloPass::runOnOperation() { RewritePatternSet patterns(context); populateWithGenerated(patterns); + PopulatePrepareConvPatterns(context, patterns); + PopulatePrepareReduceWindowPatterns(context, patterns); + PopulatePrepareSlicePatterns(context, patterns); + if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) { signalPassFailure(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td index 1cc8398afd749a..9b6f6efbfcf4f6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td @@ -29,8 +29,9 @@ include "mlir/Dialect/Arith/IR/ArithOps.td" // to be of a specific configuration: // // TFL Native Standard Conv Layouts: -// 2D : [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] -// 3D : [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] +// 2D : [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f] +// 3D : [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f] +// 2D (depthwise) : [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f] // // The following patterns transpose the inputs and output of mhlo.convolution // ops until they are complicit with the TFL required layout. @@ -63,8 +64,24 @@ def PermuteShape : NativeCodeCall< def IsStandardConv : Constraint())">>; +def IsDepthwiseConv : Constraint())">>; + +def IsSupportedNonTrivialConv : Constraint())">>; + +def IsSupportedConv : Constraint>; + +def IsSupportedStandardOrNonTrivialConv : Constraint>; + +def IsStandardOrDepthwiseConv : Constraint>; + +// // Re-layout input (lhs) to [b, spatials..., f] -//===------------------------------------------ +//===--------------------------------------------------------------------------- def IsInputNotTFLNativeLayout : Constraint; + (IsSupportedConv $conv)], + [], + (addBenefit 1)>; -// Re-layout kernel to [o, spatials..., i] (2d) or [spatials..., i, o] (3d) -//===---------------------------------------------------------------------- +// +// Re-layout kernel +//===--------------------------------------------------------------------------- -def IsKernelNotTFLNativeLayout : Constraint>; +def KernelHasIotaSpatials : Constraint>; def KernelLayout : NativeCodeCall< "Layout($0.getKernelInputFeatureDimension()," "$0.getKernelOutputFeatureDimension()," "$0.getKernelSpatialDimensions())">; -def TFLNativeKernelLayout : NativeCodeCall< - "GetTFLNativeKernelLayout($0)">; +// +// standard conv kernel = [o, spatials..., i]. +//=----- -def KernelHasIotaSpatials : Constraint>; +def IsKernelNotTFLNativeStandardConvLayout : Constraint>; + +def TFLNativeStandardConvKernelLayout : NativeCodeCall< + "GetTFLNativeStandardConvKernelLayout($0)">; // Copy dnums with the kernel layout set to [o, spatials..., i]. -def CloneDnumsWithTFLNativeKernelLayout : NativeCodeCall< +def CloneDnumsWithTFLNativeStandardConvKernelLayout : NativeCodeCall< "CloneDnumsWithKernelLayout(" "$_builder," "$0," - "GetTFLNativeKernelLayout($0))">; + "GetTFLNativeStandardConvKernelLayout($0))">; + def ReLayoutConvKernel : Pat<(MHLO_ConvolutionOp:$conv $input, @@ -179,25 +204,88 @@ def ReLayoutConvKernel : Pat<(MHLO_ConvolutionOp:$conv $kernel, (PermForReLayout (KernelLayout $dnums), - (TFLNativeKernelLayout $dnums)) + (TFLNativeStandardConvKernelLayout $dnums)) ), $strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, - (CloneDnumsWithTFLNativeKernelLayout $dnums), + (CloneDnumsWithTFLNativeStandardConvKernelLayout $dnums), $feature_groups, $batch_groups, $precision_config ), [(AreDnumsFullyDefined $conv), (KernelHasIotaSpatials $dnums), - (IsKernelNotTFLNativeLayout $dnums), - (IsStandardConv $conv)]>; + (IsKernelNotTFLNativeStandardConvLayout $dnums), + (IsSupportedStandardOrNonTrivialConv $conv)], + [], + (addBenefit 1)>; + +// +// depthwise conv kernel = [i, spatials..., o]. +//=----- + +def IsKernelNotTFLNativeDepthwiseLayout : Constraint>; + +def TFLNativeDepthwiseConvKernelLayout : NativeCodeCall< + "GetTFLNativeDepthwiseConvKernelLayout()">; + +def CloneDnumsWithTFLNativeDepthwiseConvKernelLayout : NativeCodeCall< + "CloneDnumsWithKernelLayout(" + "$_builder," + "$0," + "GetTFLNativeDepthwiseConvKernelLayout())">; + + +def ReLayoutConvKernelDepthwise : Pat<(MHLO_ConvolutionOp:$conv + $input, + $kernel, + $strides, + $padding, + $lhs_dilation, + $rhs_dilation, + $window_reversal, + $dnums, + $feature_groups, + $batch_groups, + $precision_config + ), + (MHLO_ConvolutionOp + $input, + (MHLO_TransposeOp + $kernel, + (PermForReLayout + (KernelLayout $dnums), + (TFLNativeDepthwiseConvKernelLayout)) + ), + $strides, + $padding, + $lhs_dilation, + $rhs_dilation, + $window_reversal, + (CloneDnumsWithTFLNativeDepthwiseConvKernelLayout $dnums), + $feature_groups, + $batch_groups, + $precision_config + ), + [(AreDnumsFullyDefined $conv), + (KernelHasIotaSpatials $dnums), + (IsKernelNotTFLNativeDepthwiseLayout $dnums), + (IsDepthwiseConv $conv)], + [], + (addBenefit 1)>; + +// // Re-layout output to [b, spatials..., f] -//===------------------------------------- +//===--------------------------------------------------------------------------- def IsOutputNotTFLNativeLayout : Constraint; + (IsSupportedConv $conv)]>; + + +// Pull out non-trivial padding into separate explicit pad_op. +// +// This has the benifit of allowing for a single point of control +// for turning negative padding into slices. TFL convs can fuse +// "SAME" padding back in post-legalization. Note when lhs dilations +// are non-trivial, the mhlo.convolution has the semantics of a deconvolution. +// In this case padding is interpreted differently and so we leave it in the op. +//===--------------------------------------------------------------------------- + +// Given DenseElements (i64), check if they are all equal to "val". +class AreI64ElementsAll : Constraint()," + "[](auto v) { return v == "# val #"; })">>; + +class AreI64ElementsNotAll : + Constraint.predicate>>; + +// Gets a tuple of DenseElements (i64) given result from mhlo.convolution. +def GetExplicitPaddingArgs : NativeCodeCall< + "GetExplicitPaddingArgs($_builder," + "$0.getDefiningOp())">; + +// Gets element type from Value. +def GetElementType : NativeCodeCall< + "$0.getType().cast().getElementType()">; + +// Given element type, get a DenseElements with scalar shape and 0 value. +def GetZeroScalarAttrFromType : NativeCodeCall< + "$_builder.getZeroAttr(" + "RankedTensorType::get({}, $0))">; + +// Given padding attr, get new padding attr for trivial (no) padding. +def GetTrivialPaddingAttr : NativeCodeCall< + "$_builder.getZeroAttr($0.getType())">; + +// Given mhlo.convolution result, build an explicit mhlo.pad op +// which is semantically equivalant. +def ExplicitlyPadInput : NativeCodeCall< + "CreatePadOpFromConvPadding($_builder," + "$0.getDefiningOp())">; + +def UnfuseConvWithExplicitPadding : Pat<(MHLO_ConvolutionOp:$conv + $input, + $kernel, + $strides, + $padding, + $lhs_dilation, + $rhs_dilation, + $window_reversal, + $dnums, + $feature_groups, + $batch_groups, + $precision_config + ), + (MHLO_ConvolutionOp + (ExplicitlyPadInput $conv), + $kernel, + $strides, + (GetTrivialPaddingAttr $padding), + $lhs_dilation, + $rhs_dilation, + $window_reversal, + $dnums, + $feature_groups, + $batch_groups, + $precision_config + ), + [(AreDnumsFullyDefined $conv), + (KernelHasIotaSpatials $dnums), + (IsStandardOrDepthwiseConv $conv), + (AreI64ElementsNotAll<0> $padding)]>; + //===------------------------------------------------------------------------=== diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index 8e22b343d7f3d0..e8a2bc870e960d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -148,6 +148,7 @@ class TflToStablehloPass case flexbuffers::FBT_VECTOR_INT: { const auto& vector = value.AsTypedVector(); std::vector vec; + vec.reserve(vector.size()); for (size_t i = 0; i < vector.size(); i++) { vec.push_back(vector[i].AsInt64()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 05b1982a85e4b9..6cd284a73dd576 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -14,12 +14,16 @@ limitations under the License. ==============================================================================*/ // The kept headers are provided for the included file `passes.h.inc`. +#include #include #include #include +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -35,8 +39,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/iota.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/pad.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce_window.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/slice.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/sort.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" // IWYU pragma: keep @@ -47,9 +55,137 @@ namespace mlir { namespace odml { namespace { +// Returns the shape of the given value in a Constant Op. +arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) { + ArrayRef shape = mlir::cast(value.getType()).getShape(); + auto attr_type = RankedTensorType::get({static_cast(shape.size())}, + rewriter.getIntegerType(64)); + auto attr = DenseElementsAttr::get(attr_type, shape); + return rewriter.create(value.getLoc(), attr_type, attr); +} + +bool IsSign(APInt a, APInt sign) { + if (a.isZero()) return a == sign; + if (a.isNegative()) return sign == -1; + return sign == 1; +} + +bool IsSign(APFloat a, APFloat sign) { + if (a.isNaN() || a.isZero()) return a == sign; + if (a.isNegative()) return sign.isExactlyValue(-1.0); + return sign.isExactlyValue(1.0); +} + +bool IsDenseSplatIntAttr(ElementsAttr float_or_int) { + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); +} + +bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) { + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); +} + +bool ValueEquals(ElementsAttr float_or_int, double rhs) { + if (IsDenseSplatFloatAttr(float_or_int)) { + return mlir::cast(float_or_int) + .getSplatValue() + .isExactlyValue(rhs); + } else if (IsDenseSplatIntAttr(float_or_int)) { + return mlir::cast(float_or_int).getSplatValue() == + static_cast(rhs); + } + return false; +} + +// Returns whether the splat constant is the sign of the int or float Tensor. +bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int, + ElementsAttr sgn_cst) { + auto sgn_splat = llvm::dyn_cast(sgn_cst); + if (!sgn_splat) return false; + + auto splat = dyn_cast(float_or_int); + if (auto float_spl = llvm::dyn_cast_if_present(splat), + sgn_cst_spl = llvm::dyn_cast_if_present(sgn_splat); + float_spl && sgn_cst_spl) { + return IsSign(float_spl.getValue(), sgn_cst_spl.getValue()); + } + if (auto int_spl = llvm::dyn_cast_if_present(splat), + sgn_cst_spl = llvm::dyn_cast_if_present(sgn_splat); + int_spl && sgn_cst_spl) { + return IsSign(int_spl.getValue(), sgn_cst_spl.getValue()); + } + if (mlir::isa(float_or_int)) { + auto sgn_splat_value = sgn_splat.getSplatValue(); + return llvm::all_of(float_or_int.getValues(), [&](APFloat value) { + return IsSign(value, sgn_splat_value); + }); + } + if (mlir::isa(float_or_int)) { + auto sgn_splat_value = sgn_splat.getSplatValue(); + return llvm::all_of(float_or_int.getValues(), [&](APInt value) { + return IsSign(value, sgn_splat_value); + }); + } + return false; +} + +bool SameTypeOrDefaultCompare(mhlo::ComparisonTypeAttr comparison_type_attr, + ElementsAttr cst) { + if (!comparison_type_attr) return true; + auto comparison_type_attr_value = comparison_type_attr.getValue(); + if (comparison_type_attr_value == mhlo::ComparisonType::FLOAT && + IsDenseSplatFloatAttr(cst)) { + return true; + } + if ((comparison_type_attr_value == mhlo::ComparisonType::SIGNED || + comparison_type_attr_value == mhlo::ComparisonType::UNSIGNED) && + IsDenseSplatIntAttr(cst)) { + return true; + } + return false; +} + +bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) { + if (IsDenseSplatFloatAttr(float_or_int) && + IsDenseSplatFloatAttr(float_or_int)) { + return (mlir::cast(float_or_int) + .getSplatValue() * + mlir::cast(rhs).getSplatValue()) + .isExactlyValue(1.0); + } else if (IsDenseSplatIntAttr(float_or_int) && + IsDenseSplatIntAttr(float_or_int)) { + return (mlir::cast(float_or_int).getSplatValue() * + mlir::cast(rhs).getSplatValue()) == 1; + } + return false; +} + +bool ValueGreaterThanZero(ElementsAttr float_or_int) { + if (IsDenseSplatIntAttr(float_or_int)) { + auto value = + mlir::cast(float_or_int).getSplatValue(); + return !value.isNegative() && !value.isZero(); + } else if (IsDenseSplatFloatAttr(float_or_int)) { + auto value = + mlir::cast(float_or_int).getSplatValue(); + return !value.isNaN() && !value.isNegative() && !value.isZero(); + } + return false; +} + #define GEN_PASS_DEF_LEGALIZEHLOTOTFLITEPASS #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" +bool SupportedComparisonType(mhlo::ComparisonTypeAttr comp_type) { + if (!comp_type) return true; + auto c_ty = comp_type.getValue(); + return c_ty == mhlo::ComparisonType::FLOAT || + c_ty == mhlo::ComparisonType::SIGNED || + c_ty == mhlo::ComparisonType::UNSIGNED || + c_ty == mhlo::ComparisonType::NOTYPE; +} + class LegalizeHloToTfLitePass : public impl::LegalizeHloToTfLitePassBase { public: @@ -62,10 +198,88 @@ std::optional IsCbrtLegal(mhlo::CbrtOp op) { return !op.getType().getElementType().isF32(); } +bool IsNotOpLegal(mhlo::NotOp op) { + return op.getType().getElementType().isInteger(64); +} + +// Mark possible target ops from rounding patterns as having "unknown" +// legality. This is required to schedule patterns on these ops even +// though MhloDialect is explicitly marked legal (which cannot be changed +// easily). +void AddRoundingOpsAsUnknown(ConversionTarget& target) { + target.addDynamicallyLegalOp< + // go/keep-sorted start + // clang-format off + mhlo::AddOp, + mhlo::BroadcastInDimOp, + mhlo::ConstantOp, + mhlo::DivOp, + mhlo::FloorOp, + mhlo::MulOp, + mhlo::RemOp, + mhlo::RoundOp, + mhlo::SelectOp, + mhlo::SignOp, + mhlo::SubtractOp, + mhlo::TupleOp + // clang-format on + // go/keep-sorted end + >([](Operation* op) { return std::nullopt; }); +} +bool IsCompareLegal(mhlo::CompareOp op) { + return !SupportedComparisonType(op.getCompareTypeAttr()); +} + +void SetUnaryOpLegal(ConversionTarget& target) { + auto is_legal = [](Operation* op) { + return !llvm::cast(op->getOperand(0).getType()) + .getElementType() + .isIntOrFloat(); + }; + target.addDynamicallyLegalOp< + // go/keep-sorted start + // clang-format off + mhlo::AbsOp, + mhlo::BitcastConvertOp, + mhlo::CeilOp, + mhlo::ConvertOp, + mhlo::CosineOp, + mhlo::ExpOp, + mhlo::Expm1Op, + mhlo::FloorOp, + mhlo::ImagOp, + mhlo::IsFiniteOp, + mhlo::Log1pOp, + mhlo::LogOp, + mhlo::LogisticOp, + mhlo::NegOp, + mhlo::RealOp, + mhlo::RsqrtOp, + mhlo::SignOp, + mhlo::SineOp, + mhlo::SqrtOp, + mhlo::TanhOp + // clang-format on + // go/keep-sorted end + >(is_legal); +} + +// mhlo "bitwise ops" can be both bitwise (floats/ints) or logical (bools). +// TFL ops are only one of logical or bitwise. +void SetBinaryBitwiseLegal(ConversionTarget& target) { + auto is_logical = [](Operation* op) { + return llvm::cast(op->getResultTypes()[0]) + .getElementType() + .isInteger(1); + }; + auto is_bitwise = [&](Operation* op) { return !is_logical(op); }; + target.addDynamicallyLegalOp(is_bitwise); + target.addDynamicallyLegalOp(is_logical); +} + #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_tflite_legalize_hlo.inc" void LegalizeHloToTfLitePass::runOnOperation() { MLIRContext* context = &getContext(); - RewritePatternSet patterns(context); patterns.add(context); populateWithGenerated(patterns); @@ -73,14 +287,44 @@ void LegalizeHloToTfLitePass::runOnOperation() { ConversionTarget target(*context); target.addLegalDialect(); target.addLegalOp(); + target.addDynamicallyLegalOp(IsCustomCallLegal); target.addDynamicallyLegalOp(IsCbrtLegal); - target.addIllegalOp(); + target.addDynamicallyLegalOp(IsNotOpLegal); + target.addDynamicallyLegalOp(IsCompareLegal); + + target.addIllegalOp< + // go/keep-sorted start + // clang-format off + mhlo::ClampOp, + mhlo::DotGeneralOp, + mhlo::DotOp, + mhlo::DynamicReshapeOp, + mhlo::MaxOp, + mhlo::MinOp, + mhlo::MulOp, + mhlo::PowOp, + mhlo::RemOp, + mhlo::ReshapeOp, + mhlo::ShiftRightArithmeticOp, + mhlo::ShiftRightLogicalOp, + mhlo::TransposeOp + // clang-format on + // go/keep-sorted end + >(); + + AddRoundingOpsAsUnknown(target); + SetUnaryOpLegal(target); + SetBinaryBitwiseLegal(target); PopulatePadPatterns(context, patterns, target); PopulateReducePatterns(context, patterns, target); + PopulateLegalizeReduceWindowPatterns(context, patterns, target); PopulateGatherPatterns(context, patterns, target); - PopulateConvPatterns(context, patterns, target); + PopulateLegalizeConvPatterns(context, patterns, target); + PopulateLegalizeSlicePatterns(context, patterns, target); + PopulateSortPatterns(context, patterns, target); + PopulateIotaPatterns(context, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index 671115d5c318f0..55e76560da365f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -13,13 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -include "mlir/IR/OpBase.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" -include "mhlo/IR/hlo_ops.td" include "mlir/IR/CommonAttrConstraints.td" -include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/CommonTypeConstraints.td" +include "mlir/IR/OpBase.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td" -include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" +include "mhlo/IR/hlo_ops.td" + + +def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">; def CreateTFLCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; @@ -28,6 +33,143 @@ def LegalizeTranspose : Pat<(MHLO_TransposeOp $arg, $perm), (TFL_TransposeOp $arg, (CreateTFLCastToInt32Op (TFL_ConstOp $perm)))>; +def LegalizeReshape : Pat<(MHLO_ReshapeOp:$output $input), + (TFL_ReshapeOp $input, + (CreateTFLCastToInt32Op (ShapeToConst $output)))>; + +def LegalizeDynamicReshape : Pat<(MHLO_DynamicReshapeOp $input, $shape), + (TFL_ReshapeOp $input, (CreateTFLCastToInt32Op $shape))>; + +//===----------------------------------------------------------------------===// +// logical and bitwise ops +//===----------------------------------------------------------------------===// + +class GetRankedScalarAttr : + NativeCodeCall<"DenseElementsAttr::get<" # prefix # "int" # width # "_t>(" + "RankedTensorType::get({}, $_builder.getIntegerType(" + # width # signed # "))," # value # ")">; + +def : Pat<(MHLO_NotOp I1Tensor:$input), (TFL_LogicalNotOp $input)>; + +// TFL does not support bitwise negation. not(x) is equivalant to xor(x, y) if +// y has a 1 in every bit position (xor(1, 1) = 0 and xor(0, 1) = 1). + +// Signed: The 2s complement of -1 has a 1 in every bit position. +def : Pat<(MHLO_NotOp I8Tensor:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"", 8, "", "-1">)))>; + +def : Pat<(MHLO_NotOp I16Tensor:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"", 16, "", "-1">)))>; + +def : Pat<(MHLO_NotOp I32Tensor:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"", 32, "", "-1">)))>; + + +// Unsigned: 0xFFF... has a 1 in every bit position. +def : Pat<(MHLO_NotOp TensorOf<[UI8]>:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"u", 8, ", false", "0xFFU">)))>; + +def : Pat<(MHLO_NotOp TensorOf<[UI16]>:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"u", 16, ", false", "0xFFFFU">)))>; + +def : Pat<(MHLO_NotOp TensorOf<[UI32]>:$input), + (TFL_BitwiseXorOp $input, + (Arith_ConstantOp + (GetRankedScalarAttr<"u", 32, ", false", "0xFFFFFFFFUL">)))>; + +foreach pair = [ + [MHLO_AndOp, TFL_LogicalAndOp], + [MHLO_OrOp, TFL_LogicalOrOp], +] in { + def : Pat< + (pair[0] TFL_BoolTensor:$l, TFL_BoolTensor:$r), + (pair[1] $l, $r)>; +} + +def LegalizeXor : Pat< + (MHLO_XorOp + TFL_IntTensor:$l, + TFL_IntTensor:$r), + (TFL_BitwiseXorOp $l, $r)>; + +//===----------------------------------------------------------------------===// +// binary element-wise ops +//===----------------------------------------------------------------------===// + +def : Pat< + (MHLO_ShiftRightArithmeticOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_ShiftRightLogicalOp $l, $r), + (TFL_RightShiftOp $l, $r)>; + +def : Pat< + (MHLO_RemOp $l, $r), + (TFL_FloorModOp $l, $r)>; + +// Binary ops with no attrs. +foreach pair = [ + [MHLO_MaxOp, TFL_MaximumOp], + [MHLO_MinOp, TFL_MinimumOp], + [MHLO_PowOp, TFL_PowOp], +] in { + def : Pat< + (pair[0] $l, $r), + (pair[1] $l, $r)>; +} + +// Binary ops with fused activiation attr. +foreach pair = [ + [MHLO_MulOp, TFL_MulOp], +] in { + def : Pat< + (pair[0] $l, $r), + (pair[1] $l, $r, TFL_AF_None)>; +} + + + +//===----------------------------------------------------------------------===// +// comparison ops +//===----------------------------------------------------------------------===// + +// Check implicit bool cast of `$_self` to ensure Attribute is non-null before +// casting. +def HasSupportedComparisonType : AttrConstraint< + CPred<"!$_self || SupportedComparisonType($_self.cast())">>; + +class MHLO_ComparisonDirectionValue : + ConstantAttr; + +foreach p = [ + [TFL_EqualOp, MHLO_ComparisonDirectionValue<"EQ">], + [TFL_NotEqualOp, MHLO_ComparisonDirectionValue<"NE">], + [TFL_GreaterEqualOp, MHLO_ComparisonDirectionValue<"GE">], + [TFL_LessEqualOp, MHLO_ComparisonDirectionValue<"LE">], + [TFL_GreaterOp, MHLO_ComparisonDirectionValue<"GT">], + [TFL_LessOp, MHLO_ComparisonDirectionValue<"LT">]] +in { + def : Pat< + (MHLO_CompareOp $l, $r, p[1], HasSupportedComparisonType), + (p[0] $l, $r)>; +} + +//===----------------------------------------------------------------------===// +// unary element-wise op +//===----------------------------------------------------------------------===// + def LowerCbrt : Pat<(MHLO_CbrtOp $opr), (TFL_PowOp $opr, (TFL_DivOp @@ -35,3 +177,436 @@ def LowerCbrt : Pat<(MHLO_CbrtOp $opr), (Arith_ConstantOp ConstantAttr, "3.0f">), TFL_AF_None)), [(F32Tensor $opr)]>; + + +foreach pair = [ + [MHLO_AbsOp, TFL_AbsOp], + [MHLO_BitcastConvertOp, TFL_BitcastOp], + [MHLO_CeilOp, TFL_CeilOp], + [MHLO_CosineOp, TFL_CosOp], + [MHLO_ExpOp, TFL_ExpOp], + [MHLO_FloorOp, TFL_FloorOp], + [MHLO_ImagOp, TFL_ImagOp], + [MHLO_LogOp, TFL_LogOp], + [MHLO_LogisticOp, TFL_LogisticOp], + [MHLO_NegOp, TFL_NegOp], + [MHLO_RealOp, TFL_RealOp], + [MHLO_RsqrtOp, TFL_RsqrtOp], + [MHLO_SineOp, TFL_SinOp], + [MHLO_SignOp, TFL_SignOp], + [MHLO_SqrtOp, TFL_SqrtOp], + [MHLO_TanhOp, TFL_TanhOp] +] in { + def : Pat< + (pair[0] $input), + (pair[1] $input)>; +} + +def : Pat< + (MHLO_ConvertOp $input), + (TFL_CastOp $input)>; + +def : Pat< + (MHLO_Expm1Op F32Tensor:$x), + (TFL_SubOp + (TFL_ExpOp $x), + (Arith_ConstantOp + ConstantAttr, "1.0f">), + TFL_AF_None)>; + +def : Pat< + (MHLO_IsFiniteOp F32Tensor:$x), + (TFL_EqualOp + (TFL_SubOp $x, $x, TFL_AF_None), + (Arith_ConstantOp + ConstantAttr, "0.0f">))>; + +def : Pat< + (MHLO_Log1pOp F32Tensor:$x), + (TFL_LogOp + (TFL_AddOp + $x, + (Arith_ConstantOp + ConstantAttr, "1.0f">), + TFL_AF_None))>; + +//===----------------------------------------------------------------------===// +// rounding +//===----------------------------------------------------------------------===// + +class ValueEquals : + Constraint>; + +def SameValue : + Constraint>; + +def FloatOrDefaultCompare : + Constraint>; + +def SameTypeOrDefaultCompare : + Constraint>; + +def ValueIsReciprocal : + Constraint>; + +def TensorIsSign : + Constraint>; + +def ValueGreaterThanZero : + Constraint>; + + +// Converts a dag of HLOs representing banker rounding (round x.5 to nearest +// even) to tfl.round. This only supports float types because mhlo.floor only +// supports float types. tf.round with integer input type will become an +// identity op, so we will never face an mhlo.floor with an integer input type. +// The pattern matched executes the following computation: +// frac = x - floor(x) +// to_even = (floor(x) - 2 * floor(0.5 * x)) == 1 +// if frac > 0.5 || (frac == 0.5 && to_even) +// return floor(x) + 1 +// else +// return floor(x) +def Round : Pat<(MHLO_SelectOp + (MHLO_OrOp + (MHLO_CompareOp (MHLO_SubtractOp:$frac + $input, + (MHLO_FloorOp:$floor $input)), + (MHLO_ConstantOp $half), + MHLO_ComparisonDirectionValue<"GT">, + $compare_type0), + (MHLO_AndOp + (MHLO_CompareOp + $frac1, + (MHLO_ConstantOp $half1), + MHLO_ComparisonDirectionValue<"EQ">, + $compare_type1), + (MHLO_CompareOp + (MHLO_SubtractOp + $floor1, + (MHLO_MulOp + (MHLO_FloorOp (MHLO_MulOp $input, (MHLO_ConstantOp $half2))), + (MHLO_ConstantOp $two))), + (MHLO_ConstantOp $one1), + MHLO_ComparisonDirectionValue<"EQ">, + $compare_type2))), + (MHLO_AddOp $floor2, (MHLO_ConstantOp $one)), + $floor3), + (TFL_RoundOp $input), + [(ValueEquals<"1.0"> $one), + (ValueEquals<"1.0"> $one1), + (ValueEquals<"2.0"> $two), + (ValueEquals<"0.5"> $half), + (ValueEquals<"0.5"> $half1), + (ValueEquals<"0.5"> $half2), + (SameValue $floor, $floor1), + (SameValue $floor, $floor2), + (SameValue $floor, $floor3), + (SameValue $frac, $frac1), + (FloatOrDefaultCompare $compare_type0), + (FloatOrDefaultCompare $compare_type1), + (FloatOrDefaultCompare $compare_type2)]>; + +// Converts a dag of HLOs representing floor_mod to tfl.floor_mod. +// The pattern matched executes the following computation: +// +// rem = remainder(arg0, arg1) +// for i in 0 to len(arg1): +// if ((rem[i] < 0) != (arg0[i] < 0) && arg0[i] != 0) +// rem[i] += arg1[i] +// return rem +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"LT">, + $compare_type), + (MHLO_CompareOp:$arg1ltz $arg1, (MHLO_ConstantOp $cst1), MHLO_ComparisonDirectionValue<"LT">, $compare_type1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type2), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, $arg1), + $rem3), + (TFL_FloorModOp $arg, $arg1), + [(ValueEquals<"0.0"> $cst), + (ValueEquals<"0.0"> $cst1), + (ValueEquals<"0.0"> $cst2), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $rem, $rem3), + (SameTypeOrDefaultCompare $compare_type, $cst), + (SameTypeOrDefaultCompare $compare_type1, $cst1)]>; + +// Converts a dag of HLOs representing floor_mod with a constant to +// tfl.floor_mod. The pattern matched executes the following computation: +// +// cst = value that is > 0 +// rem = remainder(arg0, cst) +// for i in 0 to len(arg1): +// if (rem[i] < 0 && rem[i] != 0) +// rem[i] += cst +// return rem +def : Pat<(MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp:$rltz + (MHLO_RemOp:$rem $arg, (MHLO_ConstantOp $cst)), + (MHLO_ConstantOp $cst1), + MHLO_ComparisonDirectionValue<"LT">, + $compare_type), + (MHLO_CompareOp:$rnz $rem1, (MHLO_ConstantOp $cst2), MHLO_ComparisonDirectionValue<"NE">, $compare_type3)), + (MHLO_AddOp $rem2, (MHLO_ConstantOp $cst3)), + $rem3), + (TFL_FloorModOp $arg, (Arith_ConstantOp $cst3)), + [(ValueGreaterThanZero $cst), + (ValueEquals<"0.0"> $cst1), + (ValueEquals<"0.0"> $cst2), + (SameValue $cst, $cst3), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $rem, $rem3), + (SameTypeOrDefaultCompare $compare_type, $cst1), + (SameTypeOrDefaultCompare $compare_type3, $cst2)]>; + +// Converts a dag of HLOs representing floor_div to tfl.floor_div. +// The pattern matched executes the following computation: +// +// rem = remainder(arg0, arg1) +// for i in 0 to len(arg1): +// rem[i] = arg0[i] - rem[i] / arg1[i] +// if (rem[i] != 0 && sign(arg1[i]) != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 sn sn1 - $1 +// / | | | / | +// $0 $1 $1 rem $0 rem +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, $arg1), + (MHLO_ConstantOp $cst), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_SignOp $arg1), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + $arg1b), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $arg1), + [(ValueEquals<"0.0"> $cst), + (ValueEquals<"-1.0"> $cst_neg1), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst), + (FloatOrDefaultCompare $compare_type1, $cst)]>; + +// Converts a dag of HLOs representing floor_div with a splat constant to +// tfl.floor_div. The pattern matched executes the following computation: +// This particular pattern matches multiplication with the reciprocal of the +// constant instead of dividing by the constant. +// rem = remainder(arg0, cst) +// for i in 0 to len(arg0): +// rem[i] = (arg0[i] - rem[i]) * 1 / cst +// if (rem[i] != 0 && sign(cst) != sign(rem[i])) +// rem[i] += -1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + mul +// / | / \ +// != != mul -1 +// / | / | / | +// rem 0.0 cs1 sn1 - cs2 +// / | | / | +// $0 cst rem $0 rem +// cs1 == sign(cst) +// cs2 = 1 / cst i.e. the reciprocal +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_MulOp:$mul + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cst_recip)), + (MHLO_ConstantOp $cst_neg1)), + $mul1)), + (TFL_FloorDivOp $arg0, $cst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (ValueIsReciprocal $cstv, $cst_recip), + (SameValue $mul, $mul1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + +// Converts a dag of HLOs representing floor_div with a splat constant to +// tfl.floor_div. The pattern matched executes the following computation: +// This particular pattern matches division with the constant. +// . +// rem = remainder(arg0, cst) +// for i in 0 to len(arg0): +// rem[i] = (arg0[i] - rem[i]) / cst +// if (rem[i] != 0 && sign(cst) != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 cs1 sn1 - cs2 +// / | | / | +// $0 cst rem $0 rem +// cs1 == sign(cst) +// cs2 = 1 / cst i.e. the reciprocal +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, (MHLO_ConstantOp:$cst $cstv)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + (MHLO_ConstantOp $cstv1)), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $cst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (SameValue $cstv1, $cstv), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + +// Converts a dag of HLOs representing floor_div with a broadcasted vector +// constant to tfl.floor_div. The pattern matched executes the following +// computation: +// scs = sign(cst) +// bcst = broadcast(cst) +// rem = remainder(arg0, bcst) +// for i in 0 to len(arg0): +// rem[i] = arg0[i] - rem[i] * / bcst +// if (rem[i] != 0 && scs != sign(rem[i])) +// rem[i] -= 1.0 +// return round_nearest_afz(rem) +// Where scs is a splat constant folded sign on the unbroadcasted tensor. +// +// As a dag this looks like the following: +// round +// | +// -------- select +// | | \ +// && + div +// / | / \ +// != != div -1 +// / | / | / | +// rem 0.0 scs sn1 - bcst +// / | | / | +// $0 bcst rem $0 rem +// | +// cst +// scs == sign(cst) == sign(bcst) +// Note that named operators like 'sn' and 'sn1' are different values produced by +// the same function in this case the sign function. Named values like 'div' +// refer to the same value produced by the same function, in this case division. +// Mathematical symbols do not indicate a re-use of the value. +def : Pat<(MHLO_RoundOp + (MHLO_SelectOp + (MHLO_AndOp + (MHLO_CompareOp + (MHLO_RemOp:$rem $arg0, + (MHLO_BroadcastInDimOp:$bcst + (MHLO_ConstantOp $cstv), + $broadcast_dimension)), + (MHLO_ConstantOp $cst_zero), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type), + (MHLO_CompareOp + (MHLO_ConstantOp $cst_sgn), + (MHLO_SignOp $rem1), + MHLO_ComparisonDirectionValue<"NE">, + $compare_type1)), + (MHLO_AddOp + (MHLO_DivOp:$div + (MHLO_SubtractOp $arg0, $rem2), + $bcst1), + (MHLO_ConstantOp $cst_neg1)), + $div1)), + (TFL_FloorDivOp $arg0, $bcst), + [(ValueEquals<"0.0"> $cst_zero), + (ValueEquals<"-1.0"> $cst_neg1), + (TensorIsSign $cstv, $cst_sgn), + (SameValue $bcst, $bcst1), + (SameValue $div, $div1), + (SameValue $rem, $rem1), + (SameValue $rem, $rem2), + (FloatOrDefaultCompare $compare_type, $cst_zero), + (FloatOrDefaultCompare $compare_type1, $cst_sgn)]>; + + +//===----------------------------------------------------------------------===// +// ternary op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(MHLO_ClampOp $min, $arg, $max), + (TFL_MaximumOp (TFL_MinimumOp $arg, $max), $min)>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc new file mode 100644 index 00000000000000..e717114610b527 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/build_stablehlo_composite_pass.cc @@ -0,0 +1,555 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "json/json.h" +#include "json/reader.h" +#include "json/value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Analysis/TopologicalSortUtils.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { + +#define GEN_PASS_DEF_BUILDSTABLEHLOCOMPOSITEPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { + +// Checks if this operation is a MarkTensor operation used to mark the +// boundaries of a composite. +static bool IsMarkTensorOp(mlir::Operation* op) { + if (op == nullptr) { + return false; + } + if (op->getNumOperands() != 1 || op->getNumResults() != 1) { + return false; + } + if (!llvm::isa(op)) { + return false; + } + auto target_name = + mlir::dyn_cast(op->getAttr("call_target_name")); + if (target_name == nullptr || target_name.str() != "mark_tensor") { + return false; + } + return true; +} + +struct BoundaryMetadata { + std::string name; + std::string id; + int64_t pos; + bool is_input; + std::unordered_map attrs; + + auto boundary_key() const { return absl::StrCat(name, "__@@__", id); } + + auto uid() const { return std::forward_as_tuple(name, id, pos, is_input); } + + bool operator==(const BoundaryMetadata& other) const { + return uid() == other.uid(); + } + bool operator<(const BoundaryMetadata& other) const { + return uid() < other.uid(); + } + + static std::unique_ptr Parse(llvm::StringRef str_ref) { + Json::Value root; + Json::Reader reader; + if (!reader.parse(str_ref.str(), root)) { + return nullptr; + } + return Build(root); + } + + private: + template + static bool CopyJsonValue(const Json::Value& json, llvm::StringRef key, + Json::ValueType expected_type, T* to) { + if (!json.isMember(key.str()) || json[key.str()].type() != expected_type) { + return false; + } + + *to = json[key.str()].as(); + return true; + } + + static std::unique_ptr Build(const Json::Value& json) { + BoundaryMetadata metadata; + + bool is_valid_metadata_json = + CopyJsonValue(json, "name", Json::stringValue, &metadata.name) && + CopyJsonValue(json, "id", Json::stringValue, &metadata.id) && + CopyJsonValue(json, "pos", Json::intValue, &metadata.pos) && + CopyJsonValue(json, "is_input", Json::booleanValue, &metadata.is_input); + + if (!is_valid_metadata_json) { + return nullptr; + } + + Json::Value attrs_value = json["attr"]; + if (attrs_value.type() == Json::objectValue) { + for (const auto& key_value : attrs_value.getMemberNames()) { + metadata.attrs.insert({key_value, attrs_value[key_value]}); + } + } + return std::make_unique(std::move(metadata)); + } +}; + +class BuildStableHLOCompositePass + : public impl::BuildStableHLOCompositePassBase< + BuildStableHLOCompositePass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(BuildStableHLOCompositePass); + + void runOnOperation() override { + mlir::ModuleOp module_op = getOperation(); + llvm::SmallVector func_ops( + module_op.getOps()); + for (mlir::func::FuncOp& func_op : func_ops) { + llvm::DenseMap op_order_map = + BuildOpOrderMap(func_op); + std::unordered_map> + boundary_output_ops_map = BuildBoundaryOutputOpsMap(func_op); + + for (const auto& [unused, ops] : boundary_output_ops_map) { + if (mlir::failed(BuildStableHLOComposite(ops, op_order_map))) { + func_op.emitError() << "failed to build composite."; + return signalPassFailure(); + } + } + } + + // Remove mark_tensor custom_call ops. + getOperation()->walk([](mlir::stablehlo::CustomCallOp op) { + if (!IsMarkTensorOp(op.getOperation())) { + return; + } + mlir::Value original_value = op.getOperand(0); + + for (mlir::Value result : op.getResults()) { + result.replaceAllUsesWith(original_value); + } + op.erase(); + }); + } + + private: + llvm::DenseMap BuildOpOrderMap( + mlir::func::FuncOp func_op) const { + llvm::DenseMap op_order_map; + for (const auto& op : llvm::enumerate(func_op.getOps())) { + op_order_map[&op.value()] = op.index(); + } + return op_order_map; + } + + std::unordered_map> + BuildBoundaryOutputOpsMap(mlir::func::FuncOp func_op) { + std::unordered_map> + boundary_output_ops; + + for (auto op : func_op.getOps()) { + auto metadata_or = GetBoundaryMetadata(op); + if (mlir::failed(metadata_or)) { + continue; + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + continue; + } + + auto& output_ops = boundary_output_ops[metadata->boundary_key()]; + if (metadata->pos >= output_ops.size()) { + output_ops.resize(metadata->pos + 1, nullptr); + } + output_ops[metadata->pos] = op.getOperation(); + } + return boundary_output_ops; + } + + mlir::FailureOr> GetBoundaryMetadata( + mlir::Operation* op) { + if (!IsMarkTensorOp(op)) { + return mlir::FailureOr>(nullptr); + } + auto backend_config = + mlir::dyn_cast(op->getAttr("backend_config")); + if (backend_config == nullptr) { + return mlir::FailureOr>(nullptr); + } + std::unique_ptr metadata = + BoundaryMetadata::Parse(backend_config); + if (metadata == nullptr) { + return op->emitError() << "invalid boundary metadata JSON."; + } + return metadata; + } + + mlir::FailureOr BuildAttrFromJson( + mlir::OpBuilder& builder, mlir::Operation* op, + const Json::Value& json_value) { + switch (json_value.type()) { + case Json::intValue: + case Json::uintValue: + return builder.getI64IntegerAttr(json_value.as()); + case Json::ValueType::realValue: + return builder.getF32FloatAttr(json_value.as()); + case Json::ValueType::booleanValue: + return builder.getBoolAttr(json_value.as()); + case Json::ValueType::stringValue: + return builder.getStringAttr(json_value.as()); + case Json::ValueType::arrayValue: { + if (json_value.empty()) { + return builder.getArrayAttr({}); + } + auto get_json_type = [](const Json::Value& json_value) { + auto ty = json_value.type(); + if (ty == Json::uintValue) { + return Json::intValue; + } + return ty; + }; + + auto head_type = get_json_type(json_value[0]); + bool is_homogeneous = llvm::all_of(json_value, [&](auto& el) { + return get_json_type(el) == head_type; + }); + if (!is_homogeneous) { + return op->emitError() + << "invalid JSON to MLIR, arrays must be homogeneous"; + } + + switch (head_type) { + case Json::intValue: { + llvm::SmallVector int_values; + for (const auto& json_value : json_value) { + int_values.push_back(json_value.as()); + } + return builder.getI64TensorAttr(int_values); + } + case Json::realValue: { + llvm::SmallVector float_values; + for (const auto& json_value : json_value) { + float_values.push_back(json_value.as()); + } + return mlir::DenseFPElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getF32Type()), + float_values); + } + case Json::booleanValue: { + llvm::SmallVector bool_values; + for (const auto& json_value : json_value) { + bool_values.push_back(json_value.as()); + } + return mlir::DenseIntElementsAttr::get( + mlir::RankedTensorType::get(json_value.size(), + builder.getI1Type()), + bool_values); + } + default: + return op->emitError() + << "invalid JSON to MLIR: invalid array type. arrays must " + "be " + "1-D homogeneous arrays of supported primitive types"; + } + } + default: + return op->emitError() + << "invalid JSON to MLIR: unsupported json value type"; + } + } + + mlir::FailureOr BuildDictionaryAttrFromJsonMap( + mlir::OpBuilder& builder, mlir::Operation* op, + const std::unordered_map& json_map) { + llvm::SmallVector named_attrs; + for (auto& [key, json] : json_map) { + mlir::FailureOr attribute_or = + BuildAttrFromJson(builder, op, json); + if (mlir::failed(attribute_or)) { + return mlir::failure(); + } + named_attrs.push_back({builder.getStringAttr(key), *attribute_or}); + } + return builder.getDictionaryAttr(named_attrs); + } + + mlir::LogicalResult BuildStableHLOComposite( + const llvm::SmallVector& output_ops, + const llvm::DenseMap& op_order_map) { + if (output_ops.empty()) { + return mlir::success(); + } + + // Get the output op with minimum order num as the representative. + mlir::Operation* first_output_op = output_ops[0]; + for (mlir::Operation* op : output_ops) { + if (op_order_map.at(op) < op_order_map.at(first_output_op)) { + first_output_op = op; + } + } + + auto metadata_or = GetBoundaryMetadata(first_output_op); + if (mlir::failed(metadata_or)) { + return mlir::failure(); + } + + std::unique_ptr metadata = std::move(*metadata_or); + if (metadata == nullptr || metadata->is_input) { + // There should always be a valid boundary output metadata associated with + // each op in output_ops. + return mlir::failure(); + } + + auto args_ops_or = + GetBoundaryArgsAndOps(output_ops, *metadata, op_order_map); + if (mlir::failed(args_ops_or)) { + return mlir::failure(); + } + + auto [args, impl_ops] = *args_ops_or; + + mlir::func::FuncOp impl_func = BuildStableHLOCompositeImplFunc( + output_ops, absl::StrCat(metadata->name, ".impl"), args, impl_ops); + mlir::FailureOr composite_op_or = + BuildStableHLOCompositeOp(first_output_op, impl_func, args, *metadata); + if (mlir::failed(composite_op_or)) { + return mlir::failure(); + } + mlir::Operation* composite_op = *composite_op_or; + + // Updates all users of this op's result(s) to use the results(s) of impl + // func call. + size_t composite_result_i = 0; + for (mlir::Operation* op : output_ops) { + for (size_t i = 0; i < op->getNumResults(); ++i) { + mlir::OpResult result = op->getResult(i); + result.replaceAllUsesWith( + composite_op->getResult(composite_result_i++)); + } + } + + if (!mlir::sortTopologically(composite_op->getBlock())) { + composite_op->emitError() + << "The graph is not acyclic after BuildStableHLOCompositePass pass."; + return mlir::failure(); + } + // The unused impl_ops will be eliminated with canonicalizer. + return mlir::success(); + } + + mlir::FailureOr, + llvm::SmallVector>> + GetBoundaryArgsAndOps( + const llvm::SmallVector boundary_output_ops, + const BoundaryMetadata& metadata, + const llvm::DenseMap& op_order_map) { + llvm::SetVector impl_ops_setvec; + llvm::SetVector> arg_pos_setvec; + llvm::SmallVector processing(boundary_output_ops.begin(), + boundary_output_ops.end()); + + // Reverse graph traversal: from boundary output op to boundary input op, + // global function arg, or stablehlo constant. + while (!processing.empty()) { + mlir::Operation* curr_op = processing.back(); + processing.pop_back(); + if (impl_ops_setvec.contains(curr_op)) { + continue; + } + + auto curr_metadata_or = GetBoundaryMetadata(curr_op); + if (mlir::failed(curr_metadata_or)) { + return mlir::failure(); + } + std::unique_ptr curr_metadata = + std::move(*curr_metadata_or); + if (curr_metadata != nullptr) { + if (curr_metadata->is_input && + curr_metadata->boundary_key() == metadata.boundary_key()) { + // Terminal condition: boundary input op. + + arg_pos_setvec.insert( + {mlir::dyn_cast(curr_op->getResult(0)), + curr_metadata->pos}); + continue; + } + } + + impl_ops_setvec.insert(curr_op); + for (mlir::Value value : curr_op->getOperands()) { + mlir::Operation* def_op = value.getDefiningOp(); + if (def_op == nullptr) { + // Terminal condition: global function arg + arg_pos_setvec.insert({value, std::numeric_limits::max()}); + } else if (llvm::isa(def_op)) { + // Terminal condition: constant + impl_ops_setvec.insert(def_op); + } else { + processing.push_back(def_op); + } + } + } + // Sorts all ops within the boundary by their line numbers in the input + // MLIR. The ops will be duplicated to the impl function following this + // order. + llvm::SmallVector impl_ops = impl_ops_setvec.takeVector(); + for (auto& op : impl_ops) { + if (!op_order_map.contains(op)) { + return op->emitError() + << "does not have a ordering number in its outer func."; + } + } + std::sort(impl_ops.begin(), impl_ops.end(), + [&op_order_map](const auto& a, const auto& b) { + return op_order_map.at(a) < op_order_map.at(b); + }); + + // Sorts boundary args by their positions. Note that the args of the + // composite and impl function may be more than the boundary inputs, because + // the MLIR is lowered from the functionalized graph and additional args may + // be Pytorch constants. In such case the position of those args would be + // undetermined, while they would always come after boundary inputs. + auto arg_pos_pairs = arg_pos_setvec.takeVector(); + std::stable_sort( + arg_pos_pairs.begin(), arg_pos_pairs.end(), + [](const auto& a, const auto& b) { return a.second < b.second; }); + llvm::SmallVector args; + args.reserve(arg_pos_pairs.size()); + for (auto& [arg, unused] : arg_pos_pairs) { + args.push_back(arg); + } + + return std::make_pair(std::move(args), std::move(impl_ops)); + } + + mlir::func::FuncOp BuildStableHLOCompositeImplFunc( + const llvm::SmallVector boundary_output_ops, + llvm::StringRef func_name, const llvm::SmallVector& args, + const llvm::SmallVector& impl_ops) { + mlir::ModuleOp module_op = getOperation(); + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + // Creates composite impl function and duplicates all ops within the + // boundary in the function. + llvm::SmallVector arg_locs; + llvm::SmallVector arg_types; + for (auto& arg : args) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + llvm::SmallVector result_types; + for (mlir::Operation* op : boundary_output_ops) { + result_types.append(op->getResultTypes().begin(), + op->getResultTypes().end()); + } + + mlir::func::FuncOp impl_func = builder.create( + module_op.getLoc(), func_name, + mlir::FunctionType::get(context, arg_types, result_types)); + mlir::IRMapping mapping; + builder.createBlock(&impl_func.getBody(), impl_func.begin(), arg_types, + arg_locs); + for (const auto& arg : llvm::enumerate(args)) { + mapping.map(arg.value(), impl_func.getArgument(arg.index())); + } + for (mlir::Operation* original_op : impl_ops) { + mlir::Operation* cloned_op = builder.clone(*original_op, mapping); + mapping.map(original_op, cloned_op); + } + + llvm::SmallVector results; + for (mlir::Operation* op : boundary_output_ops) { + results.append(mapping.lookup(op)->getResults().begin(), + mapping.lookup(op)->getResults().end()); + } + builder.create(impl_func.getBody().getLoc(), results); + + // Adds the new function to symbol table. + mlir::SymbolTable symbol_table(module_op); + impl_func.setPrivate(); + symbol_table.insert(impl_func); + + return impl_func; + } + + mlir::FailureOr BuildStableHLOCompositeOp( + mlir::Operation* boundary_output_op, mlir::func::FuncOp impl_func, + const llvm::SmallVector& args, + const BoundaryMetadata& metadata) { + mlir::MLIRContext* context = &getContext(); + mlir::OpBuilder builder(context); + + mlir::FailureOr attributes_or = + BuildDictionaryAttrFromJsonMap(builder, boundary_output_op, + metadata.attrs); + if (mlir::failed(attributes_or)) { + return boundary_output_op->emitError() + << "failed to transform boundary attr " + "JSON into composite attributes."; + } + + // Creates and inserts composite call op. + builder.setInsertionPointAfter(boundary_output_op); + mlir::Operation* composite_op = + builder.create( + boundary_output_op->getLoc(), + impl_func.getFunctionType().getResults(), args, metadata.name, + *attributes_or, impl_func.getSymName()); + return composite_op; + } +}; + +} // namespace +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc new file mode 100644 index 00000000000000..89e23c6edc4a24 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/torch/lift_callsite_loc_caller_pass.cc @@ -0,0 +1,54 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" // IWYU pragma: keep + +namespace mlir { +namespace odml { +#define GEN_PASS_DEF_LIFTCALLSITELOCCALLERPASS +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc" + +namespace { + +// JAX bridge generates a func.call for each op lowering +// These are inlined but loc will be messed up after the inline pass. This pass +// normalize the loc after inline pass. + +class LiftCallSiteLocCallerPass + : public impl::LiftCallSiteLocCallerPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftCallSiteLocCallerPass); + + void runOnOperation() override { + getOperation()->walk([](func::FuncOp func_op) { + for (Operation& op : func_op.getOps()) { + if (!mlir::isa(op.getLoc())) { + continue; + } + + auto loc = op.getLoc().dyn_cast(); + op.setLoc(loc.getCaller()); + } + }); + } +}; + +} // namespace +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stateful_error_reporter.h b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h new file mode 100644 index 00000000000000..fbb82d3e54d121 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/stateful_error_reporter.h @@ -0,0 +1,36 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ + +// LINT.IfChange +#include + +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" + +namespace tflite_migration { + +// Similar to tflite::ErrorReporter, except that it allows callers to get the +// last error message. +class StatefulErrorReporter : public tflite::ErrorReporter { + public: + // Returns last error message. Returns empty string if no error is reported. + virtual std::string message() = 0; +}; + +} // namespace tflite_migration +// LINT.ThenChange(//tensorflow/lite/stateful_error_reporter.h) + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir index 9626a292b8eb6d..7b949d3d551151 100644 --- a/tensorflow/compiler/mlir/lite/tests/const-fold.mlir +++ b/tensorflow/compiler/mlir/lite/tests/const-fold.mlir @@ -151,6 +151,75 @@ func.func @mul_f16() -> (tensor, tensor<4xf16>, tensor<4xf16>, tensor<4xf16 func.return %5, %6, %7, %8 : tensor, tensor<4xf16>, tensor<4xf16>, tensor<4xf16> } +// CHECK-LABEL: @mul_zero +func.func @mul_zero(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %zero_int = arith.constant dense<0> : tensor<4xi32> + %zero_float = arith.constant dense<0.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %cst, %cst_0 + + %0 = "tfl.mul"(%arg0, %zero_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%arg1, %zero_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_zero_lhs +func.func @mul_zero_lhs(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %zero_int = arith.constant dense<0> : tensor<4xi32> + %zero_float = arith.constant dense<0.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %cst, %cst_0 + + %0 = "tfl.mul"(%zero_int, %arg0) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%zero_float, %arg1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one +func.func @mul_one(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %arg0, %arg1 + + %0 = "tfl.mul"(%arg0, %one_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%arg1, %one_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one_lhs +func.func @mul_one_lhs(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.mul + // CHECK: return %arg0, %arg1 + + %0 = "tfl.mul"(%one_int, %arg0) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.mul"(%one_float, %arg1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @mul_one_quant +func.func @mul_one_quant(%arg0: tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> { + %one = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<1> : tensor<32xi8>} : () -> tensor<32x!quant.uniform> + + // CHECK: %[[MUL:.*]] = tfl.mul + // CHECK: return %[[MUL]] + + %0 = "tfl.mul"(%one, %arg0) {fused_activation_function = "NONE"} : (tensor<32x!quant.uniform>, tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> + + func.return %0 : tensor<32x!quant.uniform> +} + + // CHECK-LABEL: @elementwise_unary_ops func.func @elementwise_unary_ops() -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor) { %0 = arith.constant dense<-1.0> : tensor @@ -191,6 +260,15 @@ func.func @max_with_neg_f32_max_val(%arg0 : tensor) -> (tensor, tensor // CHECK: return %[[ARG0]], %[[ARG0]] } +// CHECK-LABEL: @max_with_neg_inf +func.func @max_with_neg_inf(%arg0 : tensor) -> (tensor, tensor) { + %neg_inf = arith.constant dense<0xFF800000> : tensor + %0 = "tfl.maximum"(%arg0, %neg_inf) : (tensor, tensor) -> tensor + %1 = "tfl.maximum"(%neg_inf, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @min_with_f32_max_val // CHECK-SAME: (%[[ARG0:.+]]: tensor) func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) { @@ -201,6 +279,15 @@ func.func @min_with_f32_max_val(%arg0 : tensor) -> (tensor, tensor) -> (tensor, tensor) { + %inf = arith.constant dense<0x7F800000> : tensor + %0 = "tfl.minimum"(%arg0, %inf) : (tensor, tensor) -> tensor + %1 = "tfl.minimum"(%inf, %arg0) : (tensor, tensor) -> tensor + func.return %0, %1 : tensor, tensor + // CHECK: return %[[ARG0]], %[[ARG0]] +} + // CHECK-LABEL: @max_with_neg_f64_max_val // CHECK-SAME: (%[[ARG0:.+]]: tensor) func.func @max_with_neg_f64_max_val(%arg0 : tensor) -> (tensor, tensor) { @@ -672,6 +759,32 @@ func.func @div_dense_different_rank() -> tensor<1x2x2xf32> { // CHECK: return %[[CST]] } +// CHECK-LABEL: @div_one +func.func @div_one(%arg0: tensor<4xi32>, %arg1: tensor<4xf32>) -> (tensor<4xi32>, tensor<4xf32>) { + %one_int = arith.constant dense<1> : tensor<4xi32> + %one_float = arith.constant dense<1.0> : tensor<4xf32> + + // CHECK-NOT: tfl.div + // CHECK: return %arg0, %arg1 + + %0 = "tfl.div"(%arg0, %one_int) {fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + %1 = "tfl.div"(%arg1, %one_float) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %0, %1 : tensor<4xi32>, tensor<4xf32> +} + +// CHECK-LABEL: @div_one_quant +func.func @div_one_quant(%arg0: tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> { + %one = "tfl.pseudo_qconst"() {qtype = tensor<32x!quant.uniform>, value = dense<1> : tensor<32xi8>} : () -> tensor<32x!quant.uniform> + + // CHECK: %[[DIV:.*]] = tfl.div + // CHECK: return %[[DIV]] + + %0 = "tfl.div"(%arg0, %one) {fused_activation_function = "NONE"} : (tensor<32x!quant.uniform>, tensor<32x!quant.uniform>) -> tensor<32x!quant.uniform> + + func.return %0 : tensor<32x!quant.uniform> +} + // CHECK-LABEL: @rsqrt_bf16 func.func @rsqrt_bf16() -> tensor { %cst = arith.constant dense<4.0> : tensor @@ -779,6 +892,51 @@ func.func @cast_ui8_to_i1() -> tensor<4xi1> { // CHECK: return %[[CST]] } +// CHECK-LABEL: @cast_f32_to_i32 +func.func @cast_f32_to_i32() -> tensor<8xi32> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 0.99, 1.175494351e-38, 3.402823466e+38, -3.402823466e+38, -1.175494351e-38]> : tensor<8xf32> + %0 = "tfl.cast"(%cst) : (tensor<8xf32>) -> tensor<8xi32> + func.return %0 : tensor<8xi32> +} + +// CHECK: %cst = arith.constant dense<[-1, 0, 1, 0, 0, 2147483647, -2147483648, 0]> : tensor<8xi32> + +// CHECK-LABEL: @cast_i32_to_f32 +func.func @cast_i32_to_f32() -> tensor<5xf32> { + %cst = arith.constant dense<[-1, 0, 2, 2147483647, -2147483648]> : tensor<5xi32> + %0 = "tfl.cast"(%cst) : (tensor<5xi32>) -> tensor<5xf32> + func.return %0 : tensor<5xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 2.000000e+00, 2.14748365E+9, -2.14748365E+9]> : tensor<5xf32> + +// CHECK-LABEL: @cast_bool_to_f32 +func.func @cast_bool_to_f32() -> tensor<2xf32> { + %cst = arith.constant dense<[true, false]> : tensor<2xi1> + %0 = "tfl.cast"(%cst) : (tensor<2xi1>) -> tensor<2xf32> + func.return %0 : tensor<2xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 0.000000e+00]> : tensor<2xf32> + +// CHECK-LABEL: @cast_f64_to_f32 +func.func @cast_f64_to_f32() -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf64> + %0 = "tfl.cast"(%cst) : (tensor<4xf64>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf32> + +// CHECK-LABEL: @cast_f32_to_f64 +func.func @cast_f32_to_f64() -> tensor<4xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 1.5, 100.0]> : tensor<4xf32> + %0 = "tfl.cast"(%cst) : (tensor<4xf32>) -> tensor<4xf64> + func.return %0 : tensor<4xf64> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 1.500000e+00, 1.000000e+02]> : tensor<4xf64> + // CHECK-LABEL: @ConstantFoldFullyConnectedSmall func.func @ConstantFoldFullyConnectedSmall() -> tensor<3xf32> { %cst_input = arith.constant dense<[2.0, 3.0]> : tensor<2xf32> @@ -942,3 +1100,336 @@ func.func @ConstFoldEmbeddingLookup() -> (tensor<5x2xf32>, tensor<3x2x2xf32>) { // CHECK-DAG: %[[LOOKUP1:.*]] = arith.constant dense<{{\[\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]], {{\[\[}}5.000000e+00, 6.000000e+00], [7.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]]> : tensor<3x2x2xf32> // CHECK: return %[[LOOKUP0]], %[[LOOKUP1]] : tensor<5x2xf32>, tensor<3x2x2xf32> } + +// CHECK-LABEL: @less_int_both_splat +func.func @less_int_both_splat() -> tensor<4xi1> { + %0 = arith.constant dense<3> : tensor<4xi32> + %1 = arith.constant dense<10> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense : tensor<4xi1> + +// CHECK-LABEL: @less_int_one_splat +func.func @less_int_one_splat() -> tensor<4xi1> { + %0 = arith.constant dense<3> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK:%cst = arith.constant dense<[true, false, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @less_int +func.func @less_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_float +func.func @less_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.less"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_equal_int +func.func @less_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.less_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @less_equal_float +func.func @less_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.less_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, true]> : tensor<4xi1> + +// CHECK-LABEL: @greater_int +func.func @greater_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.greater"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_float +func.func @greater_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.greater"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_equal_int +func.func @greater_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.greater_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, true, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @greater_equal_float +func.func @greater_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.greater_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, true, true, false]> : tensor<4xi1> + +// CHECK-LABEL: @equal_int +func.func @equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @equal_float +func.func @equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[false, true, false, false]> : tensor<4xi1> + +// CHECK-LABEL: @not_equal_int +func.func @not_equal_int() -> tensor<4xi1> { + %0 = arith.constant dense<[11, 2, 0, 2]> : tensor<4xi32> + %1 = arith.constant dense<[10, 2, -1, 3]> : tensor<4xi32> + + %2 = "tfl.not_equal"(%0, %1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, true]> : tensor<4xi1> + +// CHECK-LABEL: @not_equal_float +func.func @not_equal_float() -> tensor<4xi1> { + %0 = arith.constant dense<[11.0, 2.0, 0.0, 2.0]> : tensor<4xf32> + %1 = arith.constant dense<[10.0, 2.0, -1.0, 3.0]> : tensor<4xf32> + + %2 = "tfl.not_equal"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + + func.return %2 : tensor<4xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true, true]> : tensor<4xi1> + +// CHECK-LABEL: @logical_or +func.func @logical_or() -> tensor<3xi1> { + %0 = arith.constant dense<[true, false, true]> : tensor<3xi1> + %1 = arith.constant dense<[false, false, true]> : tensor<3xi1> + + %2 = "tfl.logical_or"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + + func.return %2 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true]> : tensor<3xi1> + +// CHECK-LABEL: @logical_and +func.func @logical_and() -> tensor<3xi1> { + %0 = arith.constant dense<[true, false, true]> : tensor<3xi1> + %1 = arith.constant dense<[false, false, true]> : tensor<3xi1> + + %2 = "tfl.logical_and"(%0, %1) : (tensor<3xi1>, tensor<3xi1>) -> tensor<3xi1> + + func.return %2 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[false, false, true]> : tensor<3xi1> + +// CHECK-LABEL: @select_splat_cond +func.func @select_splat_cond() -> tensor<4xi32> { + %cond = arith.constant dense : tensor<4xi1> + %0 = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + %1 = arith.constant dense<[-1, -2, -3, -4]> : tensor<4xi32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + func.return %2 : tensor<4xi32> +} + +// CHECK: %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> + +// CHECK-LABEL: select_splat_lhs +func.func @select_splat_lhs() -> tensor<4xi32> { + %cond = arith.constant dense<[true, true, false, false]> : tensor<4xi1> + %0 = arith.constant dense<0> : tensor<4xi32> + %1 = arith.constant dense<[-1, -2, -3, -4]> : tensor<4xi32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> + + func.return %2 : tensor<4xi32> +} + +// CHECK: %cst = arith.constant dense<[0, 0, -3, -4]> : tensor<4xi32> + +// CHECK-LABEL: select_float +func.func @select_float() -> tensor<4xf32> { + %cond = arith.constant dense<[true, true, false, false]> : tensor<4xi1> + %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32> + %1 = arith.constant dense<[-1.0, -2.0, -3.0, -4.0]> : tensor<4xf32> + + %2 = "tfl.select"(%cond, %0, %1) : (tensor<4xi1>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + func.return %2 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 2.000000e+00, -3.000000e+00, -4.000000e+00]> : tensor<4xf32 + +// CHECK-LABEL: floor +func.func @floor() -> tensor<3xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99]> : tensor<3xf32> + %0 = "tfl.floor"(%cst) : (tensor<3xf32>) -> tensor<3xf32> + func.return %0 : tensor<3xf32> +} + +// CHECK: %cst = arith.constant dense<[-1.000000e+00, 0.000000e+00, 0.000000e+00]> : tensor<3xf32> + +// CHECK-LABEL: floor_f64 +func.func @floor_f64() -> tensor<3xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99]> : tensor<3xf64> + %0 = "tfl.floor"(%cst) : (tensor<3xf64>) -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} + +// CHECK: tfl.floor + +// CHECK-LABEL: exp +func.func @exp() -> tensor<4xf32> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99, 0.36787944117]> : tensor<4xf32> + %0 = "tfl.exp"(%cst) : (tensor<4xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK: %cst = arith.constant dense<[0.36787945, 1.000000e+00, 2.69123459, 1.44466782]> : tensor<4xf32> + +// CHECK-LABEL: exp_f64 +func.func @exp_f64() -> tensor<4xf64> { + %cst = arith.constant dense<[-1.0, 0.0, 0.99, 0.36787944117]> : tensor<4xf64> + %0 = "tfl.exp"(%cst) : (tensor<4xf64>) -> tensor<4xf64> + func.return %0 : tensor<4xf64> +} + +// CHECK: tfl.exp + +// CHECK-LABEL: pow_float +func.func @pow_float() -> tensor<3xf32> { + %0 = arith.constant dense<[1.0, 0.0, 2.0]> : tensor<3xf32> + %1 = arith.constant dense<[2.0, 3.0, -1.5]> : tensor<3xf32> + + %2 = "tfl.pow"(%0, %1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + + func.return %2 : tensor<3xf32> +} + +// CHECK: %cst = arith.constant dense<[1.000000e+00, 0.000000e+00, 0.353553385]> : tensor<3xf32> + +// CHECK-LABEL: pow_int +func.func @pow_int() -> tensor<3xi32> { + %0 = arith.constant dense<[1, 0, 2]> : tensor<3xi32> + %1 = arith.constant dense<[2, 3, -1]> : tensor<3xi32> + + %2 = "tfl.pow"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + func.return %2 : tensor<3xi32> +} + +// CHECK: %cst = arith.constant dense<[1, 0, 0]> : tensor<3xi32> + +// CHECK-LABEL: logical_not +func.func @logical_not() -> tensor<3xi1> { + %cst = arith.constant dense<[false, true, false]> : tensor<3xi1> + %0 = "tfl.logical_not"(%cst) : (tensor<3xi1>) -> tensor<3xi1> + func.return %0 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense<[true, false, true]> : tensor<3xi1> + +// CHECK-LABEL: logical_not_splat +func.func @logical_not_splat() -> tensor<3xi1> { + %cst = arith.constant dense : tensor<3xi1> + %0 = "tfl.logical_not"(%cst) : (tensor<3xi1>) -> tensor<3xi1> + func.return %0 : tensor<3xi1> +} + +// CHECK: %cst = arith.constant dense : tensor<3xi1> + +// CHECK-LABEL: bitwise_xor_i32 +func.func @bitwise_xor_i32() -> tensor<3xi32> { + %0 = arith.constant dense<[0, 5, 3]> : tensor<3xi32> + %1 = arith.constant dense<[5, 0, 7]> : tensor<3xi32> + + %2 = "tfl.bitwise_xor"(%0, %1) : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + func.return %2 : tensor<3xi32> +} + +// CHECK: %cst = arith.constant dense<[5, 5, 4]> : tensor<3xi32> + +// CHECK-LABEL: bitwise_xor_ui8 +func.func @bitwise_xor_ui8() -> tensor<3xui8> { + %0 = arith.constant dense<[0, 5, 3]> : tensor<3xui8> + %1 = arith.constant dense<[5, 0, 7]> : tensor<3xui8> + + %2 = "tfl.bitwise_xor"(%0, %1) : (tensor<3xui8>, tensor<3xui8>) -> tensor<3xui8> + + func.return %2 : tensor<3xui8> +} + +// CHECK: %cst = arith.constant dense<[5, 5, 4]> : tensor<3xui8> diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD index 6cc9a623af0a3f..4e7fa53cf78c50 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD @@ -56,9 +56,9 @@ tf_native_cc_binary( "importer_test_min_max.cc", ], deps = [ + "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/schema:schema_utils", - "//tensorflow/lite:framework", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir index 7ea7e48777522e..77edd7a648fcaa 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/if_op.mlir @@ -16,3 +16,50 @@ func.func @cond_false(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf3 %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32> func.return %0 : tensor<*xf32> } + +// ----- + +func.func @tfl_if(%arg0: tensor) -> tensor { +// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}) <{else_branch = @tfl.if_else, is_stateless = false, then_branch = @tfl.if_then}> : (tensor, tensor) -> tensor + %cst = arith.constant dense<0> : tensor + %0 = tfl.add %cst, %cst {fused_activation_function = "NONE"} : tensor + %1 = "tfl.if"(%arg0) ({ + %2 = func.call @tfl.if_then(%0) : (tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }, { + %2 = func.call @tfl.if_else(%0) : (tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %1 : tensor +} +func.func private @tfl.if_then(%arg0: tensor) -> tensor { + return %arg0 : tensor +} +func.func private @tfl.if_else(%arg0: tensor) -> tensor { + %0 = tfl.mul %arg0, %arg0 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} + +// ----- + +func.func @tfl_if_multi_args(%arg0: tensor) -> tensor { +// CHECK: %{{.*}} = "tf.If"(%{{.*}}, %{{.*}}, %{{.*}}) <{else_branch = @tfl.if_else_1, is_stateless = false, then_branch = @tfl.if_then_1}> : (tensor, tensor, tensor) -> tensor + %cst = arith.constant dense<0> : tensor + %0 = tfl.add %cst, %cst {fused_activation_function = "NONE"} : tensor + %1 = tfl.mul %cst, %cst {fused_activation_function = "NONE"} : tensor + %2 = "tfl.if"(%arg0) ({ + %2 = func.call @tfl.if_then_1(%0, %1) : (tensor, tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }, { + %2 = func.call @tfl.if_else_1(%0, %1) : (tensor, tensor) -> tensor + "tfl.yield"(%2) : (tensor) -> () + }) : (tensor) -> tensor + return %1 : tensor +} +func.func private @tfl.if_then_1(%arg0: tensor, %arg1: tensor) -> tensor { + return %arg0 : tensor +} +func.func private @tfl.if_else_1(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = tfl.mul %arg0, %arg1 {fused_activation_function = "NONE"} : tensor + return %0 : tensor +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc index 30890fb539bde0..6231088052aa78 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/importer_test_min_max.cc @@ -16,17 +16,15 @@ limitations under the License. #include #include #include -#include -#include "absl/strings/string_view.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/mlir/lite/core/absl_error_model_builder.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" -#include "tensorflow/lite/model.h" using llvm::cl::opt; @@ -52,7 +50,7 @@ namespace mlir { namespace { std::optional> InjectStatsToFullyConnected( llvm::StringRef buffer) { - auto model_ptr = tflite::FlatBufferModel::VerifyAndBuildFromBuffer( + auto model_ptr = TFL::FlatBufferModelAbslError::VerifyAndBuildFromBuffer( buffer.data(), buffer.size()); if (nullptr == model_ptr) { return std::nullopt; diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir index 32cd4552f0b15d..9d5bad8c7d6181 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/output_arrays.mlir @@ -3,7 +3,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %0 = "tfl.pseudo_const" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") // CHECK: %[[MUL:.*]] = tfl.mul %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir index 060d5fc871665a..cb87e4f0a2147f 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/pruning.mlir @@ -3,7 +3,7 @@ func.func @main(tensor<4xf32>) -> tensor<4xf32> { ^bb0(%arg0: tensor<4xf32>): - %0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") + %0 = "tfl.pseudo_const" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> loc("Const") %1 = "tfl.squared_difference"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("squared_difference") // CHECK: %[[MUL:.*]] = tfl.mul %2 = "tfl.mul"(%0, %1) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> loc("mul") diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 3c2c24baba8972..4301cbf8627b79 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1,12 +1,11 @@ // Run optimize pass only and check the results. -// RUN: tf-opt %s -tfl-optimize | FileCheck %s +// RUN: tf-opt %s -tfl-optimize='enable-canonicalization=false' | FileCheck %s // Run optimize pass and then canonicalize pass, and make sure some folding is applied. -// RUN: tf-opt %s -tfl-optimize='enable-canonicalization=true' | FileCheck --check-prefix=FOLD %s - +// RUN: tf-opt %s -tfl-optimize | FileCheck --check-prefix=FOLD %s // Run legalize pass and then optimize pass, and make sure some fusing is applied. -// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize | FileCheck --check-prefix=Fusing %s +// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='enable-canonicalization=false' | FileCheck --check-prefix=Fusing %s // Run legalize pass and then optimize pass, and make sure some fusing is applied, but no mul->fc. -// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='disable-fuse-mul-and-fc=true' | FileCheck --check-prefix=NoFusing %s +// RUN: tf-opt %s -tfl-legalize-tf -tfl-optimize='enable-canonicalization=false disable-fuse-mul-and-fc=true' | FileCheck --check-prefix=NoFusing %s // CHECK-LABEL: fusedConv2dRelu func.func @fusedConv2dRelu(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<256x32x32x16xf32> { @@ -4124,4 +4123,67 @@ func.func @StridedSliceToSliceBeginNeg(%arg0: tensor<5x5x5x5xf32>) -> tensor<*xf func.return %47 : tensor<*xf32> // CHECK-NOT: %[[slice:.*]] = "tfl.slice" -} \ No newline at end of file +} + +// CHECK-LABEL: conv3d_external_padding +func.func @conv3d_external_padding(%arg0: tensor<1x7x7x7x128xf32>, %arg1: tensor<3x3x3x128x256xf32>) -> tensor<1x7x7x7x256xf32> { + %cst = arith.constant dense<[[0, 0], [1, 1], [1, 1], [1, 1], [0, 0]]> : tensor<5x2xi64> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32> + %0 = "tfl.pad"(%arg0, %cst) : (tensor<1x7x7x7x128xf32>, tensor<5x2xi64>) -> tensor<1x9x9x9x128xf32> + %1 = "tfl.conv_3d"(%0, %arg1, %cst_0) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x9x9x9x128xf32>, tensor<3x3x3x128x256xf32>, tensor<256xf32>) -> tensor<1x7x7x7x256xf32> + return %1 : tensor<1x7x7x7x256xf32> +} + +// CHECK: %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x7x7x7x128xf32>, tensor<3x3x3x128x256xf32>, tensor<256xf32>) -> tensor<1x7x7x7x256xf32> + +// CHECK-LABEL: conv3d_external_padding_strided +func.func @conv3d_external_padding_strided(%arg0: tensor<1x8x56x56x128xf32>, %arg1: tensor<3x3x3x128x256xf32>) -> tensor<1x4x28x28x256xf32> { + %cst = arith.constant dense<[[0, 0], [0, 1], [0, 1], [0, 1], [0, 0]]> : tensor<5x2xi64> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32> + %0 = "tfl.pad"(%arg0, %cst) : (tensor<1x8x56x56x128xf32>, tensor<5x2xi64>) -> tensor<1x9x57x57x128xf32> + %1 = "tfl.conv_3d"(%0, %arg1, %cst_0) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x9x57x57x128xf32>, tensor<3x3x3x128x256xf32>, tensor<256xf32>) -> tensor<1x4x28x28x256xf32> + return %1 : tensor<1x4x28x28x256xf32> +} + +// CHECK: %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) <{dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x8x56x56x128xf32>, tensor<3x3x3x128x256xf32>, tensor<256xf32>) -> tensor<1x4x28x28x256xf32> + +// CHECK-LABEL: conv2d_external_padding +func.func @conv2d_external_padding(%arg0: tensor<1x7x7x128xf32>, %arg1: tensor<256x3x3x128xf32>) -> tensor<1x7x7x256xf32> { + %cst = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32> + %0 = "tfl.pad"(%arg0, %cst) : (tensor<1x7x7x128xf32>, tensor<4x2xi64>) -> tensor<1x9x9x128xf32> + %1 = "tfl.conv_2d"(%0, %arg1, %cst_0) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x9x9x128xf32>, tensor<256x3x3x128xf32>, tensor<256xf32>) -> tensor<1x7x7x256xf32> + return %1 : tensor<1x7x7x256xf32> +} + +// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x7x7x128xf32>, tensor<256x3x3x128xf32>, tensor<256xf32>) -> tensor<1x7x7x256xf32> + +// CHECK-LABEL: conv2d_external_padding_strided +func.func @conv2d_external_padding_strided(%arg0: tensor<1x8x8x128xf32>, %arg1: tensor<256x3x3x128xf32>) -> tensor<1x4x4x256xf32> { + %cst = arith.constant dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32> + %0 = "tfl.pad"(%arg0, %cst) : (tensor<1x8x8x128xf32>, tensor<4x2xi64>) -> tensor<1x9x9x128xf32> + %1 = "tfl.conv_2d"(%0, %arg1, %cst_0) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x9x9x128xf32>, tensor<256x3x3x128xf32>, tensor<256xf32>) -> tensor<1x4x4x256xf32> + return %1 : tensor<1x4x4x256xf32> +} + +// CHECK: %0 = "tfl.conv_2d"(%arg0, %arg1, %cst) <{dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}> : (tensor<1x8x8x128xf32>, tensor<256x3x3x128xf32>, tensor<256xf32>) -> tensor<1x4x4x256xf32> + +// CHECK-LABEL: depthwise_conv_external_same_padding +func.func @depthwise_conv_external_same_padding(%arg0: tensor<1x8x8x64xf32>, %arg1: tensor<1x3x3x64xf32>) -> tensor<1x8x8x64xf32> { + %cst = arith.constant dense<0.000000e+00> : tensor<64xf32> + %cst_0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + %0 = "tfl.pad"(%arg0, %cst_0) : (tensor<1x8x8x64xf32>, tensor<4x2xi64>) -> tensor<1x10x10x64xf32> + %1 = "tfl.depthwise_conv_2d"(%0, %arg1, %cst) <{ + depth_multiplier = 1 : i32, + dilation_h_factor = 1 : i32, + dilation_w_factor = 1 : i32, + fused_activation_function = "NONE", + padding = "VALID", + stride_h = 1 : i32, + stride_w = 1 : i32 + }> : (tensor<1x10x10x64xf32>, tensor<1x3x3x64xf32>, tensor<64xf32>) -> tensor<1x8x8x64xf32> + return %1 : tensor<1x8x8x64xf32> +} + +// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %cst) <{depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}> : (tensor<1x8x8x64xf32>, tensor<1x3x3x64xf32>, tensor<64xf32>) -> tensor<1x8x8x64xf32> \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir index a5da33ca90191b..8796d690f72796 100644 --- a/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir +++ b/tensorflow/compiler/mlir/lite/tests/push-tpose-through-ewise.mlir @@ -164,4 +164,35 @@ func.func @pushTposeBcastScalarCstInput(%arg0: tensor<2x3x4x5xf32>) -> tensor<5x // CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> // CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor<2x3x4x5xf32>, tensor<4xi32>) -> tensor<5x2x3x4xf32> +// ----- + +// CHECK-LABEL: pushTposeDynamicBcastScalarCstInput +func.func @pushTposeDynamicBcastScalarCstInput(%arg0: tensor) -> tensor<5x?x?x4xf32> { + %perm = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %perm) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %cst = arith.constant dense<1.0> : tensor + %1 = "tfl.add"(%0, %cst) { fused_activation_function = "NONE" } : (tensor<5x?x?x4xf32>, tensor) -> tensor<5x?x?x4xf32> + func.return %1 : tensor<5x?x?x4xf32> +} + +// CHECK: %cst = arith.constant dense<1.000000e+00> : tensor +// CHECK: %cst_0 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor) -> tensor +// CHECK: %1 = "tfl.transpose"(%0, %cst_0) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> +// ----- + +// CHECK-LABEL: doubleTposeDynamicInput +func.func @doubleTposeDynamicInput(%arg0: tensor, %arg1: tensor) -> tensor<5x?x?x4xf32> { + %perm = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %perm) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %perm1 = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> + %1 = "tfl.transpose"(%arg1, %perm1) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> + %2 = tfl.add %0, %1 { fused_activation_function = "NONE" } : tensor<5x?x?x4xf32> + func.return %2 : tensor<5x?x?x4xf32> +} + +// CHECK: %cst = arith.constant dense<[3, 0, 1, 2]> : tensor<4xi32> +// CHECK: %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor +// CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor, tensor<4xi32>) -> tensor<5x?x?x4xf32> +// CHECK: return %1 : tensor<5x?x?x4xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir index 477315d696783c..d9382fdeb3341b 100644 --- a/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/raise-custom-ops.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: custom_op func.func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %0 = "arith.constant" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32> + %0 = "arith.constant" () {value = dense<2.0> : tensor<4xf32>} : () -> tensor<4xf32> %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> // will be preserved since it has uses. %2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> @@ -11,7 +11,7 @@ func.func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> { "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> func.return %2 : tensor<4xf32> -// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant dense<2.000000e+00> : tensor<4xf32> // CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32> // CHECK-NEXT: %[[CUSTOM_1:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ({ // CHECK-NEXT: ^bb0(%arg1: tensor<4xf32>, %arg2: tensor<4xf32>): diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index c756e59313156a..878026e6b47913 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -59,12 +59,18 @@ void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::TFL::CreatePrepareQuantizePass(quant_specs)); if (quant_specs.default_ranges.first.has_value() || quant_specs.default_ranges.second.has_value()) { + mlir::TFL::DefaultQuantParamsPassOptions default_quant_params_pass_options; + default_quant_params_pass_options.default_min_ = + quant_specs.default_ranges.first.value_or(0.0); + default_quant_params_pass_options.default_max_ = + quant_specs.default_ranges.second.value_or(0.0); + default_quant_params_pass_options.is_signed_ = + quant_specs.IsSignedInferenceType(); pass_manager.addNestedPass( mlir::TFL::CreateDefaultQuantParamsPass( - quant_specs.default_ranges.first.value_or(0.0), - quant_specs.default_ranges.second.value_or(0.0), - quant_specs.IsSignedInferenceType())); + default_quant_params_pass_options)); } + pass_manager.addNestedPass( mlir::TFL::CreateQuantizePass(quant_specs)); bool emit_quant_adaptor_ops = @@ -95,8 +101,9 @@ void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, pass_manager.addNestedPass( mlir::TFL::CreateOptimizeBatchMatmulPass()); } + // Add TFLite optimize pass. pass_manager.addNestedPass( - mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); + mlir::TFL::CreateOptimizePass()); } void AddVariableFreezingFromGlobalTensorsPasses( @@ -107,12 +114,10 @@ void AddVariableFreezingFromGlobalTensorsPasses( pass_manager->addPass( mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); - if (!pass_config.disable_variable_freezing) { - // This pass 'freezes' immutable global tensors and inlines them as tf - // constant ops. - pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass( - /*allow_mutable_tensors=*/pass_config.enable_tflite_variables)); - } + // This pass 'freezes' immutable global tensors and inlines them as tf + // constant ops. + pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass( + /*allow_mutable_tensors=*/pass_config.enable_tflite_variables)); pass_manager->addPass(mlir::TFL::CreateUnfreezeMutableGlobalTensorsPass()); } @@ -152,8 +157,18 @@ void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, pass_manager.addNestedPass( mlir::TFL::CreateOptimizeBatchMatmulPass()); } + + // Add TFLite optimize pass. pass_manager.addNestedPass( - mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true)); + mlir::TFL::CreateOptimizePass()); +} + +void AddPytorchPasses(mlir::OpPassManager& pass_manager) { + pass_manager.addNestedPass(mlir::createCSEPass()); + pass_manager.addPass(mlir::odml::createBuildStableHLOCompositePass()); + pass_manager.addPass(mlir::createInlinerPass()); + pass_manager.addPass(mlir::odml::createLiftCallSiteLocCallerPass()); + pass_manager.addNestedPass(mlir::createCSEPass()); } void AddPreQuantizationStableHloToTfPasses( @@ -163,6 +178,10 @@ void AddPreQuantizationStableHloToTfPasses( pass_manager.addPass( mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + if (pass_config.model_origin_framework == toco::TocoFlags::PYTORCH) { + AddPytorchPasses(pass_manager); + } + // Legalize MHLO to StableHLO should be moved closer to where it is needed // There are some entry points that start with HLO->MHLO like // jax_to_tfl_flatbuffer.cc which can likely be updated to emit StableHLO @@ -300,14 +319,17 @@ void AddPreVariableFreezingTFToTFLConversionPasses( mlir::OpPassManager* pass_manager) { // This pass wraps all the tf.FakeQuant ops in a custom op so they are not // folded before being converted to tfl.quantize and tfl.dequantize ops. - auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); + std::vector target_ops = mlir::TFL::AllTfFakeQuantOps(); + mlir::TFL::RaiseCustomOpsPassOptions raise_custom_ops_pass_options; + raise_custom_ops_pass_options.target_ops_ = target_ops; pass_manager->addNestedPass( - mlir::TFL::CreateRaiseCustomOpsPass(wrapped_ops)); + mlir::TFL::CreateRaiseCustomOpsPass(raise_custom_ops_pass_options)); mlir::TF::StandardPipelineOptions standard_pipeline_options; standard_pipeline_options.enable_inliner = false; standard_pipeline_options.form_clusters = pass_config.form_clusters; mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); + pass_manager->addNestedPass( mlir::TF::CreateDeviceIndexSelectorPass()); @@ -372,13 +394,16 @@ void AddPostVariableFreezingTFToTFLConversionPasses( if (pass_config.lower_tensor_list_ops && toco_flags.tf_quantization_mode().empty()) { // TODO(haoliang): Add this pass by default. + mlir::TFL::LowerStaticTensorListPassOptions + lower_static_tensor_list_pass_options; + lower_static_tensor_list_pass_options.allow_tensorlist_pass_through_ = + toco_flags.force_select_tf_ops() || toco_flags.enable_select_tf_ops(); + lower_static_tensor_list_pass_options.default_to_single_batch_ = + toco_flags.default_to_single_batch_in_tensor_list_ops(); + lower_static_tensor_list_pass_options.enable_dynamic_update_slice_ = + toco_flags.enable_dynamic_update_slice(); pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass( - /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() || - toco_flags.enable_select_tf_ops(), - /*default_to_single_batch=*/ - toco_flags.default_to_single_batch_in_tensor_list_ops(), - /*enable_dynamic_update_slice=*/ - toco_flags.enable_dynamic_update_slice())); + lower_static_tensor_list_pass_options)); } if (pass_config.shape_inference) { @@ -433,6 +458,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); } + // Force layout supported by TFLite, this will transpose the data // to match 'kTFLiteDataLayout' mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options; @@ -440,13 +466,19 @@ void AddPostVariableFreezingTFToTFLConversionPasses( layout_optimization_options.skip_fold_transpose_in_ops = true; mlir::TF::CreateLayoutOptimizationPipeline( pass_manager->nest(), layout_optimization_options); + // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. + mlir::TFL::PrepareTFPassOptions prepare_tf_pass_options; + prepare_tf_pass_options.unfold_batch_matmul_ = + pass_config.unfold_batch_matmul; + prepare_tf_pass_options.allow_bf16_and_f16_type_legalization_ = + !pass_config.runtime_verification; + prepare_tf_pass_options.use_fake_quant_num_bits_ = + toco_flags.use_fake_quant_num_bits(); pass_manager->addNestedPass( - mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul, - /*allow_bf16_and_f16_type_legalization=*/ - !pass_config.runtime_verification, - toco_flags.use_fake_quant_num_bits())); + mlir::TFL::CreatePrepareTFPass(prepare_tf_pass_options)); + pass_manager->addNestedPass( mlir::createCanonicalizerPass()); if (pass_config.shape_inference) { @@ -465,22 +497,42 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addNestedPass( mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str())); + // Add legalize TF pass to TFL dialect. + mlir::TFL::LegalizeTFPassOptions legalize_tf_pass_options; + legalize_tf_pass_options.run_tfl_runtime_verification_ = + pass_config.runtime_verification; + legalize_tf_pass_options.preserve_assert_op_ = + pass_config.preserve_assert_op; pass_manager->addNestedPass( - mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification, - pass_config.preserve_assert_op)); + mlir::TFL::CreateLegalizeTFPass(legalize_tf_pass_options)); + pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); - if (!pass_config.unfold_batch_matmul) { - // Enable an optimization pass that transforms FC to BatchMatmul only when - // `unfold_batch_matmul=false`. + + mlir::TFL::OptimizePassOptions optimize_pass_options; + optimize_pass_options.disable_fuse_mul_and_fc = + toco_flags.disable_fuse_mul_and_fc(); + + auto add_tfl_optimization_passes = [&]() { + if (!pass_config.unfold_batch_matmul) { + // Enable an optimization pass that transforms FC to BatchMatmul only + // when `unfold_batch_matmul=false`. + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeBatchMatmulPass()); + } + pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); + + // Add TFLite optimize pass. pass_manager->addNestedPass( - mlir::TFL::CreateOptimizeBatchMatmulPass()); - } - pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); - pass_manager->addNestedPass( - mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true, - toco_flags.disable_fuse_mul_and_fc())); + mlir::TFL::CreateOptimizePass(optimize_pass_options)); + }; + + // Run TFL optimization passes set multiple times as op fusion and + // reordering in later passes may enable further optimizations with earlier + // passes. + add_tfl_optimization_passes(); + add_tfl_optimization_passes(); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index 853606c1119436..bf3353e874bc87 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -61,10 +61,12 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/debug/debug.h" +#include "tensorflow/compiler/mlir/lite/experimental/remat/metadata_util.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -96,7 +98,6 @@ limitations under the License. #include "tensorflow/core/ir/types/dialect.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/lite/toco/toco_flags.pb.h" -#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/statusor.h" @@ -107,7 +108,6 @@ using mlir::MLIRContext; using mlir::ModuleOp; using mlir::Operation; using mlir::OwningOpRef; -using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::quantization::PyFunctionLibrary; bool IsControlFlowV1Op(Operation* op) { @@ -310,13 +310,13 @@ absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( reinterpret_cast(translated_result.c_str()); const ::tflite::Model* input_model = ::tflite::GetModel(buffer); - ::tflite::optimize::BufferType quantized_type; + mlir::lite::toco_legacy::BufferType quantized_type; switch (quant_specs.inference_type) { case DT_QINT8: - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; + quantized_type = mlir::lite::toco_legacy::BufferType::QUANTIZED_INT8; break; case DT_HALF: - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; + quantized_type = mlir::lite::toco_legacy::BufferType::QUANTIZED_FLOAT16; break; default: return absl::InvalidArgumentError("Quantized type not supported"); @@ -324,9 +324,10 @@ absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( } bool use_updated_hybrid_scheme = !quant_specs.disable_per_channel; - absl::Status quantize_weights_status = ::tflite::optimize::QuantizeWeights( - &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER); + absl::Status quantize_weights_status = + mlir::lite::toco_legacy::QuantizeWeights( + &q_builder, input_model, quantized_type, use_updated_hybrid_scheme, + mlir::lite::toco_legacy::QuantizerType::OLD_QUANTIZER); if (!quantize_weights_status.ok()) return quantize_weights_status; const uint8_t* q_buffer = q_builder.GetBufferPointer(); *result = diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 94ed4b1e0340a5..3c52ea8b61c235 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -63,6 +63,13 @@ class DefaultQuantParamsPass this->is_signed_ = is_signed; } + explicit DefaultQuantParamsPass( + const DefaultQuantParamsPassOptions &options) { + this->default_min_ = options.default_min_; + this->default_max_ = options.default_max_; + this->is_signed_ = options.is_signed_; + } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(DefaultQuantParamsPass) void runOnOperation() override; @@ -237,6 +244,11 @@ std::unique_ptr> CreateDefaultQuantParamsPass( is_signed); } +std::unique_ptr> CreateDefaultQuantParamsPass( + const DefaultQuantParamsPassOptions &options) { + return std::make_unique(options); +} + std::unique_ptr> CreateDefaultQuantParamsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 5cac14867482bb..73d102a0502f1f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -24,8 +24,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h" //===----------------------------------------------------------------------===// // The DenseToSparse Pass. @@ -125,8 +125,8 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, &b_size); if (type.getElementType().isF32()) { - tflite::internal::sparsity::FormatConverter format_converter( - shape, traversal_order, format, b_size, b_map); + tflite_migration::internal::sparsity::FormatConverter + format_converter(shape, traversal_order, format, b_size, b_map); std::vector data; data.reserve(type.getNumElements()); for (const auto val : attr.getValues()) data.push_back(val); @@ -135,8 +135,8 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, GetSparsity(type.getNumElements() - format_converter.GetData().size(), type.getNumElements()); } else if (type.getElementType().isF16()) { - tflite::internal::sparsity::FormatConverter format_converter( - shape, traversal_order, format, b_size, b_map); + tflite_migration::internal::sparsity::FormatConverter + format_converter(shape, traversal_order, format, b_size, b_map); std::vector data; data.reserve(type.getNumElements()); for (const auto& val : attr.getValues()) @@ -146,8 +146,8 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, GetSparsity(type.getNumElements() - format_converter.GetData().size(), type.getNumElements()); } else if (mlir::isa(type.getElementType())) { - tflite::internal::sparsity::FormatConverter format_converter( - shape, traversal_order, format, b_size, b_map); + tflite_migration::internal::sparsity::FormatConverter + format_converter(shape, traversal_order, format, b_size, b_map); std::vector data; data.reserve(type.getNumElements()); for (const auto val : attr.getValues()) data.push_back(val); @@ -250,7 +250,7 @@ std::vector BuildSparsityParameterAttribute( PopulateEncodingParams(block_size, &traversal_order, &format, &b_map, &b_size); - tflite::internal::sparsity::FormatConverter format_converter( + tflite_migration::internal::sparsity::FormatConverter format_converter( shape, traversal_order, format, b_size, b_map); format_converter.DenseToSparse(dense_buffer); const auto& metadata = format_converter.GetDimMetadata(); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index a2ec3624064dd4..78f52b63f09243 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include "absl/status/statusor.h" @@ -88,6 +89,11 @@ class LegalizeTFPass : public impl::LegalizeTFPassBase { this->preserve_assert_op_ = preserve_assert_op; } + explicit LegalizeTFPass(const LegalizeTFPassOptions& options) { + this->run_tfl_runtime_verification_ = options.run_tfl_runtime_verification_; + this->preserve_assert_op_ = options.preserve_assert_op_; + } + /// Performs the lowering to TFLite dialect. void runOnOperation() override; }; @@ -1139,6 +1145,12 @@ std::unique_ptr> CreateLegalizeTFPass( preserve_assert_op); } +// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. +std::unique_ptr> CreateLegalizeTFPass( + const LegalizeTFPassOptions& options) { + return std::make_unique(options); +} + std::unique_ptr> CreateLegalizeTFPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 7875eecb17d776..2b5b7537f5154c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -20,25 +20,20 @@ limitations under the License. // be represented using a TensorFlow op. Otherwise, TensorFlow Lite dialect op // is used. -#include #include +#include #include #include -#include "absl/container/inlined_vector.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -57,14 +52,11 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -89,25 +81,6 @@ namespace mlir { namespace { -// TODO(b/355062942): This a temporary solution to unblock LLVM intergration. -// https://github.com/llvm/llvm-project/commit/bbd4af5da2b741672a8e6f625eb12ea5c2d6220f -// changed the behavior of `applySignatureConversion`. Before, an op adaptor -// would have the new block arguments directly as operands. Now, there is an -// `UnrealizedConversionCastOp` inserts from the new type to the old type. -// The new behaviour is correct, but passes in this file depended on the old -// bahavior and worked by coincidence. -llvm::SmallVector GetOperandsAndSkipUnrealizedConversionCasts( - ValueRange operands) { - llvm::SmallVector result; - for (Value operand : operands) { - if (auto cast = operand.getDefiningOp()) { - operand = cast.getInputs().front(); - } - result.push_back(operand); - } - return result; -} - /// Lower TensorList ops in functions for subsequent legalization. struct LowerStaticTensorListPass : public impl::LowerStaticTensorListPassBase { @@ -123,6 +96,14 @@ struct LowerStaticTensorListPass this->enable_dynamic_update_slice_ = enable_dynamic_update_slice; } + explicit LowerStaticTensorListPass( + const TFL::LowerStaticTensorListPassOptions &options) { + this->allow_tensorlist_pass_through_ = + options.allow_tensorlist_pass_through_; + this->default_to_single_batch_ = options.default_to_single_batch_; + this->enable_dynamic_update_slice_ = options.enable_dynamic_update_slice_; + } + void runOnOperation() override; }; @@ -371,9 +352,7 @@ struct ConvertTensorListSetItem ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); - + auto operands = adaptor.getOperands(); Value input = operands[0]; Value index = operands[1]; Value item = operands[2]; @@ -433,8 +412,7 @@ struct ConvertTensorListSetItem ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); Value input = operands[0]; Value index = operands[1]; Value item = operands[2]; @@ -721,8 +699,7 @@ struct ConvertTensorListPushBack LogicalResult matchAndRewrite( TF::TensorListPushBackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); Value input_handle = operands[0]; Value item = operands[1]; @@ -764,8 +741,7 @@ struct ConvertTensorListResize LogicalResult matchAndRewrite( TF::TensorListResizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); Value input_handle = operands[0]; Value size = operands[1]; @@ -929,9 +905,7 @@ struct ConvertTensorListGetItem LogicalResult matchAndRewrite( TF::TensorListGetItemOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); - + auto operands = adaptor.getOperands(); Value input = operands[0]; Value index = operands[1]; rewriter.replaceOpWithNewOp(op, op.getType(), input, index, @@ -948,8 +922,7 @@ struct ConvertTensorListLength TF::TensorListLengthOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value input_handle = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0]; + Value input_handle = adaptor.getOperands()[0]; BoolAttr true_attr = rewriter.getBoolAttr(true); auto shape = rewriter.create(loc, input_handle, @@ -970,8 +943,7 @@ struct ConvertTensorListStack ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); Value input = operands[0]; Value element_shape = operands[1]; @@ -1021,8 +993,7 @@ struct ConvertTensorListConcatV2 ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); Value input = operands[0]; Value element_shape = operands[1]; @@ -1084,8 +1055,7 @@ struct ConvertIdentity : public OpConversionPattern { LogicalResult matchAndRewrite( TF::IdentityOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value input = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands())[0]; + Value input = adaptor.getOperands()[0]; rewriter.replaceOpWithNewOp(op, input.getType(), input, op->getAttrs()); return success(); @@ -1098,9 +1068,7 @@ struct ConvertReturn : public OpConversionPattern { LogicalResult matchAndRewrite( func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); - + auto operands = adaptor.getOperands(); rewriter.replaceOpWithNewOp(op, ValueRange{}, operands, op->getAttrs()); return success(); @@ -1113,8 +1081,7 @@ struct ConvertYield : public OpConversionPattern { LogicalResult matchAndRewrite( TF::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto operands = - GetOperandsAndSkipUnrealizedConversionCasts(adaptor.getOperands()); + auto operands = adaptor.getOperands(); rewriter.replaceOpWithNewOp(op, operands); return success(); } @@ -1661,6 +1628,11 @@ std::unique_ptr> TFL::CreateLowerStaticTensorListPass( enable_dynamic_update_slice); } +std::unique_ptr> TFL::CreateLowerStaticTensorListPass( + const LowerStaticTensorListPassOptions &options) { + return std::make_unique(options); +} + std::unique_ptr> TFL::CreateLowerStaticTensorListPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 606be04a0f7d6b..b36fe6b55bbd93 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -16,14 +16,14 @@ limitations under the License. // This transformation pass takes operations in TensorFlowLite dialect and // optimizes them to resulting operations in TensorFlowLite dialect. +#include "tensorflow/compiler/mlir/lite/transforms/optimize.h" + #include #include -#include #include #include #include #include -#include #include #include #include @@ -37,7 +37,6 @@ limitations under the License. #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -57,7 +56,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -73,8 +71,6 @@ namespace TFL { //===----------------------------------------------------------------------===// // The actual Optimize Pass. namespace { -#define GEN_PASS_DEF_OPTIMIZEPASS -#include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" constexpr char kRelu[] = "RELU"; constexpr char kRelu6[] = "RELU6"; @@ -123,26 +119,124 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { return true; } -using ::llvm::cast; +// Is rankx2xi32 padding array "balanced" +// i.e. 0 <= [d][1] - [d][0] <= 1 for all spatial dims d (and 0 elsewhere). +template +bool IsBalancedPaddingArray(int spatials_start, int spatials_end, + llvm::ArrayRef data) { + for (int i = 0; i < data.size() / 2; ++i) { + const T pad_low = data[2 * i]; + const T pad_hi = data[2 * i + 1]; + if ((i < spatials_start || i >= spatials_end) && + (pad_low != 0 || pad_hi != 0)) { + return false; + } + const T pad_diff = pad_hi - pad_low; + if (pad_diff > 1 || pad_diff < 0) { + return false; + } + } + return true; +} + +bool IsBalancedPaddingArray(int spatials_start, int spatials_end, + DenseElementsAttr data) { + if (data.isSplat()) { + return false; + } + if (data.getElementType().isInteger(64)) { + return IsBalancedPaddingArray( + spatials_start, spatials_end, + llvm::SmallVector(data.value_begin(), + data.value_end())); + } + if (data.getElementType().isInteger(32)) { + return IsBalancedPaddingArray( + spatials_start, spatials_end, + llvm::SmallVector(data.value_begin(), + data.value_end())); + } + return false; +} -// Optimize TFLite operations in functions. -class OptimizePass : public impl::OptimizePassBase { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) +bool HasSameStridedDim(int in, int dilate, int stride, int k, int p) { + const int effective_filter = (k - 1) * dilate + 1; + const int out_size = (in + stride - 1) / stride; + const int padding_needed = (out_size - 1) * stride + effective_filter - in; + return padding_needed == p; +} - OptimizePass() = default; - OptimizePass(const OptimizePass &) {} - explicit OptimizePass(bool enable_canonicalization, - bool disable_fuse_mul_and_fc = false) { - this->enable_canonicalization_ = enable_canonicalization; - this->disable_fuse_mul_and_fc_ = disable_fuse_mul_and_fc; +// Is the pre pad shape amenable to given conv with SAME padding. +bool HasSameStridedShape(TFL::Conv2DOp op, ArrayRef pre_pad_shape) { + auto conv_in_shape = + llvm::dyn_cast(op.getInput().getType()).getShape(); + auto kernel_shape = + llvm::dyn_cast(op.getFilter().getType()).getShape(); + if (conv_in_shape.size() != kernel_shape.size()) { + return false; + } + if (conv_in_shape.size() < 3) { + return false; } - void runOnOperation() override; -}; + const int64_t h_pad = conv_in_shape[1] - pre_pad_shape[1]; + const bool h_strided = + HasSameStridedDim(pre_pad_shape[1], op.getDilationHFactor(), + op.getStrideH(), kernel_shape[1], h_pad); + + const int64_t w_pad = conv_in_shape[2] - pre_pad_shape[2]; + const bool w_strided = + HasSameStridedDim(pre_pad_shape[2], op.getDilationWFactor(), + op.getStrideW(), kernel_shape[2], w_pad); + return h_strided && w_strided; +} -// Return true if the product of dimension values of a subsection of the tensor -// is equal to the non-contracting dimension after a reshape +bool HasSameStridedShape(TFL::DepthwiseConv2DOp op, + ArrayRef pre_pad_shape) { + auto conv_in_shape = + llvm::dyn_cast(op.getInput().getType()).getShape(); + auto kernel_shape = + llvm::dyn_cast(op.getFilter().getType()).getShape(); + + const int64_t h_pad = conv_in_shape[1] - pre_pad_shape[1]; + const bool h_strided = + HasSameStridedDim(pre_pad_shape[1], op.getDilationHFactor(), + op.getStrideH(), kernel_shape[1], h_pad); + + const int64_t w_pad = conv_in_shape[2] - pre_pad_shape[2]; + const bool w_strided = + HasSameStridedDim(pre_pad_shape[2], op.getDilationWFactor(), + op.getStrideW(), kernel_shape[2], w_pad); + return h_strided && w_strided; +} + +bool HasSameStridedShape(TFL::Conv3DOp op, ArrayRef pre_pad_shape) { + auto conv_in_shape = + llvm::dyn_cast(op.getInput().getType()).getShape(); + auto kernel_shape = + llvm::dyn_cast(op.getFilter().getType()).getShape(); + + const int64_t d_pad = conv_in_shape[1] - pre_pad_shape[1]; + const bool d_strided = + HasSameStridedDim(pre_pad_shape[1], op.getDilationDFactor(), + op.getStrideD(), kernel_shape[0], d_pad); + + const int64_t h_pad = conv_in_shape[2] - pre_pad_shape[2]; + const bool h_strided = + HasSameStridedDim(pre_pad_shape[2], op.getDilationHFactor(), + op.getStrideH(), kernel_shape[1], h_pad); + + const int64_t w_pad = conv_in_shape[3] - pre_pad_shape[3]; + const bool w_strided = + HasSameStridedDim(pre_pad_shape[3], op.getDilationWFactor(), + op.getStrideW(), kernel_shape[2], w_pad); + return h_strided && w_strided && d_strided; +} + +using ::llvm::cast; + +// Return true if the product of dimension values of a subsection of the +// tensor is equal to the non-contracting dimension after a reshape bool BroadcastDimsProductEqual(Value input, Value output, size_t agg_start_idx) { ArrayRef input_shape = @@ -231,17 +325,16 @@ bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, } auto elements_depth = elements_shape.empty() ? 1 : elements_shape.back(); - // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then we - // can let binary op to broadcast elements. + // If elements depth equals 1 (i.e., scalar or tensor with 1 element), then + // we can let binary op to broadcast elements. if (elements_depth == 1) { return true; } - // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise Conv. - // For conv: - // Check if last dimension in filter equals the first dimension - // For depthwise conv: - // Check if the first in filter dimension equals the first dimension. + // In TFLite Conv2D uses OHWI format for filter, and 1HWO for Depthwise + // Conv. For conv: Check if last dimension in filter equals the first + // dimension For depthwise conv: Check if the first in filter dimension + // equals the first dimension. if (filter_shape.empty() || (is_depthwise ? filter_shape.back() != elements_depth : filter_shape[0] != elements_depth)) @@ -275,15 +368,15 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, } // Returns true if we can eliminate the GatherNdOp or ScatterNdOp. When the -// value of `indices` are from 0 to n-1, the output tensor are identical to the -// `params`. +// value of `indices` are from 0 to n-1, the output tensor are identical to +// the `params`. bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, DenseIntElementsAttr indices, Type output_type) { auto params_type = mlir::dyn_cast(params.getType()); auto indices_type = mlir::dyn_cast(indices.getType()); - // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D - // `indices` means it gets the first row of `params`. As long as indices + // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. + // 2D `indices` means it gets the first row of `params`. As long as indices // iterate the first row of `params`, the output is identical to input. if (!params_type || !indices_type || indices_type.getRank() != 2 || indices_type.getDimSize(0) != params_type.getDimSize(0) || @@ -304,9 +397,9 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, return true; } -// Returns true if we can eliminate the SliceOp. When the values of `begin` are -// all 0s and `size[i]` is equal to either -1 or `input.shape[i]` -// for each dim i, the output tensor is identical to `input`. +// Returns true if we can eliminate the SliceOp. When the values of `begin` +// are all 0s and `size[i]` is equal to either -1 or `input.shape[i]` for each +// dim i, the output tensor is identical to `input`. bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Checks if `begin` and `size` are i32 or i64. auto begin_attr = mlir::dyn_cast(begin); @@ -324,8 +417,8 @@ bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { return false; } - // Checks if `input` is ranked and its rank is equal to number of elements in - // `begin` and `size`. + // Checks if `input` is ranked and its rank is equal to number of elements + // in `begin` and `size`. auto input_ty = mlir::cast(input.getType()); if (!input_ty.hasRank()) { return false; @@ -380,8 +473,8 @@ TypeAttr RescaleQtype(Type input, Attribute factor) { return quant::RescaleQuantizedType(input, factor); } -// Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in -// the specified `shape` and `false` otherwise. +// Returns `true` if reducing `axes` in `input` with `keep_dims=true` results +// in the specified `shape` and `false` otherwise. static bool ShapeMatchesReduceWithKeepAxes(Value input, const mlir::Attribute &axes, const mlir::Attribute &shape) { @@ -481,8 +574,8 @@ bool IsF32Value(Value value) { return mlir::cast(value.getType()).getElementType().isF32(); } -// Returns the number of elements in attr if it is a static shape, 1 otherwise, -// as an unranked int32 Attribute. +// Returns the number of elements in attr if it is a static shape, 1 +// otherwise, as an unranked int32 Attribute. TypedAttr GetNumElementsOrOne(Type type) { auto shaped_type = mlir::cast(type); int32_t num_elements = @@ -497,8 +590,8 @@ TypedAttr GetNumElementsOrOne(Type type) { // Reshapes value to a given shape. Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { - // This function is always guarded with HasTrivialShapeExceptSecondLastDim(), - // so we could cast safely here. + // This function is always guarded with + // HasTrivialShapeExceptSecondLastDim(), so we could cast safely here. auto type = mlir::cast(value.getType()); SmallVector new_shape; if (type.hasStaticShape()) { @@ -548,10 +641,11 @@ bool HasOneUseOrUsedByOnlyBinaryOps(Value out_value) { return true; } -// Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or an -// incrementing sequence from 0 to N-1. +// Returns true if attr is a DenseIntElementsAttr of int32 or int64 values or +// an incrementing sequence from 0 to N-1. // -// If such a value is used in an Equal operator, it can be replaced with OneHot. +// If such a value is used in an Equal operator, it can be replaced with +// OneHot. bool IsOneHotIndexAttribute(Attribute attr) { const auto dense_attr = mlir::dyn_cast_or_null(attr); if (!dense_attr) { @@ -638,8 +732,8 @@ bool IsF32Splat(Attribute input_splat) { } // Converts an Attribute with a single value of float or integral type to an -// Attribute holding a single value of float type. If attr has no elements, the -// result is 0.0f. +// Attribute holding a single value of float type. If attr has no elements, +// the result is 0.0f. TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { const auto dense_fp_attr = mlir::dyn_cast_or_null(attr); if (dense_fp_attr) { @@ -2520,6 +2614,7 @@ void AddCanonicalizationPatterns(MLIRContext *context, for (auto op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(*patterns, context); } +} // namespace void OptimizePass::runOnOperation() { RewritePatternSet patterns(&getContext()); @@ -2576,13 +2671,11 @@ void OptimizePass::runOnOperation() { AddCanonicalizationPatterns(ctx, &phase_2_patterns); (void)applyPatternsAndFoldGreedily(func, std::move(phase_2_patterns)); } -} // namespace // Creates an instance of the TensorFlow Lite dialect Optimize pass. std::unique_ptr> CreateOptimizePass( - bool enable_canonicalization, bool disable_fuse_mul_and_fc) { - return std::make_unique(enable_canonicalization, - disable_fuse_mul_and_fc); + const OptimizePassOptions &options) { + return std::make_unique(options); } std::unique_ptr> CreateOptimizePass() { diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.h b/tensorflow/compiler/mlir/lite/transforms/optimize.h new file mode 100644 index 00000000000000..477d2d23d7ad07 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.h @@ -0,0 +1,95 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ + +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" + +namespace mlir { +namespace TFL { + +struct OptimizePassOptions { + bool enable_canonicalization = true; + bool disable_fuse_mul_and_fc = false; +}; + +class OptimizePass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) + + OptimizePass() = default; + OptimizePass(const OptimizePass &) {} + explicit OptimizePass(bool enable_canonicalization, + bool disable_fuse_mul_and_fc = false) { + this->enable_canonicalization_ = enable_canonicalization; + this->disable_fuse_mul_and_fc_ = disable_fuse_mul_and_fc; + } + + explicit OptimizePass(const OptimizePassOptions &options) { + this->enable_canonicalization_ = options.enable_canonicalization; + this->disable_fuse_mul_and_fc_ = options.disable_fuse_mul_and_fc; + } + + void runOnOperation() final; + + /// Returns the command-line argument attached to this pass. + static constexpr llvm::StringLiteral getArgumentName() { + return llvm::StringLiteral("tfl-optimize"); + } + llvm::StringRef getArgument() const final { return "tfl-optimize"; } + + llvm::StringRef getDescription() const final { + return "Optimize within the TensorFlow Lite dialect"; + } + + /// Returns the derived pass name. + static constexpr llvm::StringLiteral getPassName() { + return llvm::StringLiteral("OptimizePass"); + } + llvm::StringRef getName() const final { return "OptimizePass"; } + + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(mlir::DialectRegistry ®istry) const final { + registry.insert(); + } + + private: + mlir::Pass::Option enable_canonicalization_{ + *this, "enable-canonicalization", + llvm::cl::desc("Enable canonicalization during optimization pass."), + llvm::cl::init(true)}; + mlir::Pass::Option disable_fuse_mul_and_fc_{ + *this, "disable-fuse-mul-and-fc", + llvm::cl::desc("Disable folding mul and fully connected ops during " + "optimization pass."), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 121322c761753f..f63c3e92af59a9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -22,6 +22,7 @@ include "mlir/Dialect/Func/IR/FuncOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/lite/utils/utils.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< @@ -1651,3 +1652,123 @@ def FuseSliceAndPack : Pat<( (Arith_ConstantOp ConstantAttr,"1">)), $_) ), $_, $_), (replaceWithValue $input0), [(IsSame $input0, $input1), (IsSame $input0, $input2)]>; + +// Given a value, checks if dim `d` is static. +class HasStaticDim : Constraint().isDynamicDim(" # d # ")">>; + +class IsBalancedPaddingArray : + Constraint())">>; + +// Given in_shape, out_shape, stride checks ceil(in_shape[d] / stride) == out_shape[d] +def IsSameStridedShape2D : Constraint()," + "$1.getType().cast().getShape())">>; + +def IsSameStridedShapeDepthwise : Constraint()," + "$1.getType().cast().getShape())">>; + +def IsSameStridedShape3D : Constraint()," + "$1.getType().cast().getShape())">>; + +def IsValidPadding : Constraint>; + +// Fuse explicit tfl.pad ops into standard convolutions when it implies "SAME" +// padding. "SAME" padding is defined to be any non-trivial padding where +// ceil(in_dim_i / stride_i) == out_dim_i +// and 0 <= (pad_i_hi - pad_i_lo) <= 1 for all spatial dims i. + +def FuseSamePaddingConv2D : Pat< + (TFL_Conv2DOp:$conv_out + (TFL_PadOp $input, (Arith_ConstantOp $paddings)), + $filter, + $bias, + $h_dilate, + $w_dilate, + $faf, + $padding, + $stride_h, + $stride_w + ), (TFL_Conv2DOp + $input, + $filter, + $bias, + $h_dilate, + $w_dilate, + $faf, + TFL_PAD_Same, + $stride_h, + $stride_w + ), + [(HasStaticDim<1> $input), + (HasStaticDim<2> $input), + (IsBalancedPaddingArray<1, 3> $paddings), + (IsValidPadding $padding), + (IsSameStridedShape2D $conv_out, $input)]>; + +def FuseSamePaddingDepthwiseConv : Pat< + (TFL_DepthwiseConv2DOp:$conv_out + (TFL_PadOp $input, (Arith_ConstantOp $paddings)), + $filter, + $bias, + $h_dilate, + $w_dilate, + $faf, + $padding, + $stride_h, + $stride_w, + $depth + ), (TFL_DepthwiseConv2DOp + $input, + $filter, + $bias, + $h_dilate, + $w_dilate, + $faf, + TFL_PAD_Same, + $stride_h, + $stride_w, + $depth + ), + [(HasStaticDim<1> $input), + (HasStaticDim<2> $input), + (IsBalancedPaddingArray<1, 3> $paddings), + (IsValidPadding $padding), + (IsSameStridedShapeDepthwise $conv_out, $input)]>; + +def FuseSamePaddingConv3D : Pat< + (TFL_Conv3DOp:$conv_out + (TFL_PadOp $input, (Arith_ConstantOp $paddings)), + $filter, + $bias, + $d_dilate, + $h_dilate, + $w_dilate, + $faf, + $padding, + $stride_d, + $stride_h, + $stride_w + ), (TFL_Conv3DOp + $input, + $filter, + $bias, + $d_dilate, + $h_dilate, + $w_dilate, + $faf, + TFL_PAD_Same, + $stride_d, + $stride_h, + $stride_w + ), + [(HasStaticDim<1> $input), + (HasStaticDim<2> $input), + (HasStaticDim<3> $input), + (IsBalancedPaddingArray<1, 4> $paddings), + (IsValidPadding $padding), + (IsSameStridedShape3D $conv_out, $input)]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index b81ffa3fe1c7e0..a11b20000222ad 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -21,6 +21,8 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/optimize.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" namespace mlir { @@ -60,8 +62,6 @@ std::unique_ptr> CreateLegalizeTFPass( std::unique_ptr> CreateLegalizeTFPass(); // Creates an instance of the TensorFlow Lite dialect Optimize pass. -std::unique_ptr> CreateOptimizePass( - bool enable_canonicalization, bool disable_fuse_mul_and_fc = false); std::unique_ptr> CreateOptimizePass(); // Creates an instance of the Tensorflow Lite batch matmul Optimize pass. @@ -248,6 +248,7 @@ CreatePartitionedTopologicalSortPass(); #define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS #define GEN_PASS_DECL_DENSETOSPARSEPASS #define GEN_PASS_DECL_LEGALIZETFPASS +#define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS #define GEN_PASS_DECL_MODIFYIONODESPASS #define GEN_PASS_DECL_OPTIMIZEPASS #define GEN_PASS_DECL_POSTQUANTIZEPASS @@ -260,6 +261,43 @@ CreatePartitionedTopologicalSortPass(); #define GEN_PASS_DECL_TRIMFUNCTIONSPASS #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" + +// Creates an instance of the TensorFlow Lite dialect LegalizeTF pass. +std::unique_ptr> CreateLegalizeTFPass( + const LegalizeTFPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect Optimize pass. +std::unique_ptr> CreateOptimizePass( + const OptimizePassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect PrepareTF pass. +std::unique_ptr> CreatePrepareTFPass( + const PrepareTFPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect LowerStaticTensorList +// pass. +std::unique_ptr> CreateLowerStaticTensorListPass( + const LowerStaticTensorListPassOptions& options); + +// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp +std::unique_ptr> CreateRaiseCustomOpsPass( + const RaiseCustomOpsPassOptions& options); + +// Creates an instance of the TensorFlow Lite dialect pass to add default +// quantization parameters. +std::unique_ptr> CreateDefaultQuantParamsPass( + const DefaultQuantParamsPassOptions& options); + +inline void registerOptimizePass() { + mlir::registerPass( + []() -> std::unique_ptr<::mlir::Pass> { return CreateOptimizePass(); }); +} + +inline void registerTensorFlowLitePasses() { + registerTensorFlowLiteTdPasses(); + registerOptimizePass(); +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 1d8403e32d6f56..8db083175a2226 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -147,10 +147,7 @@ def LiftTfliteFlexOpsPass : Pass<"tfl-lift-tflite-flex-ops", "mlir::func::FuncOp def LowerStaticTensorListPass : Pass<"tfl-lower-static-tensor-list", "mlir::ModuleOp"> { let summary = "Lower TensorList ops within TensorFlow Lite dialect."; let constructor = "CreateLowerStaticTensorListPass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect", - "quant::QuantizationDialect", - "quantfork::QuantizationForkDialect" - ]; + let dependentDialects = ["TFL::TensorFlowLiteDialect"]; let options = [ Option<"allow_tensorlist_pass_through_", "allow-tensorlist-pass-through", "bool", "false", @@ -183,20 +180,6 @@ def ModifyIONodesPass : Pass<"tfl-modify-io-nodes", "mlir::func::FuncOp"> { ]; } -def OptimizePass : Pass<"tfl-optimize", "mlir::func::FuncOp"> { - let summary = "Optimize within the TensorFlow Lite dialect"; - let constructor = "CreateOptimizePass()"; - let dependentDialects = ["TFL::TensorFlowLiteDialect"]; - let options = [ - Option<"enable_canonicalization_", "enable-canonicalization", - "bool", "false", - "Enable canonicalization during optimization pass.">, - Option<"disable_fuse_mul_and_fc_", "disable-fuse-mul-and-fc", - "bool", "false", - "Disable folding mul and fully connected ops during optimization pass.">, - ]; -} - def OptimizeBatchMatmulPass : Pass<"tfl-optimize-batch-matmul", "mlir::func::FuncOp"> { let summary = "Optimize FC with BatchMatmul within the TensorFlow Lite dialect"; let constructor = "CreateOptimizeBatchMatmulPass()"; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index d1bb76fa431ee2..f323b34df278f1 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -31,6 +31,7 @@ limitations under the License. #include #include +#include #include #include "absl/algorithm/container.h" @@ -123,6 +124,13 @@ class PrepareTFPass : public impl::PrepareTFPassBase { this->use_fake_quant_num_bits_ = use_fake_quant_num_bits; } + explicit PrepareTFPass(const PrepareTFPassOptions &options) { + this->unfold_batch_matmul_ = options.unfold_batch_matmul_; + this->allow_bf16_and_f16_type_legalization_ = + options.allow_bf16_and_f16_type_legalization_; + this->use_fake_quant_num_bits_ = options.use_fake_quant_num_bits_; + } + void runOnOperation() override; }; @@ -1571,6 +1579,12 @@ std::unique_ptr> CreatePrepareTFPass( use_fake_quant_num_bits); } +// Creates an instance of the TensorFlow Lite dialect PrepareTF pass. +std::unique_ptr> CreatePrepareTFPass( + const PrepareTFPassOptions &options) { + return std::make_unique(options); +} + // Creates an instance of the TensorFlow Lite dialect PrepareTF pass. std::unique_ptr> CreatePrepareTFPass() { return std::make_unique(); diff --git a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc index 7a8b35e4be7cde..f01d8aa737d2e0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc +++ b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc @@ -72,7 +72,7 @@ llvm::SmallVector PermuteShape(llvm::ArrayRef shape, // Determine if op commutes with transposes. Requires a strict // definition of Elementwise, all i/o shapes and types must be same-rank -// broadcastable and fully static. Consider moving this into attribute later. +// broadcastable. Consider moving this into attribute later. bool IsElementwise(Operation *op) { if (!(llvm::isa(op))) { @@ -90,11 +90,6 @@ bool IsElementwise(Operation *op) { return false; } - if (!opr1_type.hasStaticShape() && opr2_type.hasStaticShape() && - res_type.hasStaticShape()) { - return false; - } - return true; } diff --git a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc index ce2d51a66e5097..25d9b15fec858a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/raise_custom_ops.cc @@ -17,9 +17,7 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -30,7 +28,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { namespace TFL { @@ -50,6 +47,10 @@ struct RaiseCustomOpsPass this->target_ops_ = target_ops; } + explicit RaiseCustomOpsPass(const RaiseCustomOpsPassOptions &options) { + this->target_ops_ = options.target_ops_; + } + void runOnOperation() override; }; @@ -113,6 +114,11 @@ std::unique_ptr> CreateRaiseCustomOpsPass( return std::make_unique(target_ops); } +std::unique_ptr> CreateRaiseCustomOpsPass( + const RaiseCustomOpsPassOptions &options) { + return std::make_unique(options); +} + static PassRegistration pass; } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 1629000ff181df..613dfa8878d549 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 21da9d071e906c..0a450f9c28152b 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -45,6 +45,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/string_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tsl/platform/statusor.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h index 6834e0542a0a90..135ddb1faef32e 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/meta/type_traits.h" #include "absl/status/statusor.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 6a4dbf3e505ba6..76b00825628b2e 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project @@ -35,8 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" -#include "tsl/platform/statusor.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.h b/tensorflow/compiler/mlir/lite/utils/constant_utils.h index 99a2f610b82ad6..1340aa0f5bebee 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.h @@ -16,8 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_CONSTANT_UTILS_H_ +#include "absl/status/statusor.h" +#include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index 1edb08cff57423..b118bab483048a 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -16,9 +16,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc index 28c6106dcb7c34..5e070a04b1e11e 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc @@ -18,8 +18,15 @@ limitations under the License. #include #include -#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index d1dcf8c304b0a9..146cae1f2c4770 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -21,12 +21,19 @@ limitations under the License. #include #include +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" +#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index bada49a68a9e55..940d30c9c7929b 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -20,17 +20,15 @@ limitations under the License. #include #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project @@ -41,6 +39,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h index 9e01b5dbf75e5f..f2266f8920669a 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.h @@ -26,8 +26,10 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 7fe7ae8404137c..f85ea68d621ef6 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -23,13 +23,11 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -38,7 +36,8 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/core/platform/test.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc index cab3df456c0e00..6677f57c6fdd0d 100644 --- a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -17,9 +17,14 @@ limitations under the License. #include +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.h b/tensorflow/compiler/mlir/lite/utils/nms_utils.h index a0739ea10b25d4..e3487ba9a2ec14 100644 --- a/tensorflow/compiler/mlir/lite/utils/nms_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.h @@ -24,6 +24,7 @@ limitations under the License. #include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc index 3a3d20eff2d2c4..5f680c7db9be58 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc @@ -16,8 +16,11 @@ limitations under the License. #include +#include "flatbuffers/base.h" // from @flatbuffers +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc index 4448be5cdd4df6..650c372e42b2b4 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils_test.cc @@ -18,16 +18,18 @@ limitations under the License. #include #include -#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/core/platform/test.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/size_utils.cc b/tensorflow/compiler/mlir/lite/utils/size_utils.cc index a5ffb64eaf4300..9793097ee4362f 100644 --- a/tensorflow/compiler/mlir/lite/utils/size_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/size_utils.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc index 49c3fc70cd0381..07175a653b7b53 100644 --- a/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/size_utils_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/size_utils.h" -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "tensorflow/core/platform/test.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc index 2eead55a51415f..f43ed208e69b8a 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.cc @@ -17,8 +17,9 @@ limitations under the License. #include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h index b78f7c86e45436..e7e3e721ae8bac 100644 --- a/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 5e9bcc16d27537..24314630e65154 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -23,16 +23,12 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project @@ -40,7 +36,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/ir/types/dialect.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc index 5138d7475452cd..2acb4dccb88a18 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils_test.cc @@ -18,8 +18,11 @@ limitations under the License. #include #include -#include "tensorflow/core/lib/core/status_test_util.h" +#include "absl/status/status.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/platform/test.h" +#include "tsl/platform/status.h" namespace mlir { namespace TFL { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index 902d7b144ba69d..536762c3e44292 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -18,8 +18,10 @@ limitations under the License. #include #include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 0e7370c5fa499b..86306ab5a454ce 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -20,7 +20,12 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VALIDATORS_H_ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc index fbe0c1822c0b32..dca30aee7fe606 100644 --- a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc @@ -14,8 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/utils/variables_utils.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" diff --git a/tensorflow/compiler/mlir/lite/utils/variables_utils.h b/tensorflow/compiler/mlir/lite/utils/variables_utils.h index 2dd972495f8023..570f9afdb20d69 100644 --- a/tensorflow/compiler/mlir/lite/utils/variables_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/variables_utils.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VARIABLES_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_VARIABLES_UTILS_H_ +#include "mlir/IR/AffineMap.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project namespace mlir { diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index cf8d4487f2f593..8d9802deaeaa66 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -360,7 +360,7 @@ Status MlirFunctionOptimizationPass::Run( timings.Reset({kTfMlirCategory, "convert_mlir_to_graph"}); // Some or all passes are enabled. Convert MLIR module and return back // resulted graph. - Status status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( *module_ref, export_config, graph, flib_def, &control_ret_nodes); if (!status.ok()) { errors::AppendToMessage(&status, @@ -476,10 +476,11 @@ Status MlirV1CompatGraphOptimizationPass::Run( GraphExportConfig export_config; absl::flat_hash_set control_ret_nodes; - TF_RETURN_WITH_CONTEXT_IF_ERROR(tensorflow::tf2xla::v2::ConvertMlirToGraph( - *module_ref, export_config, options.graph, - options.flib_def, &control_ret_nodes), - "Error converting MLIR module back to graph"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + *module_ref, export_config, options.graph, options.flib_def, + &control_ret_nodes), + "Error converting MLIR module back to graph"); return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index ac10197df7d2c0..b6b1d17d17a4a7 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -31,11 +31,11 @@ cc_library( ":quantization_config", ":quantization_interfaces_inc_gen", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", - "//tensorflow/lite/kernels/internal:tensor_utils", - "//tensorflow/lite/tools/optimize:quantization_utils", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index 8e5496106c5279..5223b6200fb5a8 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -47,11 +47,11 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" -#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" -#include "tensorflow/lite/tools/optimize/quantization_utils.h" +#include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" namespace mlir { @@ -580,7 +580,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, std::vector quantized_values(real_values_attr.getNumElements()); if (auto uniform_type = dyn_cast(q_type)) { float min, max, scale; - tflite::tensor_utils::SymmetricQuantizeFloats( + mlir::lite::toco_legacy::PortableSymmetricQuantizeFloats( real_values.data(), real_values.size(), quantized_values.data(), &min, &max, &scale); // The scale has been adjusted, so the adjusted scale should be respected. @@ -598,7 +598,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, std::back_inserter(scales_inv), [](float scale) { return 1.0 / scale; }); - tflite::optimize::utils::SymmetricPerChannelQuantizeValues( + tflite_migration::optimize::utils::SymmetricPerChannelQuantizeValues( real_values.data(), scales_inv, dimension, uniform_type.getQuantizedDimension(), &quantized_values); } else { @@ -619,7 +619,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, } else if (width == 16) { if (const auto uniform_type = dyn_cast(q_type)) { const auto quantized_values = - tflite::optimize::utils::SymmetricQuantizeFloatsToInt16( + tflite_migration::optimize::utils::SymmetricQuantizeFloatsToInt16( real_values.data(), real_values.size(), uniform_type.getScale()); std::transform(quantized_values.begin(), quantized_values.end(), std::back_inserter(quantized_attr), @@ -640,7 +640,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, return {}; } const auto quantized_bias = - tflite::optimize::utils::SymmetricBiasQuantize( + tflite_migration::optimize::utils::SymmetricBiasQuantize( real_values.data(), real_values.size(), scales); std::transform(quantized_bias.begin(), quantized_bias.end(), std::back_inserter(quantized_attr), diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 6f4c2dd637a297..1bbf67389366f5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -129,7 +129,7 @@ void AddShapeLegalizationPasses(OpPassManager& pm) { void AddStablehloQuantToIntPasses(OpPassManager& pm) { pm.addNestedPass( - mlir::stablehlo::createStablehloLegalizeQuantToIntPass()); + mlir::stablehlo::createStablehloLegalizeQuantToMathPass()); // StableHLO -> MHLO legalization. pm.addPass(mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(createCanonicalizerPass()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc index d6859b7b95c84e..491fcb9f5e7946 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/saved_model_export.cc @@ -240,7 +240,7 @@ absl::StatusOr ConvertMlirModuleToExportedModel( FunctionDefLibrary()}; std::unique_ptr graph; absl::flat_hash_set control_ret_nodes{}; - TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::ConvertMlirToGraph( + TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module_op, config, &graph, &flib_def, &control_ret_nodes)); GraphDef graph_def{}; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc index 318ccbd5300d81..e71af87690a287 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace mlir::quant::stablehlo { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc index 8cbd48d29715cc..2e9d7a8794bb6f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/legalize_tf_quant_test.cc @@ -25,12 +25,12 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/lib/core/status_test_util.h" namespace mlir::quant::stablehlo { namespace { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc index e96ab83369588e..d8878ff8149356 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.cc @@ -31,7 +31,7 @@ void AddQuantizationLoweringPasses(mlir::OpPassManager& pm) { pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addPass(mhlo::createHloLegalizeToStablehloPass()); pm.addNestedPass( - mlir::stablehlo::createStablehloLegalizeQuantToIntPass()); + mlir::stablehlo::createStablehloLegalizeQuantToMathPass()); pm.addPass(mhlo::createStablehloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(CreateVerifyQuantLegalizationPass()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td index 75940a24cf484f..5f6449dbfa03bb 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.td @@ -30,14 +30,15 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" def LiftDotGeneralWithBiasSameShape : Pat< (StableHLO_AddOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; def LiftConvWithBiasSameShape : Pat< @@ -86,14 +87,15 @@ def LiftConvWithBias : Pat< def LiftDotGeneralWithBias : Pat< (StableHLO_AddOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_fn"> (ArgumentList $lhs, $rhs, $bias), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias)], [], (addBenefit 5)>; def LiftConvWithBiasDynamic : Pat< @@ -121,7 +123,7 @@ def LiftConvWithBiasDynamic : Pat< def LiftDotGeneralWithBiasDynamic : Pat< (StableHLO_AddOp:$res - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -130,7 +132,8 @@ def LiftDotGeneralWithBiasDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 10)>; //===----------------------------------------------------------------------===// @@ -161,14 +164,15 @@ def LiftConvWithRelu : Pat< def LiftDotGeneralWithRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_relu_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst)], [], (addBenefit 10)>; @@ -198,7 +202,7 @@ def LiftConvWithReluDynamic : Pat< def LiftDotGeneralWithReluDynamic : Pat< (StableHLO_MaxOp:$res - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp (StableHLO_ConstantOp $cst), (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -207,7 +211,8 @@ def LiftDotGeneralWithReluDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; @@ -237,14 +242,15 @@ def LiftDotGeneralWithRelu6 : Pat< (StableHLO_ClampOp:$res (StableHLO_ConstantOp $cst_0), (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_relu6_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; //===----------------------------------------------------------------------===// @@ -255,7 +261,7 @@ def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu_fn"> @@ -263,7 +269,8 @@ def LiftDotGeneralWithBiasSameShapeAndRelu : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; @@ -320,7 +327,7 @@ def LiftDotGeneralWithBiasAndRelu : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu_fn"> @@ -328,7 +335,8 @@ def LiftDotGeneralWithBiasAndRelu : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias)], [], (addBenefit 10)>; @@ -363,7 +371,7 @@ def LiftConvWithBiasAndReluDynamic : Pat< def LiftDotGeneralWithBiasAndReluDynamic : Pat< (StableHLO_MaxOp:$res (StableHLO_AddOp:$add_0 - (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (StableHLO_DotGeneralOp:$dot_general_0 $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -375,7 +383,8 @@ def LiftDotGeneralWithBiasAndReluDynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (FloatValueEquals<"0"> $cst), (IsStableHLOConstantOp $bias), (AreTheSameValue $dot_general_0, $dot_general_1), (AreTheSameValue $add_0, $add_1)], [], (addBenefit 15)>; @@ -384,7 +393,7 @@ def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), $bias), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_same_shape_and_relu6_fn"> @@ -392,7 +401,8 @@ def LiftDotGeneralWithBiasSameShapeAndRelu6 : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; def LiftConvWithBiasAndRelu6 : Pat< @@ -424,7 +434,7 @@ def LiftDotGeneralWithBiasAndRelu6 : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_BroadcastInDimOp $bias, $dims)), (StableHLO_ConstantOp $cst_1)), (LiftAsTFXlaCallModule<"composite_dot_general_with_bias_and_relu6_fn"> @@ -432,7 +442,8 @@ def LiftDotGeneralWithBiasAndRelu6 : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1)], [], (addBenefit 10)>; def LiftConvWithBiasAndRelu6Dynamic : Pat< @@ -466,7 +477,7 @@ def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (StableHLO_ConstantOp $cst_0), (StableHLO_AddOp (StableHLO_DotGeneralOp:$dot_general_0 - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (StableHLO_DynamicBroadcastInDimOp $bias, (Shape_ShapeOfOp $dot_general_1), $_, $_, $_)), @@ -476,5 +487,6 @@ def LiftDotGeneralWithBiasAndRelu6Dynamic : Pat< (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res), (IsStableHLOConstantOp $bias), (FloatValueEquals<"0"> $cst_0), (FloatValueEquals<"6"> $cst_1), (AreTheSameValue $dot_general_0, $dot_general_1)], [], (addBenefit 15)>; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td index eaa8a9092f41f2..db0103fea2b7e5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.td @@ -47,13 +47,14 @@ def LiftConv : Pat< def LiftDotGeneral : Pat< (StableHLO_DotGeneralOp:$res - $lhs, $rhs, $dot_dimension_numbers, $precision_config), + $lhs, $rhs, $dot_dimension_numbers, $precision_config, $algorithm), (LiftAsTFXlaCallModule<"composite_dot_general_fn"> (ArgumentList $lhs, $rhs), (ResultList $res), (NamedAttributeList (NamedAttr<"dot_dimension_numbers"> $dot_dimension_numbers), - (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)))), + (NamedAttr<"precision_config"> (DefaultOrNullAttr $precision_config)), + (NamedAttr<"algorithm"> (DefaultOrNullAttr $algorithm)))), [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; def LiftGather : Pat< diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD index 0999d37da524c2..5ca03bfc209656 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD @@ -46,10 +46,7 @@ pytype_strict_library( # testonly = 1, # srcs = ["integration_test/quantize_model_test_base.py"], # tags = ["no_pip"], -# visibility = [ -# "//learning/brain/mlir/quantization/stablehlo:__subpackages__", -# "//tensorflow/compiler/mlir/quantization:__subpackages__", -# ], +# visibility = ["//visibility:private"], # deps = [ # "//third_party/py/mlir:ir", # "//third_party/py/mlir:stablehlo_dialect", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD index b465fe15e8d57c..80167b43fb9c5e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/BUILD @@ -51,11 +51,11 @@ tf_cc_test( "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", "@stablehlo//:stablehlo_ops", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc index c5daf8455c3753..2510fb96e39591 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.cc @@ -102,7 +102,7 @@ class WritableFileWrapper : public llvm::raw_ostream { } void write_impl(const char* ptr, size_t size) override { - if (file_ && !file_->Append(tsl::StringPiece(ptr, size)).ok()) { + if (file_ && !file_->Append(absl::string_view(ptr, size)).ok()) { file_ = nullptr; } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc index c3034f4294b13d..be49ddb7e03b08 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index ddfc905acd2365..fcd42b88cc30c9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -79,7 +79,7 @@ cc_library( hdrs = ["tf_to_xla_attribute_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/compiler/mlir/lite/core/c:common", + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", "//tensorflow/compiler/mlir/lite/kernels:padding", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:constant_fold", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index b3fab1cc04e7a8..e7dc98b8e9535d 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1162,6 +1162,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/platform:status_matchers", "//tensorflow/core/protobuf/tpu:topology_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 7b90a434bad98e..cbde97456aca43 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -1137,7 +1137,7 @@ to be batched.}]>:$captured_tensors, DefaultValuedOptionalAttr:$low_priority_allowed_batch_sizes, DefaultValuedOptionalAttr:$low_priority_max_enqueued_batches, DefaultValuedOptionalAttr, "\"low_priority_padding_with_max_batch_size\"">:$mixed_priority_policy, - DefaultValuedOptionalAttr, "\"PAD_UP\"">:$batch_padding_policy, + DefaultValuedOptionalAttr, "\"PAD_UP\"">:$batch_padding_policy, DefaultValuedOptionalAttr:$enable_large_batch_splitting ); @@ -11293,6 +11293,7 @@ elements from `input_dataset` in parallel.}]>:$num_parallel_calls, DefaultValuedOptionalAttr:$use_inter_op_parallelism, DefaultValuedOptionalAttr:$deterministic, DefaultValuedOptionalAttr:$preserve_cardinality, + DefaultValuedOptionalAttr:$use_unbounded_threadpool, DefaultValuedOptionalAttr:$metadata ); diff --git a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir index 4e820220598398..a458a20d49e510 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/freeze_variables.mlir @@ -424,7 +424,7 @@ module { } // CHECK: func private @f_callee(%[[ARG0:.*]]: tensor<0xf32>) -> tensor<0xf32> - // CHECK-SAME: tf._input_shapes = [#tf_type.shape<00>] + // CHECK-SAME: tf._input_shapes = [#tf_type.shape<0>] func.func private @f_callee(%arg0: tensor<0xf32>, %arg1: tensor<*x!tf_type.resource>) -> tensor<0xf32> attributes {tf._input_shapes = [#tf_type.shape<0>, #tf_type.shape<>]} { %0 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf_type.resource>) -> tensor<0xf32> %1 = "tf.AddV2"(%arg0, %0) : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index f7bac4ba31b50a..b6e8e1c9b9ca07 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -974,7 +974,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @call_in_graph_1 func.func @call_in_graph_1(%arg0: tensor, %arg1: tensor<5x5x1x32xbf16>) -> tensor<*xbf16> { - // CHECK: tf_executor.fetch %outputs : tensor + // CHECK: tf_executor.fetch %outputs : tensor %0 = tf_executor.graph { %1:2 = tf_executor.island wraps "tf.PartitionedCall"(%arg0, %arg1) { config = "", config_proto = "", executor_type = "", f = @call_in_graph_func_1} : (tensor, tensor<5x5x1x32xbf16>) -> tensor<*xbf16> @@ -985,7 +985,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr // CHECK-LABEL: func @call_in_graph_func_1 func.func @call_in_graph_func_1(%arg0: tensor, %arg1: tensor<5x5x1x32xbf16>) -> tensor { - // CHECK: tf_executor.fetch %outputs : tensor + // CHECK: tf_executor.fetch %outputs : tensor %0 = tf_executor.graph { %1:2 = tf_executor.island wraps "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}: (tensor, tensor<5x5x1x32xbf16>) -> tensor tf_executor.fetch %1#0 : tensor @@ -2265,4 +2265,14 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr return %3#1, %3#2, %4, %5 : tensor<1x1120x?xi32>, tensor<1x1120x?xi32>, tensor<1120x?xi32>, tensor<2xi32> } + // CHCK-LABEL: func @infer_return_type_static_out + func.func @infer_return_type_static_out(%arg0: tensor, %arg1: tensor) -> tensor<1x28x28x3xf32> { + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}: (tensor, tensor) -> tensor<1x28x28x3xf32> + func.return %0 : tensor<1x28x28x3xf32> + } + + // CHCK: %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor, tensor) -> tensor<1x28x28x3xf32> + + } + diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 1789b1d0e65418..4a31291d4c4ae6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -289,7 +289,7 @@ module attributes {tf_saved_model.semantics} { %1 = "tf.AddV2"(%arg0, %0) : (tensor, tensor) -> tensor func.return %1 : tensor } - // CHECK: func.func private @f_callee(%arg0: tensor) -> tensor attributes {tf._input_shapes = [#tf_type.shape<00>]} { + // CHECK: func.func private @f_callee(%arg0: tensor) -> tensor attributes {tf._input_shapes = [#tf_type.shape<0>]} { // CHECK: %cst = "tf.Const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor // CHECK: %0 = "tf.AddV2"(%arg0, %cst) : (tensor, tensor) -> tensor // CHECK: return %0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index a7ee2b16944deb..f4004240779ff8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -829,6 +829,7 @@ cc_library( "@local_xla//xla/service:shape_inference", "@local_xla//xla/translate/hlo_to_mhlo:hlo_utils", "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", + "@local_xla//xla/tsl/util:env_var", ], ) @@ -1094,3 +1095,14 @@ cc_library( "@local_tsl//tsl/platform:path", ], ) + +tf_cc_test( + name = "shape_inference_test", + srcs = ["shape_inference_test.cc"], + deps = [ + ":shape_inference_pass", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:env", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index 54fb7cf4c1e845..77f9361ab94a27 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -36,13 +36,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/tsl/framework/device_type.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/debug_data_dumper.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace tfrt_compiler { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 4bcd7e03167055..89e00142aa4a8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -87,6 +87,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/tsl/util/env_var.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/shape_inference.h" @@ -110,10 +111,24 @@ namespace mlir { namespace TF { namespace { +MLIRContext::Threading GetMlirContextThreading() { + bool enable_single_thread_mlir_context = []() { + bool result = false; + if (auto status = tsl::ReadBoolFromEnvVar(kMLIRContextSingleThreadVar, + /*default_val=*/false, &result); + status.ok()) { + return result; + } + return false; + }(); + return enable_single_thread_mlir_context ? MLIRContext::Threading::DISABLED + : MLIRContext::Threading::ENABLED; +} + // Compute a refined type between two types `lhs` and `rhs`, the result type -// is always more refined (i.e. has more static information) than `lhs` -// This method will actually merge the information contained in the -// types, it is capable of refining: +// is always at least as refined as (i.e. has more static information) than +// `lhs` This method will actually merge the information contained in the types, +// it is capable of refining: // tensor>> // and: // tensor>> @@ -443,6 +458,11 @@ Type GetType(Attribute shape_attr, Attribute type_attr) { } } // namespace +// Create a MLIRContext based on the threading setup in the env var. +std::unique_ptr MakeMLIRContextWithThreading() { + return std::make_unique(GetMlirContextThreading()); +} + // Returns whether type can be further refined. bool CanBeRefined(Type type) { auto shape_type = mlir::dyn_cast(type); @@ -1024,7 +1044,7 @@ class ShapeInference { // each `XlaCallModule` op. Uses its own MLIRContext since the loader needs to // load additional dialects, which is not allowed for the main context since // shape inference may be called from a pass. - MLIRContext xla_call_module_context_; + std::unique_ptr xla_call_module_context_; DenseMap> xla_call_module_loaders_; }; @@ -1036,6 +1056,7 @@ ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module, symbol_users_(symbol_table_, module), graph_version_(graph_version), propagate_caller_callee_constants_(propagate_caller_callee_constants) { + xla_call_module_context_ = MakeMLIRContextWithThreading(); for (const auto& op_type : ops_to_skip) { ops_to_skip_.insert(op_type); } @@ -1242,10 +1263,10 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { mlir::DialectRegistry registry; registry.insert(); mlir::func::registerAllExtensions(registry); - xla_call_module_context_.appendDialectRegistry(registry); + xla_call_module_context_->appendDialectRegistry(registry); auto l = tensorflow::XlaCallModuleLoader::Create( - &xla_call_module_context_, op.getVersion(), op.getModule().str(), + xla_call_module_context_.get(), op.getVersion(), op.getModule().str(), std::move(disabled_checks), std::move(platforms), /*num_invocation_args=*/op.getArgs().size(), op.getHasTokenInputOutput()); @@ -2308,12 +2329,15 @@ bool ShapeInference::RefineWithInferTypeOpInterface( // Map each of the results of the call to the returned type of the // function. bool changed = false; - for (auto result : zip(op->getResults(), inferred)) { - if (std::get<0>(result).getType() == std::get<1>(result)) continue; - - if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result), - std::get<0>(result))) + for (auto [result, inferred_type] : zip(op->getResults(), inferred)) { + auto result_type = result.getType(); + auto new_type = TypeMeet(inferred_type, result_type); + if (new_type == result_type) { continue; + } + if (!UpdateTypeAndInsertIncompatibleUseCasts(new_type, result)) { + continue; + } changed = true; } return changed; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h index 46c1bc9c00e55a..9075754d7f550a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_SHAPE_INFERENCE_H_ #include +#include #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -31,6 +32,9 @@ limitations under the License. namespace mlir { namespace TF { +inline constexpr char kMLIRContextSingleThreadVar[] = + "TF_USE_SINGLE_THREAD_MLIR_CONTEXT"; + // Returns whether type can be further refined. bool CanBeRefined(Type type); @@ -71,6 +75,9 @@ FailureOr InferShapeForFunction(func::FuncOp func, int64_t max_iterations = 10, ArrayRef ops_to_skip = {}); +// Create a MLIRContext based on the threading setup in the env var. +std::unique_ptr MakeMLIRContextWithThreading(); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc new file mode 100644 index 00000000000000..416807c3708488 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference_test.cc @@ -0,0 +1,39 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h" + +#include + +#include +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace TF { +namespace { + +TEST(ShapeInferenceTest, CreateMultiThreadedMLIRContext) { + std::unique_ptr ctx = MakeMLIRContextWithThreading(); + EXPECT_TRUE(ctx->isMultithreadingEnabled()); +} + +TEST(ShapeInferenceTest, CreateSingleThreadedMLIRContext) { + setenv(kMLIRContextSingleThreadVar, "true", 1); + std::unique_ptr ctx = MakeMLIRContextWithThreading(); + EXPECT_FALSE(ctx->isMultithreadingEnabled()); +} + +} // namespace +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc index 3735199d8a33c8..1ccfc8775d1c44 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.cc @@ -75,7 +75,7 @@ void GraphOptPass::runOnOperation() { GraphExportConfig confs; auto graph = std::make_unique(flib_def); absl::flat_hash_set control_ret_nodes; - Status status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + Status status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module_in, confs, &graph, &flib_def, &control_ret_nodes); if (!status.ok()) { mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.message(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index ad9befdfe5fb28..5ffc11344ba3ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -254,3 +254,44 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "node_order", + srcs = ["node_order.cc"], + hdrs = ["node_order.h"], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cc_test( + name = "node_order_test", + size = "small", + srcs = [ + "node_order_test.cc", + ], + deps = [ + ":node_order", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:direct_session_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 3e72550a88749a..e50cb03a892ce8 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -127,6 +127,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -1795,7 +1796,7 @@ mlir::Location ImporterBase::GetLocation(const Node& node) { if (stack_trace != nullptr) { DVLOG(1) << "Stack available for " << node.name(); - for (const StackFrame& frame : stack_trace->ToFrames()) { + for (const StackFrame& frame : stack_trace->ToUncachedFrames()) { auto file_name = mlir::StringAttr::get(context_, frame.file_name); // Use col 1 as there is no column info in StackTrace. auto file_line_loc = diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc b/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc new file mode 100644 index 00000000000000..58a2751cb4b2a5 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order.cc @@ -0,0 +1,108 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +void TopologicalOrdering( + const Graph& g, const std::function& emit, + const std::function& get_grouping_key) { + std::unordered_map group_key_string_to_integer; + absl::flat_hash_map node_to_group; + absl::flat_hash_map remaining_incoming_nodes; + using Ready = std::vector; + std::vector group_members_that_are_ready; + std::set groups_that_are_ready; + + // Visit all nodes once, for initialization. It doesn't matter whether we use + // BFS or DFS. + DFS( + g, [](Node*) {}, + [&](Node* n) { + // Find which group this node belongs to. + std::string group_key_string = get_grouping_key(n); + auto entry = group_key_string_to_integer.try_emplace( + group_key_string, group_key_string_to_integer.size()); + int group_key = entry.first->second; + node_to_group[n] = group_key; + if (!entry.second) { + group_members_that_are_ready.push_back({}); + } + + // Count the incoming nodes and store. Also remember nodes ("sources") + // that don't have any inputs. + auto in_nodes = n->in_nodes(); + int num_incoming = std::distance(in_nodes.begin(), in_nodes.end()); + remaining_incoming_nodes[n] = num_incoming; + if (num_incoming == 0) { + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[group_key].push_back(n); + groups_that_are_ready.emplace(group_key); + } + }); + + int num_nodes = remaining_incoming_nodes.size(); + + // We emit one node per step, thus we just run this as often as we have nodes. + int current_group = 0; + for (int i = 0; i < num_nodes; i++) { + if (groups_that_are_ready.find(current_group) == + groups_that_are_ready.end()) { + current_group = *groups_that_are_ready.begin(); + } + + // NO_CDC: This array is max(group_key) + 1. + int size = group_members_that_are_ready[current_group].size(); + assert(size); + // NO_CDC: This array is max(group_key) + 1. + Node* node = group_members_that_are_ready[current_group][--size]; + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[current_group].pop_back(); + if (size == 0) { + groups_that_are_ready.erase(current_group); + } + + // Emit the operation and make its results available. + emit(node); + + for (Node* out : node->out_nodes()) { + remaining_incoming_nodes[out]--; + if (remaining_incoming_nodes[out] == 0) { + int group_key = node_to_group[out]; + // NO_CDC: This array is max(group_key) + 1. + if (group_members_that_are_ready[group_key].empty()) { + groups_that_are_ready.emplace(group_key); + } + // NO_CDC: This array is max(group_key) + 1. + group_members_that_are_ready[group_key].push_back(out); + } + } + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order.h b/tensorflow/compiler/mlir/tensorflow/translate/node_order.h new file mode 100644 index 00000000000000..4cb8e75efa7613 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +struct GroupByDevice { + std::string operator()(const Node* node) const { + return node->requested_device(); + } +}; + +// Performs a topological ordering of nodes. +// This has the property that any child node of a parent node p is emitted +// before p. A grouping function is used to break ties if multiple child nodes +// (of possibly different parents) are ready to be emitted at some point, which +// is when we prefer to stay in the current group. +// The "emit" function is used for outputing the result, and is called once +// for each node. +// This algorithm is O(n). +void TopologicalOrdering( + const Graph& g, const std::function& emit, + const std::function& get_grouping_key); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSLATE_NODE_ORDER_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc b/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc new file mode 100644 index 00000000000000..fc1d6e177f1dcd --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/translate/node_order_test.cc @@ -0,0 +1,239 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/tensorflow/translate/node_order.h" + +#include +#include +#include + +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/graph_def_builder_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +REGISTER_OP("TestParams").Output("o: float"); +REGISTER_OP("TestInput").Output("a: float").Output("b: float"); +REGISTER_OP("TestMul").Input("a: float").Input("b: float").Output("o: float"); +REGISTER_OP("TestUnary").Input("a: float").Output("o: float"); +REGISTER_OP("TestTwoOutputs").Output("a: float").Output("b: float"); +REGISTER_OP("TestBinary") + .Input("a: float") + .Input("b: float") + .Output("o: float"); + +// Compares that the order of nodes in 'inputs' respects the +// pair orders described in 'ordered_pairs'. +bool ExpectBefore(const std::vector>& ordered_pairs, + const std::vector& inputs, string* error) { + for (const std::pair& pair : ordered_pairs) { + const string& before_node = pair.first; + const string& after_node = pair.second; + bool seen_before = false; + bool seen_both = false; + for (const Node* node : inputs) { + if (!seen_before && after_node == node->name()) { + *error = std::string("Saw ") + after_node + std::string(" before ") + + before_node; + return false; + } + + if (before_node == node->name()) { + seen_before = true; + } else if (after_node == node->name()) { + seen_both = seen_before; + break; + } + } + if (!seen_both) { + *error = std::string("didn't see either ") + before_node + + std::string(" or ") + after_node; + return false; + } + } + + return true; +} + +TEST(AlgorithmTest, TopologicalOrdering) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1")); + Node* n2 = + SourceOp("TestParams", b.opts().WithName("n2").WithControlInput(n1)); + Node* n3 = + SourceOp("TestParams", b.opts().WithName("n3").WithControlInput(n2)); + Node* n4 = BinaryOp("TestMul", n1, {n3, 0}, b.opts().WithName("n4")); + Node* n5 = BinaryOp("TestMul", n1, {n3, 0}, + b.opts().WithName("n5").WithControlInput(n1)); + Node* n6 = BinaryOp("TestMul", n2, {n3, 0}, b.opts().WithName("n6")); + n3->set_requested_device("a"); + n4->set_requested_device("a"); + n5->set_requested_device("b"); + n6->set_requested_device("b"); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector> desired_order = { + {"n1", "n2"}, // because of control dependency + {"n2", "n3"}, // because of control dependency + {"n3", "n4"}, // because of NodeScorerDevice + {"n1", "n4"}, // data dependency + {"n1", "n5"}, // data dependency + {"n2", "n6"}, // data dependency + {"n3", "n4"}, // data dependency + {"n3", "n5"}, // data dependency + {"n3", "n6"}, // data dependency + }; + string error; + EXPECT_TRUE(ExpectBefore(desired_order, order, &error)) << error; +} + +TEST(AlgorithmTest, TopologicalOrderingOnShallowTree) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1").WithDevice("a")); + Node* n2 = + SourceOp("TestParams", + b.opts().WithName("n2").WithDevice("b").WithControlInput(n1)); + Node* n3 = + SourceOp("TestParams", + b.opts().WithName("n3").WithDevice("c").WithControlInput(n2)); + Node* n4 = + SourceOp("TestParams", + b.opts().WithName("n4").WithDevice("a").WithControlInput(n1)); + Node* n5 = + SourceOp("TestParams", + b.opts().WithName("n5").WithDevice("b").WithControlInput(n2)); + Node* n6 = + SourceOp("TestParams", + b.opts().WithName("n6").WithDevice("c").WithControlInput(n3)); + Node* n7 = + SourceOp("TestParams", + b.opts().WithName("n7").WithDevice("a").WithControlInput(n4)); + Node* n8 = + SourceOp("TestParams", + b.opts().WithName("n8").WithDevice("b").WithControlInput(n5)); + Node* n9 = + SourceOp("TestParams", + b.opts().WithName("n9").WithDevice("c").WithControlInput(n6)); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector desired_order = { + g.source_node(), n1, n4, n7, n2, n5, n8, n3, n6, n9, g.sink_node()}; + for (int i = 0; i < desired_order.size(); i++) { + desired_order[i] = g.FindNodeId(desired_order[i]->id()); + } + EXPECT_EQ(order, desired_order); +} + +TEST(AlgorithmTest, TopologicalOrderingGivesTheSameResultIfCalledTwice) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + SourceOp("TestParams", b.opts().WithName("n1")); + SourceOp("TestParams", b.opts().WithName("n2")); + SourceOp("TestParams", b.opts().WithName("n3")); + SourceOp("TestParams", b.opts().WithName("n4")); + SourceOp("TestParams", b.opts().WithName("n5")); + SourceOp("TestParams", b.opts().WithName("n6")); + SourceOp("TestParams", b.opts().WithName("n7")); + SourceOp("TestParams", b.opts().WithName("n8")); + SourceOp("TestParams", b.opts().WithName("n9")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order1; + std::vector order2; + + TopologicalOrdering( + g, [&](Node* n) { order1.push_back(n); }, + [&](const Node* node) { return std::string("same"); }); + + TopologicalOrdering( + g, [&](Node* n) { order2.push_back(n); }, + [&](const Node* node) { return std::string("same"); }); + + EXPECT_EQ(order1, order2); +} + +TEST(AlgorithmTest, TopologicalOrderingOnChain) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestParams", b.opts().WithName("n1")); + Node* n2 = UnaryOp("TestUnary", n1, b.opts().WithName("n2")); + Node* n3 = UnaryOp("TestUnary", n2, b.opts().WithName("n3")); + Node* n4 = UnaryOp("TestUnary", n3, b.opts().WithName("n4")); + Node* n5 = UnaryOp("TestUnary", n4, b.opts().WithName("n5")); + Node* n6 = UnaryOp("TestUnary", n5, b.opts().WithName("n6")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector desired_order = {g.source_node(), n1, n2, n3, n4, n5, n6, + g.sink_node()}; + for (int i = 0; i < desired_order.size(); i++) { + desired_order[i] = g.FindNodeId(desired_order[i]->id()); + } + EXPECT_EQ(order, desired_order); +} + +TEST(AlgorithmTest, TopologicalOrderingOnMultipleOutputs) { + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + using namespace ::tensorflow::ops; // NOLINT + Node* n1 = SourceOp("TestTwoOutputs", b.opts().WithName("n1")); + UnaryOp("TestUnary", {n1, 0}, b.opts().WithName("n2")); + UnaryOp("TestUnary", {n1, 1}, b.opts().WithName("n3")); + UnaryOp("TestUnary", {n1, 0}, b.opts().WithName("n4")); + UnaryOp("TestUnary", {n1, 1}, b.opts().WithName("n5")); + + Graph g(OpRegistry::Global()); + TF_ASSERT_OK(GraphDefBuilderToGraph(b, &g)); + + std::vector order; + TopologicalOrdering(g, [&](Node* n) { order.push_back(n); }, GroupByDevice()); + + std::vector> desired_order = { + {"n1", "n2"}, + {"n1", "n3"}, + {"n1", "n4"}, + {"n1", "n5"}, + }; + string error; + EXPECT_TRUE(ExpectBefore(desired_order, order, &error)) << error; +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc index ab73156f29c4b3..92ecf3082588ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate_registration.cc @@ -130,7 +130,7 @@ static LogicalResult MlirToGraphTranslateFunction(ModuleOp module, auto graph = std::make_unique(tensorflow::OpRegistry::Global()); absl::flat_hash_set control_ret_nodes; - auto status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, confs, &graph, flib_def.get(), &control_ret_nodes); if (!status.ok()) { LOG(ERROR) << "Export to Graph failed: " << status; @@ -179,7 +179,7 @@ static LogicalResult MlirToGraphdefTranslateFunction( std::make_unique(tensorflow::OpRegistry::Global()); absl::flat_hash_set control_ret_nodes; - auto status = tensorflow::tf2xla::v2::ConvertMlirToGraph( + auto status = tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, confs, &graph, &flib_def, &control_ret_nodes); if (!status.ok()) { LOG(ERROR) << "Export to Graph failed: " << status; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc index 9262f87edb46bd..3720a09a11d09b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/cluster_util_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/errors.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace mlir::TF { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index ba538b2f470fee..bcf9a21d26efc1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -74,9 +74,11 @@ namespace { llvm::SmallVector FindMatchingDevices( ParsedDevices devices, const ParsedDevice& spec) { llvm::SmallVector matching_devices; - for (const auto& device : devices) - if (DeviceNameUtils::IsCompleteSpecification(spec, device)) + for (const auto& device : devices) { + if (DeviceNameUtils::IsCompleteSpecification(spec, device)) { matching_devices.push_back(device); + } + } return matching_devices; } @@ -623,6 +625,71 @@ absl::StatusOr> GetDeviceCoordinates( return device_coordinates; } +absl::StatusOr GetXlaDeviceAssignmentProto( + llvm::StringRef topology_attr, int num_replicas, int num_cores_per_replica, + llvm::ArrayRef device_assignment_attr) { + tpu::TopologyProto topology_proto; + if (!topology_proto.ParseFromString(topology_attr.str())) + return absl::InvalidArgumentError(absl::StrCat( + "failed to parse '", kTopologyAttr, "' attribute to TopologyProto")); + + if (topology_proto.mesh_shape_size() < 4) { + return absl::InvalidArgumentError(absl::StrCat( + "The size of mesh_shape must be larger than or equal to 4, but got ", + topology_proto.mesh_shape_size())); + } + + const int bound_x = topology_proto.mesh_shape(0); + const int bound_y = topology_proto.mesh_shape(1); + const int bound_z = topology_proto.mesh_shape(2); + const int bound_core = topology_proto.mesh_shape(3); + + const int expected_device_assignment_size = + num_replicas * num_cores_per_replica * kTPUTopologyRank; + const int device_assignment_attr_size = device_assignment_attr.size(); + if (device_assignment_attr_size != expected_device_assignment_size) + return absl::InvalidArgumentError(absl::StrCat( + "length of '", kDeviceAssignmentAttr, + "' must be 'num_replicas' * 'num_cores_per_replica' * ", + kTPUTopologyRank, " (", num_replicas, " * ", num_cores_per_replica, + " * ", kTPUTopologyRank, "), got ", device_assignment_attr.size())); + + // TPU XLA device ID is determined by its device coordinate, from major to + // minor coordinates (z, y, x, core). + auto location_to_id = [&](int x, int y, int z, int core) { + return (x + bound_x * (y + bound_y * z)) * bound_core + core; + }; + + std::vector used_device_ids(bound_x * bound_y * bound_z * bound_core, + false); + + xla::DeviceAssignment device_assignment(num_replicas, num_cores_per_replica); + int pos = 0; + for (int replica = 0; replica < num_replicas; ++replica) { + for (int logical_core = 0; logical_core < num_cores_per_replica; + ++logical_core) { + int x = device_assignment_attr[pos++]; + int y = device_assignment_attr[pos++]; + int z = device_assignment_attr[pos++]; + int core = device_assignment_attr[pos++]; + if (DeviceCoordinateOutOfBound(x, y, z, core, bound_x, bound_y, bound_z, + bound_core)) + return DeviceCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, core, + bound_x, bound_y, bound_z, bound_core); + const int device_id = location_to_id(x, y, z, core); + if (used_device_ids[device_id]) + return DuplicateCoordinateErrorMsg(kDeviceAssignmentAttr, x, y, z, + core); + + used_device_ids[device_id] = true; + device_assignment(replica, logical_core) = device_id; + } + } + xla::DeviceAssignmentProto device_assignment_proto; + device_assignment.Serialize(&device_assignment_proto); + return device_assignment_proto; +} + absl::StatusOr GetTPUCompilationAndExecutionDevices( ParsedDevices devices, int num_replicas, int num_cores_per_replica, llvm::StringRef topology_attr, diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h index f7c9b29d6cfdcc..250c21d627c7ed 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h @@ -32,6 +32,7 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/statusor.h" +#include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -239,6 +240,11 @@ absl::StatusOr GetTPUCompilationAndExecutionDevices( int num_cores_per_replica, llvm::StringRef topology_attr, llvm::ArrayRef device_assignment_attr); +// Converts a device assignment attribute to an XLA device assignment proto. +absl::StatusOr GetXlaDeviceAssignmentProto( + llvm::StringRef topology_attr, int num_replicas, int num_cores_per_replica, + llvm::ArrayRef device_assignment_attr); + // Virtual device name of the passed logical core. The logical core is the index // of a core within a replica. std::string GetDeviceAliasForLogicalCore(int core_index); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index c6d80802b2aa0a..8527ae80b967d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/tpu/topology.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -515,6 +516,130 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh2x2x2) { EXPECT_EQ(computation_device_1.replica_device_ids(2), 3); EXPECT_EQ(computation_device_1.replica_device_ids(3), 7); } +TEST(TPURewriteDeviceUtilTest, ValidXLADeviceAssignmentMesh1x2x1x3) { + tpu::TopologyProto topology_proto; + { + topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(2); + topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(3); + topology_proto.set_num_tasks(3); + topology_proto.set_num_tpu_devices_per_task(2); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(2); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(2); + } + + std::string topology_attr = topology_proto.SerializeAsString(); + std::vector device_assignment_attr{ + 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0}; + + llvm::SmallVector devices; + std::vector device_names = + MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2); + ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices)); + + auto xla_device_assignment = GetXlaDeviceAssignmentProto( + topology_attr, /*num_replicas=*/2, /*num_cores_per_replica=*/3, + device_assignment_attr); + + TF_ASSERT_OK(xla_device_assignment.status()); + EXPECT_EQ(xla_device_assignment->replica_count(), 2); + EXPECT_EQ(xla_device_assignment->computation_count(), 3); + ASSERT_EQ(xla_device_assignment->computation_devices_size(), 3); + const auto& computation_device_0 = + xla_device_assignment->computation_devices(0); + ASSERT_EQ(computation_device_0.replica_device_ids_size(), 2); + const auto& computation_device_1 = + xla_device_assignment->computation_devices(1); + ASSERT_EQ(computation_device_1.replica_device_ids_size(), 2); + const auto& computation_device_2 = + xla_device_assignment->computation_devices(2); + ASSERT_EQ(computation_device_2.replica_device_ids_size(), 2); + + EXPECT_EQ(computation_device_0.replica_device_ids(0), 1); + EXPECT_EQ(computation_device_0.replica_device_ids(1), 5); + EXPECT_EQ(computation_device_1.replica_device_ids(0), 4); + EXPECT_EQ(computation_device_1.replica_device_ids(1), 0); + EXPECT_EQ(computation_device_2.replica_device_ids(0), 2); + EXPECT_EQ(computation_device_2.replica_device_ids(1), 3); +} + +TEST(TPURewriteDeviceUtilTest, InvalidXLADeviceAssignmentMesh1x2x1x3) { + tpu::TopologyProto topology_proto; + { + topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(2); + topology_proto.add_mesh_shape(1); + topology_proto.add_mesh_shape(3); + topology_proto.set_num_tasks(3); + topology_proto.set_num_tpu_devices_per_task(2); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(2); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(1); + topology_proto.add_device_coordinates(0); + topology_proto.add_device_coordinates(2); + } + + std::string topology_attr = topology_proto.SerializeAsString(); + std::vector device_assignment_attr{ + 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 2, 0, 1, 0, 2, 0, 0, 0, 0, 0, 1, 0, 0}; + + llvm::SmallVector devices; + std::vector device_names = + MakeDeviceSet(/*num_tasks=*/3, /*num_devices_per_task=*/2); + ASSERT_TRUE(DeviceNamesToParsedNames(device_names, &devices)); + + auto xla_device_assignment = GetXlaDeviceAssignmentProto( + topology_attr, /*num_replicas=*/2, /*num_cores_per_replica=*/2, + device_assignment_attr); + + EXPECT_THAT(xla_device_assignment, + testing::StatusIs( + absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr( + "must be 'num_replicas' * 'num_cores_per_replica' * "))); +} TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { tpu::TopologyProto topology_proto; diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index 20f771d3307800..4763746dcda061 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -97,12 +97,12 @@ tf_cc_test( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_xla//xla:shape_util", "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -179,7 +179,6 @@ tf_cc_test( "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels/xla:host_compute_ops", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/lib/monitoring:test_utils", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -187,6 +186,7 @@ tf_cc_test( "@local_xla//xla/client:client_library", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/translate/mhlo_to_hlo:type_to_shape", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -247,8 +247,8 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -299,7 +299,7 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc index 916b568a698de8..bce19eace57119 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 161c0278c1db82..8ed7d1ea727867 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -1051,8 +1051,7 @@ Status CompileGraphToXlaHlo( absl::StatusOr> GraphToModule( bool unconditionally_use_set_output_shapes, const Graph& graph, llvm::ArrayRef control_rets, - const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, - mlir::MLIRContext* context) { + const FunctionLibraryDefinition& flib_def, mlir::MLIRContext* context) { mlir::DialectRegistry registry; RegisterDialects(registry); context->appendDialectRegistry(registry); @@ -1070,6 +1069,7 @@ absl::StatusOr> GraphToModule( // do it optionally. config.unconditionally_use_set_output_shapes = unconditionally_use_set_output_shapes; + GraphDebugInfo debug_info; return ConvertGraphToMlir(graph, debug_info, flib_def, config, context); } @@ -1078,12 +1078,10 @@ Status BuildHloFromGraph( mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, std::vector& returns, bool unconditionally_use_output_shapes, llvm::ArrayRef args, llvm::ArrayRef control_rets, - llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info) { - TF_ASSIGN_OR_RETURN( - mlir::OwningOpRef module, - GraphToModule(unconditionally_use_output_shapes, graph, control_rets, - flib_def, debug_info, &mlir_context)); + llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def) { + TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, + GraphToModule(unconditionally_use_output_shapes, graph, + control_rets, flib_def, &mlir_context)); return BuildHloFromModule(module.get(), builder, xla_params, returns, args, device_type); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h index 13fe61b8f69c59..4b8df7c35fe611 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/xla_computation.h" #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/tensor_shape.h" namespace tensorflow { @@ -201,8 +200,7 @@ Status BuildHloFromGraph( mlir::MLIRContext& mlir_context, llvm::ArrayRef xla_params, std::vector& returns, bool unconditionally_use_output_shapes, llvm::ArrayRef args, llvm::ArrayRef control_rets, - llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info); + llvm::StringRef device_type, const FunctionLibraryDefinition& flib_def); static inline Status CompileToHloGraphAnalysisFailedError() { return errors::Internal("disabled after graph analysis"); diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc index 3da6e13005b171..57769d2363bc18 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -49,7 +50,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -263,8 +263,7 @@ absl::StatusOr BuildHloFromGraph( BuildHloFromGraph(graph, builder, mlir_context, xla_params, returns, use_output_shapes, xla_args, /*control_rets=*/{}, DEVICE_TPU, - FunctionLibraryDefinition(OpRegistry::Global()), - /*debug_info=*/{})); + FunctionLibraryDefinition(OpRegistry::Global()))); return builder.Build(); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 9fc2207d83dee8..265f83caf03727 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -201,7 +201,7 @@ Status PrepareAndExportToLibrary(mlir::ModuleOp module, GraphExportConfig config; config.export_entry_func_to_flib = true; absl::flat_hash_set control_ret_nodes; - return tensorflow::tf2xla::v2::ConvertMlirToGraph( + return tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( module, config, /*graph=*/nullptr, flib_def, &control_ret_nodes); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc index 06208be8fc5893..c3598e6ba60c38 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph_test.cc @@ -29,11 +29,11 @@ limitations under the License. #include "xla/shape.h" #include "xla/stream_executor/platform_manager.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/monitoring/test_utils.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc index cad1edf2b89018..1da6d58cfb0bf9 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/tf_dialect_to_executor_test.cc @@ -27,9 +27,9 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD index 74219b2c1a87f1..49b12fab37f60f 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/BUILD @@ -168,8 +168,8 @@ tf_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -227,7 +227,7 @@ tf_cc_test( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc index 233b112503780a..20da4302fa27e2 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf_test.cc @@ -33,9 +33,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc index 3f586433c39b01..58306763c15ccd 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/legalize_tf_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/client/client_library.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/lib/monitoring/test_utils.h" #include "tensorflow/core/platform/env.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/util/debug_data_dumper.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/monitoring/test_utils.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc index 4e5199b05ff285..22d36ea74401ea 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_dialect_to_executor_test.cc @@ -32,9 +32,9 @@ limitations under the License. #include "mlir/Parser/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/testing/utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/resource_loader.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace tf2xla { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc index b0a770802ac34d..7645125770fae8 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.cc @@ -796,11 +796,11 @@ Status Exporter::Convert(mlir::ModuleOp module, } // namespace -Status ConvertMlirToGraph(mlir::ModuleOp module, - const GraphExportConfig& configs, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - absl::flat_hash_set* control_ret_nodes) { +Status ConvertTfExecutorToGraph(mlir::ModuleOp module, + const GraphExportConfig& configs, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + absl::flat_hash_set* control_ret_nodes) { mlir::StatusScopedDiagnosticHandler sh(module.getContext()); if (failed(VerifyExportSuitable(module))) return sh.ConsumeStatus(); return sh.Combine( diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h index 7ee67aa221a91b..bd59770e8164fb 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h @@ -37,11 +37,11 @@ namespace v2 { // The "main" function of the module is stored in the graph and the rest of // functions are stored in the library. Control ret nodes are stored separately // in `control_ret_nodes`. -Status ConvertMlirToGraph(mlir::ModuleOp module, - const GraphExportConfig& configs, - std::unique_ptr* graph, - FunctionLibraryDefinition* flib_def, - absl::flat_hash_set* control_ret_nodes); +Status ConvertTfExecutorToGraph(mlir::ModuleOp module, + const GraphExportConfig& configs, + std::unique_ptr* graph, + FunctionLibraryDefinition* flib_def, + absl::flat_hash_set* control_ret_nodes); // Converts an MLIR function and adds it to a FunctionLibraryDefinition. Status ConvertMlirFunctionToFunctionLibraryDef(mlir::func::FuncOp func, diff --git a/tensorflow/compiler/mlir/tf2xla/internal/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/BUILD index ea2bf17026db23..e6352135a4eec4 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/BUILD @@ -244,8 +244,8 @@ tf_cc_test( "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -296,7 +296,7 @@ tf_cc_test( "//tensorflow/core:testlib", "//tensorflow/core/platform:enable_tf2_utils", # "//tensorflow/core/platform:resource_loader", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc index 840d4c971e7bb5..3365a85e5868fb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/logging_hooks_test.cc @@ -29,10 +29,10 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/file_statistics.h" #include "tsl/platform/status.h" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc index 6cbc67d4ec395c..78d027a0ca702b 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/mlir_bridge_pass_util_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/cc/ops/tpu_functional_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/enable_tf2_utils.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index 6ef95e7633c9cd..8be5db976b6075 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -129,8 +129,8 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -390,7 +390,6 @@ tf_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -399,6 +398,7 @@ tf_cc_test( "@local_xla//xla/client:xla_builder", "@local_xla//xla/client:xla_computation", "@local_xla//xla/mlir_hlo", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -507,8 +507,8 @@ tf_cc_test( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 9021095cadd8e9..a17ef4308db7d9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index dca10693d74b01..9f45164ba4dfe3 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -3145,7 +3145,8 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { // (The batch dimensions are checked by the broadcasting logic) rewriter.replaceOpWithNewOp( op, op.getType(), lhs, rhs, dimension_numbers, - /*precision_config=*/GetPrecisionConfig(&rewriter)); + /*precision_config=*/GetPrecisionConfig(&rewriter), + /*algorithm=*/DotAlgorithmAttr{}); return success(); } }; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 401d1e8b954e40..185216448a15ed 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -60,6 +60,8 @@ def CastElementsToI64Elements : NativeCodeCall< "hlo::convertElementsAttr(" "$0.cast(), $_builder.getIntegerType(64)).cast()">; +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; + //===----------------------------------------------------------------------===// // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// @@ -760,7 +762,8 @@ def HasValidPrecisionConfig : Constraint>; def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config)), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; //===----------------------------------------------------------------------===// @@ -770,7 +773,8 @@ def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), (MHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), - (ToPrecisionConfigsAttr $precision_config)), + (ToPrecisionConfigsAttr $precision_config), + (EmptyDotAlgorithmAttr)), [(HasValidDotDims $dimension_numbers), (HasValidPrecisionConfig $precision_config)]>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index aecf9db3f0d5fe..c9d06801bb088f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -42,9 +42,9 @@ limitations under the License. #include "xla/client/xla_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir index dbb77732a3d6f6..730d09694997e9 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/rewrite_cluster_to_ifrt_call.mlir @@ -11,6 +11,7 @@ // // CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) // CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { } is_bounded_dynamic_dim: false } retvals { sharding { } } num_replicas: 1 num_cores_per_replica: 1 " +// CHECK-SAME: device_assignment = [] // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return @@ -38,7 +39,8 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> (tensor<1x3xf32>) { // CHECK: return // // CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>) -// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 device_assignment { replica_count: 1 computation_count: 2 computation_devices { replica_device_ids: 0 } computation_devices { replica_device_ids: 1 } } use_spmd_for_xla_partitioning: true " +// CHECK-SAME: __tpu_compile_metadata_text = "args { dtype: DT_FLOAT shape { dim { size: 1 } dim { size: 3 } } kind: PARAMETER sharding { type: OTHER tile_assignment_dimensions: 2 tile_assignment_dimensions: 1 tile_assignment_devices: 0 tile_assignment_devices: 1 } is_bounded_dynamic_dim: false } num_replicas: 1 num_cores_per_replica: 2 use_spmd_for_xla_partitioning: true " +// CHECK-SAME: device_assignment = [0, 0, 0, 0, 0, 0, 0, 1] // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK: return @@ -70,6 +72,7 @@ func.func private @_func(%arg0: tensor<1x3xf32>) -> () { // CHECK: return // // CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: device_assignment = [0, 0, 0, 0, 0, 0, 0, 1] // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return @@ -102,6 +105,7 @@ func.func private @_func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> (ten // CHECK: return // // CHECK: func.func @_ifrt_program__func(%arg0: tensor<1x3xf32>, %arg1: tensor<3x1xf32>) -> tensor<1x1xf32> +// CHECK-SAME: device_assignment = [] // CHECK-SAME: tfrt_ifrt_serving.program_id = [[PROGRAM_ID]] : i64 // CHECK-NEXT: %0 = "tf.MatMul"(%arg0, %arg1) // CHECK: return diff --git a/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir new file mode 100644 index 00000000000000..02afa969970004 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/tests/ifrt/tf_device_cleanup.mlir @@ -0,0 +1,8 @@ +// RUN: tf-tfrt-opt %s -tf-device-cleanup | FileCheck %s + +// CHECK-LABEL: func @ops_with_device +func.func @ops_with_device() { + %0 = "tf.VarHandleOp"() {container = "", shared_name = "var", device = "/device/..."} : () -> tensor>> + // CHECK-NOT: device = "/device/..." + func.return +} diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 80969fec73cba5..2ec0fdd9d4c215 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -69,6 +69,7 @@ cc_library( "lower_to_ifrt_restore_variable.cc", "rewrite_cluster_to_ifrt_call.cc", "sink_variable_as_named_array.cc", + "tf_device_cleanup.cc", "tf_identity_propagation.cc", "tf_ifrt_passes.cc", "tf_restore_merging.cc", diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc index 6da9eda3240d82..3b190d326ce58f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_backend_compiler_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td index 7cdc5576ae5465..9c37c58c0e37ba 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.td @@ -129,3 +129,14 @@ def TfIdentityPropagationPass let constructor = "CreateTfIdentityPropagationPass()"; } +def TfDeviceCleanupPass : Pass<"tf-device-cleanup", "mlir::func::FuncOp"> { + let summary = "Cleans up device attributes from all ops"; + + let description = [{ + This pass removes `device` attributes from all TF ops. Some Serving + doesn't rely on `device` attributes from SavedModel. + }]; + + let constructor = "CreateTfDeviceCleanupPass()"; +} + diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc index 62b50b2115dae2..2fc2c173fed8ba 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/rewrite_cluster_to_ifrt_call.cc @@ -21,11 +21,13 @@ limitations under the License. #include #include "absl/base/casts.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -127,38 +129,9 @@ class RewriteClusterToIfrtCallPass << " is missing"; int num_cores_per_replica = num_cores_per_replica_attr.getInt(); - std::optional xla_device_assignment; - auto topology_attr = cluster_func->getAttrOfType( - tensorflow::kTopologyAttr); - // Get device assignment. - auto device_assignment_attr = cluster_func->getAttrOfType( - tensorflow::kDeviceAssignmentAttr); - if (topology_attr && device_assignment_attr && !topology_attr.empty() && - !device_assignment_attr.empty()) { - auto device_coordinates = - tensorflow::GetDeviceCoordinates(device_assignment_attr); - if (!device_coordinates.ok()) - return cluster_func.emitError() - << "error in parsing tpu device coordinates: " - << device_coordinates.status().message(); - - auto device_assignment = tensorflow::GetTPUCompilationAndExecutionDevices( - devices.device_names(), num_replicas, num_cores_per_replica, - topology_attr.getValue(), *device_coordinates); - if (!device_assignment.ok()) - return cluster_func.emitError() - << "error in parsing TPU compilation/execution devices: " - << device_assignment.status().message(); - if (!device_assignment->xla_device_assignment) { - return cluster_func.emitError() - << "Unexpected empty xla_device_assignment"; - } - xla_device_assignment = device_assignment->xla_device_assignment; - } - return mlir::TFTPU::SetMetadataProtoFromClusterFuncOp( cluster_func, num_replicas, num_cores_per_replica, - std::move(xla_device_assignment), metadata); + /*xla_device_assignment=*/std::nullopt, metadata); } void Rewrite(mlir::SymbolTable &symbol_table, @@ -192,10 +165,16 @@ class RewriteClusterToIfrtCallPass auto metadata_attr = ifrt_program->getAttrOfType(kMetadataTextAttrName); - if (!metadata_attr) { + auto device_assignment_attr = + ifrt_program->getAttrOfType(kDeviceAssignmentAttr); + if (!metadata_attr || !device_assignment_attr) { return signalPassFailure(); } + + // For better debuggability, attach attributes such as + // tpu_compile_metadata and device_assignment to IfrtCallOp. ifrt_call_op->setAttr(kMetadataTextAttrName, metadata_attr); + ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr); // TODO(b/304839793): populate variable names after adding a variable // hoisting pass. @@ -228,6 +207,13 @@ class RewriteClusterToIfrtCallPass cloned_ifrt_program->setAttr(kMetadataTextAttrName, builder.getStringAttr(serialized_metadata)); + auto device_assignment_attr = + cluster_func->getAttrOfType(kDeviceAssignmentAttr); + if (!device_assignment_attr) { + device_assignment_attr = builder.getArrayAttr({}); + } + cloned_ifrt_program->setAttr(kDeviceAssignmentAttr, device_assignment_attr); + cloned_ifrt_program.setName(ifrt_program_name); int64_t program_id = NewProgramId(); @@ -248,10 +234,11 @@ class RewriteClusterToIfrtCallPass // hoisting pass. ifrt_call_op.setVariableArgIndicesAttr(builder.getI32ArrayAttr({})); ifrt_call_op.setProgramId(program_id); - // Additionally attach tpu_compile_metadata to IfrtCallOp. Some subsequent - // pass such as SinkVariableAsNamedArrayPass relies on this attribute. + // For better debuggability, attach attributes such as tpu_compile_metadata + // and device_assignment to IfrtCallOp. ifrt_call_op->setAttr(kMetadataTextAttrName, builder.getStringAttr(serialized_metadata)); + ifrt_call_op->setAttr(kDeviceAssignmentAttr, device_assignment_attr); cluster_func->replaceAllUsesWith(ifrt_call_op.getResults()); cluster_func->erase(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index b201370ea3ae7b..1de61abd9e9385 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -37,9 +37,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc new file mode 100644 index 00000000000000..b40c94e6a1de07 --- /dev/null +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_device_cleanup.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { + +#define GEN_PASS_DEF_TFDEVICECLEANUPPASS +#define GEN_PASS_DECL_TFDEVICECLEANUPPASS +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep + +class TfDeviceCleanupPass + : public impl::TfDeviceCleanupPassBase { + public: + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + func.walk([](mlir::Operation* op) { + if (llvm::isa(op->getDialect())) { + op->removeAttr("device"); + } + }); + } +}; + +} // namespace + +std::unique_ptr> +CreateTfDeviceCleanupPass() { + return std::make_unique(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc index 2802cb5a94503a..6d49f9a06141c9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.cc @@ -81,6 +81,10 @@ void AddClusterToIfrtRuntimeOpsPassPipeline(OpPassManager& pm, pm.addPass(CreateRewriteClusterToIfrtCallPass()); + // After device program is extracted, we can clean up device attributes from + // all ops. + pm.addNestedPass(CreateTfDeviceCleanupPass()); + // Sink VarHandle with ReadVariableOp: subsequent SinkVariableAsNamedArrayPass // rely on the co-existence of VarHandle and ReadVariable in the same // function. diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h index 93713fbdc13646..92d9b06dc6765a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf_ifrt_passes.h @@ -57,6 +57,10 @@ CreateTfRestorePruningPass(); std::unique_ptr> CreateLowerToIfrtRestoreVariablePass(); +// Creates a pass that cleans up device attributes from all ops. +std::unique_ptr> +CreateTfDeviceCleanupPass(); + #define GEN_PASS_REGISTRATION #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/passes.h.inc" // IWYU pragma: keep diff --git a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc index 51abf57bc00951..2b97ec6a9536ac 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/passes.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/passes.cc @@ -100,14 +100,14 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( AddTfDeviceAssignmentPasses(pm, options); // After the standard pass, we now have MLIR in TF dialect, and now we convert - // reference variable to resource variables, which is besteffort. + // reference variable to resource variables, which is best effort. pm.addPass(CreateConvertReferenceVariableToResourceVariablePass()); // Move the tf.Assert op to the end of the function, so that it does not // impose unnecessary control dependencies on other ops. pm.addPass(tfrt_compiler::CreateReorderTfAssertPass()); - // Optimze the side-effects of control flow ops by examining the ops in its + // Optimize the side-effects of control flow ops by examining the ops in its // callees. pm.addPass(tfrt_compiler::CreateOptimizeTfControlFlowSideEffectPass()); @@ -117,10 +117,11 @@ void CreateTFExecutorToTFPreInvariantOptimizationPipelineHelper( // Merge non-side-effecting tf.If ops if their operands are the same. pm.addPass(tfrt_compiler::CreateMergeTfIfOpsPass()); - // Lower bound on the number of batch threads in `tf.BatchFunction`. - pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass( - {.min_num_batch_threads = options.min_num_batch_threads, - .min_max_enqueued_batches = options.min_max_enqueued_batches})); + pm.addPass(tfrt_compiler::CreateReconfigBatchOpPass({ + .min_num_batch_threads = options.min_num_batch_threads, + .min_max_enqueued_batches = options.min_max_enqueued_batches, + .batch_padding_policy = options.batch_padding_policy, + })); // Deduplicate functions invoked by tf.BatchFunction with the same // shared_name diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index cb517d1039711f..83b70c251d8bf7 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -20,11 +20,15 @@ cc_library( deps = [ "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", + "//tensorflow/core/tfrt/mlrt/bytecode:function", + "//tensorflow/core/tfrt/mlrt/bytecode:kernel", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -38,10 +42,15 @@ tf_cc_test( data = glob(["testdata/**"]), deps = [ ":mlir_to_bytecode", + "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:attribute_span", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:resource_loader", @@ -57,10 +66,15 @@ cc_library( hdrs = ["test_utils.h"], deps = [ # copybara:uncomment "//learning/brain/experimental/tfrt/native_lowering/stubs:tfrt_native_lowering_impl", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/tfrt/graph_executor:sync_resource_state", "//tensorflow/core/tfrt/mlrt/attribute", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:kernel", @@ -70,7 +84,9 @@ cc_library( "//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub", "//tensorflow/core/tfrt/utils:tensor_util", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@tf_runtime//:hostcontext", + "@tf_runtime//:support", "@tf_runtime//:tensor", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index d3b19eb3447cf7..52b1826f4a1f65 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -25,14 +25,26 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/function.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" namespace mlrt { namespace { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h index 7f5416d230cb05..950865644effcc 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h @@ -22,9 +22,16 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" namespace mlrt { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index 9f02f1d3c2a531..d7d3065d847d7f 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -19,9 +19,20 @@ limitations under the License. #include #include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" #include "tsl/platform/resource_loader.h" diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc index b5a3cb9550c558..e4f9e6f77ba2fc 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.cc @@ -22,10 +22,19 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace mlrt { namespace testing { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h index d569f32175f78c..6140c71149c9ee 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/test_utils.h @@ -21,10 +21,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/tfrt/graph_executor/sync_resource_state.h" #include "tensorflow/core/tfrt/mlrt/attribute/attribute.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/kernel.h" @@ -34,10 +39,13 @@ limitations under the License. #include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h" #include "tensorflow/core/tfrt/utils/tensor_util.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime #include "tfrt/host_context/host_allocator.h" // from @tf_runtime #include "tfrt/host_context/host_context.h" // from @tf_runtime +#include "tfrt/support/string_util.h" // from @tf_runtime +#include "tfrt/tensor/dense_host_tensor.h" // from @tf_runtime #include "tfrt/tensor/dense_tensor_utils.h" // from @tf_runtime namespace mlrt { diff --git a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h index 9f71dce30675fe..69ff39c3dcf95e 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h +++ b/tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h @@ -125,7 +125,7 @@ struct TfrtCompileOptions { // For TFRT, if true, tf.While's iterations will be parallelized on a // best-effort basis. This is currently experimental. MLRT attempts to convert // tf.while to tf_mlrt.map_fn regardless of this flag. For tf.While that - // cannot be onverted tf_mlrt.map_fn, MLRT try to parallerize tf.while's + // cannot be converted tf_mlrt.map_fn, MLRT try to parallelize tf.while's // iterations on a best-effort basis. bool enable_while_parallel_iterations = false; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 404d134223ab48..b973259a38ba16 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -337,6 +337,7 @@ Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module, kernelPm.addPass(::mlir::createConvertSCFToCFPass()); #if TENSORFLOW_USE_ROCM kernelPm.addPass(mlir::createGpuKernelToRocdlPass()); + kernelPm.addPass(mlir::createReconcileUnrealizedCastsPass()); #elif GOOGLE_CUDA kernelPm.addPass(mlir::createGpuKernelToNvvmPass()); kernelPm.addPass(mlir::NVVM::createOptimizeForTargetPass()); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD index e582aa12f34299..cda62df3967a99 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/hlo_to_kernel/BUILD @@ -12,6 +12,7 @@ glob_lit_tests( "gpu", ], driver = "//tensorflow/compiler/mlir:run_lit.sh", + hermetic_cuda_data_dir = "%S/../../../../../../../../cuda_nvcc", test_file_exts = ["mlir"], ) diff --git a/tensorflow/compiler/mlir/tools/optimize/BUILD b/tensorflow/compiler/mlir/tools/optimize/BUILD new file mode 100644 index 00000000000000..7c67945754bcb5 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/optimize/BUILD @@ -0,0 +1,20 @@ +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "quantization_utils", + srcs = ["quantization_utils.cc"], + hdrs = ["quantization_utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/kernels/internal:quantization_util", + "//tensorflow/compiler/mlir/lite/kernels/internal:runtime_shape", + ], +) diff --git a/tensorflow/compiler/mlir/tools/optimize/quantization_utils.cc b/tensorflow/compiler/mlir/tools/optimize/quantization_utils.cc new file mode 100644 index 00000000000000..d8da88ab60761c --- /dev/null +++ b/tensorflow/compiler/mlir/tools/optimize/quantization_utils.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h" +#include "tensorflow/compiler/mlir/lite/kernels/internal/runtime_shape.h" + +namespace tflite_migration { +namespace optimize { +namespace utils { + +namespace { + +const int8_t kMinQuantizedValue8bit = -127; +const int8_t kMaxQuantizedValue8bit = 127; + +// const int8_t kMinQuantizedValue4bit = -7; +// const int8_t kMaxQuantizedValue4bit = 7; + +// The maximum number of dimensions supported in per-channel quantization. +constexpr int kPerChannelMaxDim = 4; +} // namespace + +template +std::vector SymmetricBiasQuantize(const float* data, + uint64_t num_elements, + const std::vector& scales) { + std::vector buffer(num_elements); + const BiasType kScale = std::numeric_limits::max(); + float scaling_factor_inv_per_layer = (scales[0] == 0) ? 0 : 1.0 / scales[0]; + + for (int32_t idx = 0; idx < num_elements; idx++) { + float scaling_factor_inv = + scales.size() == 1 ? scaling_factor_inv_per_layer + : ((scales[idx] == 0) ? 0 : 1.0 / scales[idx]); + const BiasType quantized_value = tflite_migration::SafeCast( + std::round(data[idx] * scaling_factor_inv)); + buffer[idx] = std::min(kScale, std::max(-kScale, quantized_value)); + } + return buffer; +} + +template std::vector SymmetricBiasQuantize( + const float* data, uint64_t num_elements, const std::vector& scales); + +template std::vector SymmetricBiasQuantize( + const float* data, uint64_t num_elements, const std::vector& scales); + +std::vector SymmetricQuantizeFloatsToInt16(const float* data, + uint64_t num_elements, + float scaling_factor) { + // Compute the inverse of scale. + const float scaling_factor_inv = + (scaling_factor == 0) ? 0 : 1.0 / scaling_factor; + std::vector buffer(num_elements); + const int32_t kScale = std::numeric_limits::max(); + + for (size_t i = 0; i < num_elements; i++) { + const int32_t quantized_value = + static_cast(std::round(data[i] * scaling_factor_inv)); + buffer[i] = std::min(kScale, std::max(-kScale, quantized_value)); + } + return buffer; +} + +void SymmetricPerChannelQuantizeValues(const float* const input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value) { + using mlir::RuntimeShape; + // Quantize the values. + int indices[kPerChannelMaxDim]; + RuntimeShape unextended_tensor_dims(dimension.size(), dimension.data()); + RuntimeShape tensor_dims = + RuntimeShape::ExtendedShape(kPerChannelMaxDim, unextended_tensor_dims); + channel_dim_index += + kPerChannelMaxDim - unextended_tensor_dims.DimensionsCount(); + for (indices[0] = 0; indices[0] < tensor_dims.Dims(0); indices[0]++) { + for (indices[1] = 0; indices[1] < tensor_dims.Dims(1); indices[1]++) { + for (indices[2] = 0; indices[2] < tensor_dims.Dims(2); indices[2]++) { + for (indices[3] = 0; indices[3] < tensor_dims.Dims(3); indices[3]++) { + int channel_idx = indices[channel_dim_index]; + int index = Offset(tensor_dims, indices); + const float val = input[index]; + const int32_t quantized_value = + static_cast(std::round(val * scales_inv[channel_idx])); + output_value->at(index) = std::min( + kMaxQuantizedValue8bit, + std::max(kMinQuantizedValue8bit, quantized_value)); + } + } + } + } +} + +} // namespace utils +} // namespace optimize +} // namespace tflite_migration diff --git a/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h b/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h new file mode 100644 index 00000000000000..aa22d5469a7299 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/optimize/quantization_utils.h @@ -0,0 +1,45 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ + +#include +#include + +namespace tflite_migration { +namespace optimize { +namespace utils { + +template +std::vector SymmetricBiasQuantize(const float* data, + uint64_t num_elements, + const std::vector& scales); + +std::vector SymmetricQuantizeFloatsToInt16(const float* data, + uint64_t num_elements, + float scaling_factor); + +// Quantize the values given an array of scales. +void SymmetricPerChannelQuantizeValues(const float* input, + const std::vector& scales_inv, + const std::vector& dimension, + int32_t channel_dim_index, + std::vector* output_value); + +} // namespace utils +} // namespace optimize +} // namespace tflite_migration + +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4cb34827f0f4ab..69ba20048ee2a0 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1,6 +1,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") load("//tensorflow:tensorflow.default.bzl", "cuda_py_strict_test", "tf_cuda_cc_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/tests:build_combined_defs.bzl", "tf_xla_combined_py_test") load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites", "tf_xla_py_strict_test") load( "//tensorflow/core/platform:build_config_root.bzl", @@ -89,14 +90,34 @@ py_strict_test( ], ) +#LINT.IfChange(combined_tests) +# If you add a new tf_xla_py_strict_test please either add the test file to one of the combined test +# targets that matches in all tags and other settings or add a new combined test target. +tf_xla_combined_py_test( + name = "ops_test_mlir_false", + size = "medium", + package = "tensorflow.compiler.tests", + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + tests = [ + # go/keep-sorted start + ":adadelta_test_lib", + # go/keep-sorted end + ], +) +#LINT.ThenChange(:individual_tests) + +#LINT.IfChange(individual_tests) tf_xla_py_strict_test( name = "adadelta_test", size = "medium", srcs = ["adadelta_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], deps = [ ":xla_test", @@ -113,7 +134,6 @@ tf_xla_py_strict_test( name = "adagrad_test", size = "small", srcs = ["adagrad_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -133,7 +153,6 @@ tf_xla_py_strict_test( name = "adagrad_da_test", size = "small", srcs = ["adagrad_da_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -155,7 +174,6 @@ tf_xla_py_strict_test( name = "adam_test", size = "small", srcs = ["adam_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -185,7 +203,6 @@ tf_xla_py_strict_test( # copybara:uncomment_end # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -205,7 +222,6 @@ tf_xla_py_strict_test( name = "argminmax_test", size = "small", srcs = ["argminmax_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -226,7 +242,6 @@ tf_xla_py_strict_test( name = "binary_ops_test", size = "medium", srcs = ["binary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -255,7 +270,6 @@ tf_xla_py_strict_test( name = "complex_div_test", size = "medium", srcs = ["complex_div_test.py"], - enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -282,7 +296,6 @@ tf_xla_py_strict_test( name = "bucketize_op_test", size = "small", srcs = ["bucketize_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -302,7 +315,6 @@ tf_xla_py_strict_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -325,7 +337,6 @@ tf_xla_py_strict_test( name = "cholesky_op_test", size = "medium", srcs = ["cholesky_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -352,7 +363,6 @@ tf_xla_py_strict_test( # #TODO(b/286470564): Remove once the bug is fixed. # disable_tpu_tfrt = True, # copybara:uncomment_end - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -382,7 +392,6 @@ tf_xla_py_strict_test( name = "self_adjoint_eig_op_test", size = "medium", srcs = ["self_adjoint_eig_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -403,7 +412,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["searchsorted_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -447,7 +455,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -468,7 +475,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -489,7 +495,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_oss", # TODO(b/295649328): fix failed nightly tests @@ -513,7 +518,6 @@ tf_xla_py_strict_test( name = "clustering_test", size = "small", srcs = ["clustering_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -534,7 +538,6 @@ tf_xla_py_strict_test( name = "concat_ops_test", size = "medium", srcs = ["concat_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -559,7 +562,6 @@ tf_xla_py_strict_test( name = "conv2d_test", size = "medium", srcs = ["conv2d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -605,7 +607,6 @@ tf_xla_py_strict_test( name = "conv3d_test", size = "medium", srcs = ["conv3d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -631,7 +632,6 @@ tf_xla_py_strict_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -656,7 +656,6 @@ tf_xla_py_strict_test( name = "dynamic_slice_ops_test", size = "small", srcs = ["dynamic_slice_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -675,7 +674,6 @@ tf_xla_py_strict_test( name = "einsum_op_test", size = "medium", srcs = ["einsum_op_test.py"], - enable_mlir_bridge = True, enabled_backends = [ "cpu", "gpu", @@ -700,7 +698,6 @@ tf_xla_py_strict_test( name = "reshape_op_test", size = "small", srcs = ["reshape_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -754,7 +751,6 @@ tf_xla_py_strict_test( name = "eager_test", size = "medium", srcs = ["eager_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "multi_and_single_gpu", @@ -792,7 +788,6 @@ tf_xla_py_strict_test( name = "fifo_queue_test", size = "medium", srcs = ["fifo_queue_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -810,7 +805,6 @@ tf_xla_py_strict_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 12, tags = [ @@ -834,7 +828,6 @@ tf_xla_py_strict_test( name = "slice_ops_test", size = "medium", srcs = ["slice_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -854,7 +847,6 @@ tf_xla_py_strict_test( name = "ftrl_test", size = "medium", srcs = ["ftrl_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 8, tags = [ @@ -877,7 +869,6 @@ tf_xla_py_strict_test( name = "ftrl_ops_test", size = "medium", srcs = ["ftrl_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -897,7 +888,6 @@ tf_xla_py_strict_test( name = "function_test", size = "small", srcs = ["function_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -918,7 +908,6 @@ tf_xla_py_strict_test( size = "small", timeout = "long", srcs = ["image_ops_test.py"], - enable_mlir_bridge = False, enabled_backends = [ "cpu", "gpu", @@ -950,7 +939,6 @@ tf_xla_py_strict_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_cuda_asan", # times out @@ -969,7 +957,6 @@ tf_xla_py_strict_test( name = "lrn_ops_test", size = "medium", srcs = ["lrn_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -991,7 +978,6 @@ tf_xla_py_strict_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1011,7 +997,6 @@ tf_xla_py_strict_test( size = "medium", timeout = "long", srcs = ["matrix_band_part_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_aarch64", # TODO(b/315533266) @@ -1034,7 +1019,6 @@ tf_xla_py_strict_test( size = "medium", timeout = "long", srcs = ["matrix_diag_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 8, tags = [ @@ -1053,7 +1037,6 @@ tf_xla_py_strict_test( name = "momentum_test", size = "small", srcs = ["momentum_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1075,7 +1058,6 @@ tf_xla_py_strict_test( name = "nary_ops_test", size = "small", srcs = ["nary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1095,7 +1077,6 @@ tf_xla_py_strict_test( name = "nullary_ops_test", size = "small", srcs = ["nullary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1113,7 +1094,6 @@ tf_xla_py_strict_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1138,7 +1118,6 @@ tf_xla_py_strict_test( name = "pooling_ops_3d_test", size = "medium", srcs = ["pooling_ops_3d_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1160,7 +1139,6 @@ tf_xla_py_strict_test( name = "proximal_adagrad_test", size = "medium", srcs = ["proximal_adagrad_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1181,7 +1159,6 @@ tf_xla_py_strict_test( name = "proximal_gradient_descent_test", size = "medium", srcs = ["proximal_gradient_descent_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1207,7 +1184,6 @@ tf_xla_py_strict_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1236,7 +1212,6 @@ tf_xla_py_strict_test( "cpu", "cpu_ondemand", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, tags = [ @@ -1279,7 +1254,6 @@ tf_xla_py_strict_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1302,7 +1276,6 @@ tf_xla_py_strict_test( name = "reduce_window_test", size = "small", srcs = ["reduce_window_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1322,7 +1295,6 @@ tf_xla_py_strict_test( name = "reverse_ops_test", size = "medium", srcs = ["reverse_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1341,7 +1313,6 @@ tf_xla_py_strict_test( name = "reverse_sequence_op_test", size = "medium", srcs = ["reverse_sequence_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1361,7 +1332,6 @@ tf_xla_py_strict_test( # name = "reverse_sequence_op_args_test", # size = "medium", # srcs = ["reverse_sequence_op_args_test.py"], -# enable_mlir_bridge = False, # main = "reverse_sequence_op_args_test.py", # python_version = "PY3", # tags = [ @@ -1384,7 +1354,6 @@ tf_xla_py_strict_test( name = "rmsprop_test", size = "small", srcs = ["rmsprop_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1404,7 +1373,6 @@ tf_xla_py_strict_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -1430,7 +1398,6 @@ tf_xla_py_strict_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1451,7 +1418,6 @@ tf_xla_py_strict_test( name = "spacetobatch_op_test", size = "medium", srcs = ["spacetobatch_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 3, tags = [ @@ -1472,7 +1438,6 @@ tf_xla_py_strict_test( name = "sparse_to_dense_op_test", size = "medium", srcs = ["sparse_to_dense_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1492,7 +1457,6 @@ tf_xla_py_strict_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "config-cuda-only", @@ -1520,7 +1484,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -1558,7 +1521,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -1597,8 +1559,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - # TODO(b/232442915): Enable MLIR. - enable_mlir_bridge = False, python_version = "PY3", shard_count = 20, tags = [ @@ -1629,7 +1589,6 @@ tf_xla_py_strict_test( srcs = ["tensor_array_ops_test.py"], # TensorArray ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "config-cuda-only", @@ -1669,7 +1628,6 @@ tf_xla_py_strict_test( # copybara:uncomment_end # TensorList ops are not implemented in the on-demand compilation model yet. disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1691,7 +1649,6 @@ tf_xla_py_strict_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 8, tags = [ @@ -1715,7 +1672,6 @@ tf_xla_py_strict_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1745,7 +1701,6 @@ tf_xla_py_strict_test( name = "fused_batchnorm_test", size = "medium", srcs = ["fused_batchnorm_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1770,7 +1725,6 @@ tf_xla_py_strict_test( size = "small", timeout = "moderate", srcs = ["variable_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1805,7 +1759,6 @@ tf_xla_py_strict_test( # #TODO(b/291130193): Remove once the bug is fixed. # disable_tpu_tfrt = True, # copybara:uncomment_end - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1833,7 +1786,6 @@ tf_xla_py_strict_test( size = "small", srcs = ["case_test.py"], disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1901,7 +1853,6 @@ tf_xla_py_strict_test( name = "gather_nd_op_test", size = "medium", srcs = ["gather_nd_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1920,7 +1871,6 @@ tf_xla_py_strict_test( name = "scatter_nd_op_test", size = "medium", srcs = ["scatter_nd_op_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1940,7 +1890,6 @@ tf_xla_py_strict_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", shard_count = 50, # Times out in fastbuild mode. @@ -1971,7 +1920,6 @@ tf_xla_py_strict_test( name = "data_format_ops_test", size = "small", srcs = ["data_format_ops_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1990,7 +1938,6 @@ tf_xla_py_strict_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2008,722 +1955,383 @@ tf_xla_py_strict_test( ], ) -cuda_py_strict_test( - name = "xla_device_gpu_test", - size = "small", - srcs = ["xla_device_gpu_test.py"], +tf_xla_py_strict_test( + name = "fake_quant_ops_test", + size = "medium", + srcs = ["fake_quant_ops_test.py"], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - "//tensorflow/python/client:session", - "//tensorflow/python/eager:context", + ":xla_test", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:array_ops_gen", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", ], ) -cuda_py_strict_test( - name = "jit_test", +tf_xla_py_strict_test( + name = "placeholder_test", + size = "small", + srcs = ["placeholder_test.py"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + deps = [ + ":xla_test", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:resource_variable_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:test", + ], +) + +tf_xla_py_strict_test( + name = "quantized_ops_test", size = "medium", - srcs = ["jit_test.py"], - #shard_count = 5, + srcs = ["quantized_ops_test.py"], + python_version = "PY3", tags = [ - "no_cuda_asan", # Times out. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - ":test_utils", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", - "//tensorflow/python/compiler/xla:compiler_py", + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:cond", - "//tensorflow/python/ops:control_flow_ops", - "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:bitwise_ops", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:nn_ops", - "//tensorflow/python/ops:while_loop", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/platform:test", "//third_party/py/numpy", ], ) -cuda_py_strict_test( - name = "async_comp_test", +tf_xla_py_strict_test( + name = "xla_ops_test", size = "medium", - srcs = ["async_comp_test.py"], - shard_count = 1, + srcs = ["xla_ops_test.py"], + disabled_backends = [ + "gpu", + "gpu_a100", + "gpu_h100", + ], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python/client:session", + ":xla_test", + "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:function", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor_shape", + "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:random_ops_util", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + "@local_xla//xla:xla_data_proto_py", ], ) -cuda_py_strict_test( - name = "dense_layer_test", - size = "large", - srcs = ["dense_layer_test.py"], +tf_xla_py_strict_test( + name = "xla_custom_call_ops_test", + size = "small", + srcs = ["xla_custom_call_ops_test.py"], + disabled_backends = [ + "gpu", + "gpu_a100", + "gpu_h100", + ], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, + use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ - ":test_utils", - "//tensorflow/core:protos_all_py", - "//tensorflow/python/compiler/xla:compiler_py", + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", - "//tensorflow/python/layers", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:variables", + "//tensorflow/python/framework:tensor_spec", + "//tensorflow/python/ops:random_ops", "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", ], ) -cc_library( - name = "randomized_tests_library", - testonly = 1, - srcs = ["randomized_tests.cc"], +tf_xla_py_strict_test( + name = "runtime_shape_check_test", + size = "small", + srcs = ["runtime_shape_check_test.py"], + disabled_backends = [ + "cpu", + "cpu_ondemand", + ], + python_version = "PY3", + tags = [ + "no_pip", + "notap", + ], + use_xla_device = False, deps = [ - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:common", - "//tensorflow/compiler/jit:flags_headers", - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow_opensource", - "//tensorflow/core:test", - "//tensorflow/core:testlib", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:fixed_array", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:status", - "@local_xla//xla:xla_data_proto_cc", + ":xla_test", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", ], ) -tf_cuda_cc_test( - name = "randomized_tests", +tf_xla_py_strict_test( + name = "conv_node_name_test", size = "medium", - args = ["--tf_xla_test_use_mlir=false"], - shard_count = 20, - # This test is randomized, so only run it if explicitly requested. + srcs = ["conv_node_name_test.py"], + python_version = "PY3", + shard_count = 5, tags = [ - "manual", + "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], + ], + deps = [ + ":xla_test", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/platform:test", + "//third_party/py/numpy", + ], ) -tf_cuda_cc_test( - name = "randomized_tests_mlir", +tf_xla_py_strict_test( + name = "tridiagonal_solve_ops_test", size = "medium", - args = ["--tf_xla_test_use_mlir=true"], - shard_count = 20, - # This test is randomized, so only run it if explicitly requested. + srcs = ["tridiagonal_solve_ops_test.py"], + python_version = "PY3", tags = [ - "manual", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], + "optonly", + ], + deps = [ + ":xla_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops/linalg:linalg_impl", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], ) -# Create a deterministic version of randomized_tests_mlir with fixed seed. -# This can be used in presubmit checks as it is no longer randomized. -tf_cuda_cc_test( - name = "randomized_tests_mlir_seeded", +tf_xla_py_strict_test( + name = "tridiagonal_matmul_ops_test", size = "medium", - args = [ - "--tf_xla_random_seed=200839030", - "--tf_xla_test_use_mlir=true", - "--tf_xla_test_device=GPU:0", - ], - shard_count = 20, - tags = [ - "config-cuda-only", - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "requires-gpu-nvidia", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], -) - -# Create a deterministic version of randomized_tests with fixed seed. -# This can be used in presubmit checks as it is no longer randomized. -tf_cuda_cc_test( - name = "randomized_tests_seeded", - size = "medium", - args = [ - "--tf_xla_random_seed=200839030", - "--tf_xla_test_use_mlir=false", - "--tf_xla_test_device=GPU:0", - ], - shard_count = 20, - tags = [ - "config-cuda-only", - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm", # ROCmSoftwarePlatform #958 - "noasan", # TODO(b/201651800) - "requires-gpu-nvidia", - ] + tf_cuda_tests_tags(), - deps = [":randomized_tests_library"], -) - -tf_cuda_cc_test( - name = "unary_ops_composition_test", - srcs = ["unary_ops_composition_test.cc"], - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ] + tf_cuda_tests_tags(), - deps = [ - "//tensorflow/cc:cc_ops", - "//tensorflow/compiler/jit", - "//tensorflow/compiler/jit:flags", - "//tensorflow/compiler/jit:xla_kernel_creator", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - "//tensorflow/core/kernels:ops_testutil", - "@local_tsl//tsl/platform:status", - ], -) - -py_strict_library( - name = "lstm", - testonly = 1, - srcs = ["lstm.py"], - srcs_version = "PY3", - deps = [ - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:random_ops", - "//tensorflow/python/ops:variable_v1", - "@six_archive//:six", - ], -) - -cuda_py_strict_test( - name = "lstm_test", - srcs = ["lstm_test.py"], + srcs = ["tridiagonal_matmul_ops_test.py"], + python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], - xla_enable_strict_auto_jit = False, - xla_enabled = True, deps = [ - ":lstm", ":xla_test", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:gradients_impl", - "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:array_ops_stack", + "//tensorflow/python/ops:gradient_checker_v2", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/ops:stateless_random_ops", + "//tensorflow/python/ops/linalg:linalg_impl", + "//tensorflow/python/platform:test", "//third_party/py/numpy", ], ) -# An example of ahead-of-time compilation using tfcompile. The -# lstm_layer_inference.pbtxt file was generated by running lstm_test -# --dump_graph_dir, and the config file was written by hand. -# -# Run the following to build a minimal benchmark of the computation on Android: -# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ -# --cpu=armeabi-v7a \ -# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ -# --crosstool_top=//external:android/crosstool \ -# //tensorflow/compiler/tests:lstm_layer_inference_benchmark - -# -# Currently the resulting binary size is ~190KB -tf_library( - name = "lstm_layer_inference", - testonly = 1, - config = "lstm_layer_inference.config.pbtxt", - cpp_class = "LSTMLayerInference", - graph = "lstm_layer_inference.pbtxt", - tags = ["manual"], - tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], -) - tf_xla_py_strict_test( - name = "fake_quant_ops_test", + name = "special_math_test", size = "medium", - srcs = ["fake_quant_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", + srcs = ["special_math_test.py"], + shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], deps = [ ":xla_test", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_gen", - "//tensorflow/python/platform:test", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/ops:gradient_checker_v2", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:math_ops_gen", + "//tensorflow/python/ops:random_ops_gen", + "//tensorflow/python/platform:client_testlib", "//third_party/py/numpy", + "@absl_py//absl/flags", + "@absl_py//absl/testing:parameterized", ], ) tf_xla_py_strict_test( - name = "placeholder_test", - size = "small", - srcs = ["placeholder_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", + name = "repeat_op_test", + size = "medium", + srcs = ["repeat_op_test.py"], + shard_count = 1, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], deps = [ ":xla_test", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:resource_variable_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:test", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "quantized_ops_test", + name = "image_ops_jit_compile_test", size = "medium", - srcs = ["quantized_ops_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", + srcs = ["image_ops_jit_compile_test.py"], + disabled_backends = [ + "cpu_ondemand", + ], + shard_count = 1, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "optonly", ], + use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/framework:constant_op", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:errors", "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:bitwise_ops", + "//tensorflow/python/ops:image_ops", "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "xla_ops_test", + name = "ensure_shape_op_test", size = "medium", - srcs = ["xla_ops_test.py"], - disabled_backends = [ - "gpu", - "gpu_a100", - "gpu_h100", - ], - enable_mlir_bridge = True, + srcs = ["ensure_shape_op_test.py"], python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "no_rocm" + "optonly", ], deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/ops:gen_xla_ops", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/eager:def_function", "//tensorflow/python/framework:dtypes", "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:function", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_shape", - "//tensorflow/python/framework:test_lib", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_stack", - "//tensorflow/python/ops:random_ops_util", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", - "@absl_py//absl/testing:parameterized", - "@local_xla//xla:xla_data_proto_py", + "//tensorflow/python/ops:check_ops", + "//tensorflow/python/platform:client_testlib", ], ) tf_xla_py_strict_test( - name = "xla_custom_call_ops_test", + name = "where_op_test", size = "small", - srcs = ["xla_custom_call_ops_test.py"], - disabled_backends = [ + srcs = ["where_op_test.py"], + enabled_backends = [ + "cpu", "gpu", "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = False, - python_version = "PY3", tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "notap", + "no_pip", + "optonly", ], - use_xla_device = False, # Uses tf.function(jit_compile=True) deps = [ ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", - "//tensorflow/python/framework:tensor_spec", - "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) tf_xla_py_strict_test( - name = "runtime_shape_check_test", + name = "where_op_tpu_test", size = "small", - srcs = ["runtime_shape_check_test.py"], + srcs = ["where_op_test.py"], + args = [ + "--tpu_use_tfrt=true", + ], disabled_backends = [ "cpu", "cpu_ondemand", + "gpu", + "gpu_a100", + "gpu_h100", ], - enable_mlir_bridge = False, - python_version = "PY3", + main = "where_op_test.py", tags = [ "no_pip", - "notap", + "optonly", ], - use_xla_device = False, deps = [ ":xla_test", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:config", "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", "//tensorflow/python/platform:client_testlib", + "//tensorflow/python/tpu:tpu_py", ], ) tf_xla_py_strict_test( - name = "conv_node_name_test", - size = "medium", - srcs = ["conv_node_name_test.py"], - enable_mlir_bridge = True, + name = "const_arg_test", + size = "small", + srcs = ["const_arg_test.py"], python_version = "PY3", - shard_count = 5, tags = [ - "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], deps = [ ":xla_test", - "//tensorflow/python/framework:ops", - "//tensorflow/python/layers", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python/framework:dtypes", "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:control_flow_ops", "//tensorflow/python/platform:test", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "tridiagonal_solve_ops_test", - size = "medium", - srcs = ["tridiagonal_solve_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:test_lib", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:gradients", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops/linalg:linalg_impl", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "tridiagonal_matmul_ops_test", - size = "medium", - srcs = ["tridiagonal_matmul_ops_test.py"], - enable_mlir_bridge = True, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:array_ops_stack", - "//tensorflow/python/ops:gradient_checker_v2", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:stateless_random_ops", - "//tensorflow/python/ops/linalg:linalg_impl", - "//tensorflow/python/platform:test", - "//third_party/py/numpy", - ], -) - -tf_xla_py_strict_test( - name = "special_math_test", - size = "medium", - srcs = ["special_math_test.py"], - enable_mlir_bridge = True, - shard_count = 5, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/ops:gradient_checker_v2", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:math_ops_gen", - "//tensorflow/python/ops:random_ops_gen", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - "@absl_py//absl/flags", - "@absl_py//absl/testing:parameterized", - ], -) - -tf_xla_py_strict_test( - name = "repeat_op_test", - size = "medium", - srcs = ["repeat_op_test.py"], - enable_mlir_bridge = True, - shard_count = 1, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "image_ops_jit_compile_test", - size = "medium", - srcs = ["image_ops_jit_compile_test.py"], - disabled_backends = [ - "cpu_ondemand", - ], - enable_mlir_bridge = False, - shard_count = 1, - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly" - ], - use_xla_device = False, - deps = [ - ":xla_test", - "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/framework:ops", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:image_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/ops:variables", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "ensure_shape_op_test", - size = "medium", - srcs = ["ensure_shape_op_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:errors", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:check_ops", - "//tensorflow/python/platform:client_testlib", - ], -) - -tf_xla_py_strict_test( - name = "where_op_test", - size = "small", - srcs = ["where_op_test.py"], - enable_mlir_bridge = False, - enabled_backends = [ - "cpu", - "gpu", - "gpu_a100", - "gpu_h100", - ], - tags = [ - "no_pip", - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/tpu:tpu_py", - ], -) - -tf_xla_py_strict_test( - name = "where_op_tpu_test", - size = "small", - srcs = ["where_op_test.py"], - args = [ - "--tpu_use_tfrt=true", - ], - disabled_backends = [ - "cpu", - "cpu_ondemand", - "gpu", - "gpu_a100", - "gpu_h100", - ], - enable_mlir_bridge = False, - main = "where_op_test.py", - tags = [ - "no_pip", - "optonly", - ], - deps = [ - ":xla_test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/ops:math_ops", - "//tensorflow/python/platform:client_testlib", - "//tensorflow/python/tpu:tpu_py", - ], -) - -tf_xla_py_strict_test( - name = "const_arg_test", - size = "small", - srcs = ["const_arg_test.py"], - enable_mlir_bridge = False, - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ], - deps = [ - ":xla_test", - "//tensorflow/compiler/tf2xla/python:xla", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/ops:array_ops", - "//tensorflow/python/platform:test", - ], -) - -cuda_py_strict_test( - name = "const_test", - size = "small", - srcs = ["const_test.py"], - python_version = "PY3", - xla_enable_strict_auto_jit = False, - xla_enabled = True, - deps = [ - "//tensorflow/python/eager:def_function", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/framework:test_lib", - "//tensorflow/python/platform:client_testlib", - "//third_party/py/numpy", - ], -) - -tpu_py_strict_test( - name = "giant_const_op_test", - srcs = [ - "giant_const_op_test.py", - ], - disable_experimental = True, - # TODO(b/188995810): Add an optimization in MLIR importer to not - # materialize giant splat constants. - disable_mlir_bridge = True, - python_version = "PY3", - tags = ["no_oss"], - deps = [ - "//tensorflow/python/distribute:tpu_strategy", - "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", - "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:remote", - "//tensorflow/python/eager:test", - "//tensorflow/python/framework:config", - "//tensorflow/python/framework:constant_op", - "//tensorflow/python/framework:dtypes", - "//tensorflow/python/platform:flags", - "//third_party/py/numpy", ], ) @@ -2737,7 +2345,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 10, tags = [ @@ -2762,7 +2369,6 @@ tpu_py_strict_test( name = "approx_topk_test", srcs = ["approx_topk_test.py"], disable_experimental = False, - disable_mlir_bridge = False, tags = ["no_oss"], deps = [ "//tensorflow/python/eager:backprop", @@ -2783,7 +2389,6 @@ tf_xla_py_strict_test( name = "xla_call_module_test", size = "small", srcs = ["xla_call_module_test.py"], - enable_mlir_bridge = False, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2814,7 +2419,6 @@ tf_xla_py_strict_test( srcs = ["xla_call_module_test.py"], # cpu_ondemand overrides the TF_XLA_FLAGS disabled_backends = ["cpu_ondemand"], - enable_mlir_bridge = False, env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=platform"}, main = "xla_call_module_test.py", python_version = "PY3", @@ -2846,7 +2450,6 @@ tf_xla_py_strict_test( size = "small", srcs = ["xla_call_module_test.py"], disabled_backends = ["cpu_ondemand"], # cpu_ondemand overrides the TF_XLA_FLAGS - enable_mlir_bridge = False, env = {"TF_XLA_FLAGS": "--tf_xla_call_module_disabled_checks=shape_assertions"}, main = "xla_call_module_test.py", python_version = "PY3", @@ -2877,7 +2480,6 @@ tf_xla_py_strict_test( name = "bincount_op_test", size = "small", srcs = ["bincount_op_test.py"], - enable_mlir_bridge = False, python_version = "PY3", shard_count = 1, tags = [ @@ -2895,7 +2497,6 @@ tf_xla_py_strict_test( name = "unique_ops_test", size = "small", srcs = ["unique_ops_test.py"], - enable_mlir_bridge = False, enabled_backends = [ "cpu", "gpu", @@ -2922,7 +2523,6 @@ tpu_py_strict_test( size = "small", srcs = ["mean_op_test.py"], disable_experimental = False, - disable_mlir_bridge = False, tags = [ "notsan", # timesout ], @@ -2941,7 +2541,6 @@ tf_xla_py_strict_test( name = "xla_dump_to_test", size = "medium", srcs = ["xla_dump_to_test.py"], - enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2961,7 +2560,6 @@ tf_xla_py_strict_test( # name = "xla_dump_to_sponge_test", # size = "medium", # srcs = ["xla_dump_to_sponge_test.py"], -# enable_mlir_bridge = True, # python_version = "PY3", # tags = [ # "optonly", @@ -2975,3 +2573,327 @@ tf_xla_py_strict_test( # ], # ) # copybara:uncomment_end +#LINT.ThenChange(:combined_tests) + +cuda_py_strict_test( + name = "xla_device_gpu_test", + size = "small", + srcs = ["xla_device_gpu_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/python/client:session", + "//tensorflow/python/eager:context", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/platform:client_testlib", + ], +) + +cuda_py_strict_test( + name = "jit_test", + size = "medium", + srcs = ["jit_test.py"], + #shard_count = 5, + tags = [ + "no_cuda_asan", # Times out. + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":test_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:cond", + "//tensorflow/python/ops:control_flow_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:nn_ops", + "//tensorflow/python/ops:while_loop", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +cuda_py_strict_test( + name = "async_comp_test", + size = "medium", + srcs = ["async_comp_test.py"], + shard_count = 1, + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python/client:session", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:function", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/platform:client_testlib", + ], +) + +cuda_py_strict_test( + name = "dense_layer_test", + size = "large", + srcs = ["dense_layer_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":test_utils", + "//tensorflow/core:protos_all_py", + "//tensorflow/python/compiler/xla:compiler_py", + "//tensorflow/python/framework:ops", + "//tensorflow/python/layers", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +cc_library( + name = "randomized_tests_library", + testonly = 1, + srcs = ["randomized_tests.cc"], + deps = [ + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:common", + "//tensorflow/compiler/jit:flags_headers", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:test", + "//tensorflow/core:testlib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:status", + "@local_xla//xla:xla_data_proto_cc", + ], +) + +tf_cuda_cc_test( + name = "randomized_tests", + size = "medium", + args = ["--tf_xla_test_use_mlir=false"], + shard_count = 20, + # This test is randomized, so only run it if explicitly requested. + tags = [ + "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +tf_cuda_cc_test( + name = "randomized_tests_mlir", + size = "medium", + args = ["--tf_xla_test_use_mlir=true"], + shard_count = 20, + # This test is randomized, so only run it if explicitly requested. + tags = [ + "manual", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "notap", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +# Create a deterministic version of randomized_tests_mlir with fixed seed. +# This can be used in presubmit checks as it is no longer randomized. +tf_cuda_cc_test( + name = "randomized_tests_mlir_seeded", + size = "medium", + args = [ + "--tf_xla_random_seed=200839030", + "--tf_xla_test_use_mlir=true", + "--tf_xla_test_device=GPU:0", + ], + shard_count = 20, + tags = [ + "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", # ROCmSoftwarePlatform #958 + "noasan", # TODO(b/201651800) + "requires-gpu-nvidia", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +# Create a deterministic version of randomized_tests with fixed seed. +# This can be used in presubmit checks as it is no longer randomized. +tf_cuda_cc_test( + name = "randomized_tests_seeded", + size = "medium", + args = [ + "--tf_xla_random_seed=200839030", + "--tf_xla_test_use_mlir=false", + "--tf_xla_test_device=GPU:0", + ], + shard_count = 20, + tags = [ + "config-cuda-only", + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + "no_rocm", # ROCmSoftwarePlatform #958 + "noasan", # TODO(b/201651800) + "requires-gpu-nvidia", + ] + tf_cuda_tests_tags(), + deps = [":randomized_tests_library"], +) + +tf_cuda_cc_test( + name = "unary_ops_composition_test", + srcs = ["unary_ops_composition_test.cc"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ] + tf_cuda_tests_tags(), + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/compiler/jit", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_kernel_creator", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "@local_tsl//tsl/platform:status", + ], +) + +py_strict_library( + name = "lstm", + testonly = 1, + srcs = ["lstm.py"], + srcs_version = "PY3", + deps = [ + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:random_ops", + "//tensorflow/python/ops:variable_v1", + "@six_archive//:six", + ], +) + +cuda_py_strict_test( + name = "lstm_test", + srcs = ["lstm_test.py"], + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + ":lstm", + ":xla_test", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/ops:array_ops", + "//tensorflow/python/ops:gradients_impl", + "//tensorflow/python/ops:init_ops", + "//tensorflow/python/ops:math_ops", + "//tensorflow/python/ops:variables", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +# An example of ahead-of-time compilation using tfcompile. The +# lstm_layer_inference.pbtxt file was generated by running lstm_test +# --dump_graph_dir, and the config file was written by hand. +# +# Run the following to build a minimal benchmark of the computation on Android: +# $ bazel build -c opt --cxxopt='-std=c++11' --linkopt='-lm' \ +# --cpu=armeabi-v7a \ +# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ +# --crosstool_top=//external:android/crosstool \ +# //tensorflow/compiler/tests:lstm_layer_inference_benchmark + +# +# Currently the resulting binary size is ~190KB +tf_library( + name = "lstm_layer_inference", + testonly = 1, + config = "lstm_layer_inference.config.pbtxt", + cpp_class = "LSTMLayerInference", + graph = "lstm_layer_inference.pbtxt", + tags = ["manual"], + tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], +) + +cuda_py_strict_test( + name = "const_test", + size = "small", + srcs = ["const_test.py"], + python_version = "PY3", + xla_enable_strict_auto_jit = False, + xla_enabled = True, + deps = [ + "//tensorflow/python/eager:def_function", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/framework:test_lib", + "//tensorflow/python/platform:client_testlib", + "//third_party/py/numpy", + ], +) + +tpu_py_strict_test( + name = "giant_const_op_test", + srcs = [ + "giant_const_op_test.py", + ], + disable_experimental = True, + # TODO(b/188995810): Add an optimization in MLIR importer to not + # materialize giant splat constants. + python_version = "PY3", + tags = ["no_oss"], + deps = [ + "//tensorflow/python/distribute:tpu_strategy", + "//tensorflow/python/distribute/cluster_resolver:tpu_cluster_resolver_py", + "//tensorflow/python/eager:def_function", + "//tensorflow/python/eager:remote", + "//tensorflow/python/eager:test", + "//tensorflow/python/framework:config", + "//tensorflow/python/framework:constant_op", + "//tensorflow/python/framework:dtypes", + "//tensorflow/python/platform:flags", + "//third_party/py/numpy", + ], +) diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index b54c2e54fa3552..12fa6dd7d04bd2 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -392,9 +392,9 @@ def testNumericOps(self): for dtype in self.numeric_types: self._testBinary( math_ops.subtract, - np.array([1, 2, 100], dtype=dtype), - np.array([10, 20, -1], dtype=dtype), - expected=np.array([-9, -18, 101], dtype=dtype)) + np.array([1, 20, 100], dtype=dtype), + np.array([1, 2, 1], dtype=dtype), + expected=np.array([0, 18, 99], dtype=dtype)) self._testBinary( math_ops.subtract, dtype(5), @@ -402,9 +402,9 @@ def testNumericOps(self): expected=np.array([4, 3], dtype=dtype)) self._testBinary( math_ops.subtract, - np.array([[1], [2]], dtype=dtype), + np.array([[7], [10]], dtype=dtype), dtype(7), - expected=np.array([[-6], [-5]], dtype=dtype)) + expected=np.array([[0], [3]], dtype=dtype)) # min/max not supported for complex if dtype not in self.complex_types | {np.uint8, np.int8}: @@ -461,13 +461,13 @@ def testNumericOps(self): self._testBinary( nn_ops.bias_add, np.array([[1, 2], [3, 4]], dtype=dtype), - np.array([2, -1], dtype=dtype), - expected=np.array([[3, 1], [5, 3]], dtype=dtype)) + np.array([2, 0], dtype=dtype), + expected=np.array([[3, 2], [5, 4]], dtype=dtype)) self._testBinary( nn_ops.bias_add, np.array([[[[1, 2], [3, 4]]]], dtype=dtype), - np.array([2, -1], dtype=dtype), - expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + np.array([2, 0], dtype=dtype), + expected=np.array([[[[3, 2], [5, 4]]]], dtype=dtype)) if np.int64 in self.numeric_types: self._testBinary( @@ -998,8 +998,8 @@ def testFill(self): self._testBinary( array_ops.fill, np.array([], dtype=np.int32), - dtype(-42), - expected=dtype(-42)) + dtype(42), + expected=dtype(42)) self._testBinary( array_ops.fill, np.array([1, 2], dtype=np.int32), diff --git a/tensorflow/compiler/tests/build_combined_defs.bzl b/tensorflow/compiler/tests/build_combined_defs.bzl new file mode 100644 index 00000000000000..92f04ab6215c91 --- /dev/null +++ b/tensorflow/compiler/tests/build_combined_defs.bzl @@ -0,0 +1,63 @@ +"""Build rule for combining Tensorflow/XLA tests.""" + +load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") + +def parse_label_name(label): + """Parse a label into just the name. + + Args: + label: string in relative or absolute form. + + Returns: + The name of the label. + """ + colon_split = label.split(":") + if len(colon_split) == 1: # no ":" in label + return label + return colon_split[-1] + +def tf_xla_combined_py_test(name = "", package = None, tests = [], **kwargs): + """Generates combined tf_xla_py_test targets, one per XLA backend. + + All srcs found in the list tests are combined into one new test which is then passed on to + tf_xla_py_test which creates a new target per XLA backend. + + Args: + name: Name of the target. + package: The package that all tests in tests belong to. + tests: The test targets to be combined and tested. Assumes all tests are in the same package. + **kwargs: keyword arguments passed onto the tf_xla_py_test rule. + """ + + test_file = name + ".py" + + # run the generator to create the combined test file containing all the tests in test_files + # redirecting the output of the generator to test_file. + native.genrule( + name = name + "_gen", + testonly = 1, + srcs = tests, + outs = [test_file], + cmd = """ +mkdir -p $(@D) && cat > $@ << EOF +from tensorflow.python.platform import test +%s + +if __name__ == "__main__": + test.main() +EOF + """ % "\n".join(["from %s.%s import *" % (package, parse_label_name(test)[:-4]) for test in tests]), + tools = [], + tags = ["generated_python_test=%s.%s" % (package, name)], + ) + + tf_xla_py_test( + name = name, + test_rule = py_strict_test, + srcs = [test_file], + deps = [ + "//tensorflow/python/platform:client_testlib", + ] + tests, + **kwargs + ) diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index ce6b626683e281..fb5cb0448e8224 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -1,5 +1,6 @@ """Build rules for Tensorflow/XLA testing.""" +load("//tensorflow:py.default.bzl", "py_library") load("//tensorflow:strict.default.bzl", "py_strict_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow/compiler/tests:plugin.bzl", "plugins") @@ -73,6 +74,12 @@ def tf_xla_py_test( cpu_xla_device = "CPU" gpu_xla_device = "GPU" + py_library( + name = name + "_lib", + srcs = srcs, + deps = deps, + testonly = 1, + ) for backend in backends: test_name = "{}_{}".format(name, backend) backend_tags = ["tf_xla_{}".format(backend)] @@ -139,7 +146,7 @@ def tf_xla_py_test( args = backend_args, main = "{}.py".format(name) if main == None else main, data = data + backend_data, - deps = deps + backend_deps + extra_dep, + deps = deps + backend_deps + extra_dep + [name + "_lib"], tags = test_tags + extra_tag, exec_properties = tf_exec_properties({"tags": test_tags}), **kwargs diff --git a/tensorflow/compiler/tests/const_test.py b/tensorflow/compiler/tests/const_test.py index 4e11a436e850af..bb1f3e23a7306e 100644 --- a/tensorflow/compiler/tests/const_test.py +++ b/tensorflow/compiler/tests/const_test.py @@ -33,15 +33,33 @@ class ConstOpTest(test_util.TensorFlowTestCase): # @test_util.run_v2_only def testConst(self): types = { - dtypes.bool, dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, - dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64, - dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64, - dtypes.float8_e5m2, dtypes.float8_e4m3fn, + dtypes.bool, + dtypes.int8, + dtypes.int16, + dtypes.int32, + dtypes.int64, + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, + dtypes.float8_e5m2, + dtypes.float8_e4m3fn, } for dtype in types: with self.subTest(dtype=dtype): if dtype == dtypes.bool: values = [True, False] + elif dtype in [ + dtypes.uint8, + dtypes.uint16, + dtypes.uint32, + dtypes.uint64, + ]: + values = [0., 1., dtype.min, dtype.max] else: values = [0., 1., -1., dtype.min, dtype.max] if dtype.is_floating: diff --git a/tensorflow/compiler/tests/dynamic_slice_ops_test.py b/tensorflow/compiler/tests/dynamic_slice_ops_test.py index 7abf9a0bba1122..9f4221cfdebe11 100644 --- a/tensorflow/compiler/tests/dynamic_slice_ops_test.py +++ b/tensorflow/compiler/tests/dynamic_slice_ops_test.py @@ -50,10 +50,10 @@ def testUpdateSlice(self): self._assertOpOutputMatchesExpected( xla.dynamic_update_slice, [ np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), - np.array([-1, -2, -3], dtype=dtype), + np.array([11, 12, 13], dtype=dtype), np.array([6], dtype=np.int32) ], - expected=np.array([1, 2, 3, 4, 5, 6, -1, -2, -3, 10], dtype=dtype)) + expected=np.array([1, 2, 3, 4, 5, 6, 11, 12, 13, 10], dtype=dtype)) self._assertOpOutputMatchesExpected( xla.dynamic_update_slice, [ diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index bb7a2a73ca7e9a..fe42e3f3807d0a 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -68,6 +68,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/jit/defs.h" #include "tensorflow/compiler/jit/flags.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/device.h" @@ -97,7 +98,6 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 930fd21ab42c27..ded287593029ff 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -236,25 +236,25 @@ def testUnsortedSegmentSum2DIndices3DData(self): for dtype in self.numeric_types: data = np.array( [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ - 200, 201, 202 - ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], + 80, 81, 82 + ], [123, 124, 125]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype) indices = np.array([[3, 5], [3, 1], [5, 0], [6, 2]], dtype=np.int32) num_segments = 8 y = self._unsortedSegmentSum(data, indices, num_segments) self.assertAllClose( np.array( - [[210, 211, 212], [110, 111, 112], [310, 311, 312], [ + [[123, 124, 125], [110, 111, 112], [106, 107, 108], [ 100, 102, 104 - ], [0, 0, 0.], [210, 212, 214], [300, 301, 302], [0, 0, 0]], + ], [0, 0, 0.], [90, 92, 94], [103, 104, 105], [0, 0, 0]], dtype=dtype), y) def testUnsortedSegmentSum1DIndices3DData(self): for dtype in self.numeric_types: data = np.array( [[[0, 1, 2], [10, 11, 12]], [[100, 101, 102], [110, 111, 112]], [[ - 200, 201, 202 - ], [210, 211, 212]], [[300, 301, 302], [310, 311, 312]]], + 120, 121, 122 + ], [123, 124, 125]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype) indices = np.array([3, 0, 2, 5], dtype=np.int32) num_segments = 6 @@ -262,8 +262,8 @@ def testUnsortedSegmentSum1DIndices3DData(self): self.assertAllClose( np.array( [[[100, 101, 102.], [110, 111, 112]], [[0, 0, 0], [0, 0, 0]], - [[200, 201, 202], [210, 211, 212]], [[0, 1, 2.], [10, 11, 12]], - [[0, 0, 0], [0, 0, 0]], [[300, 301, 302], [310, 311, 312]]], + [[120, 121, 122], [123, 124, 125]], [[0, 1, 2.], [10, 11, 12]], + [[0, 0, 0], [0, 0, 0]], [[103, 104, 105], [106, 107, 108]]], dtype=dtype), y) def testUnsortedSegmentSumShapeError(self): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index e4937d223165da..809db242ac4afe 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -185,21 +185,28 @@ def testSlice(self): np.array([[], [], []], dtype=dtype), np.array([1, 0], dtype=np.int32), np.array([2, 0], dtype=np.int32), - expected=np.array([[], []], dtype=dtype)) + expected=np.array([[], []], dtype=dtype), + ) self._testTernary( array_ops.slice, np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), np.array([0, 1], dtype=np.int32), np.array([2, 1], dtype=np.int32), - expected=np.array([[2], [5]], dtype=dtype)) + expected=np.array([[2], [5]], dtype=dtype), + ) def testClipByValue(self): - for dtype in self.numeric_types - self.complex_types: + for dtype in ( + self.numeric_types - self.complex_types - self.unsigned_int_types + ): test_cases = [ (np.array([2, 4, 5], dtype=dtype), dtype(7)), # (dtype(1), np.array([2, 4, 5], dtype=dtype)), # - (np.array([-2, 7, 7], dtype=dtype), np.array([-2, 9, 8], dtype=dtype)) + ( + np.array([-2, 7, 7], dtype=dtype), + np.array([-2, 9, 8], dtype=dtype), + ), ] x = np.array([-2, 10, 6], dtype=dtype) for lower, upper in test_cases: diff --git a/tensorflow/compiler/tests/unary_ops_composition_test.cc b/tensorflow/compiler/tests/unary_ops_composition_test.cc index 40a5300a8996e6..58878b8e0df78a 100644 --- a/tensorflow/compiler/tests/unary_ops_composition_test.cc +++ b/tensorflow/compiler/tests/unary_ops_composition_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/device_factory.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/util/port.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ee0967d2150e3d..99b997561b41c3 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -773,7 +773,7 @@ def testComplexOps(self): expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.int_types - self.unsigned_int_types: self._assertOpOutputMatchesExpected( bitwise_ops.invert, np.array([0, -1, 1, 16, 42], dtype=dtype), @@ -923,7 +923,10 @@ def _testCast(self, src_type, dst_type): if src_type.is_integer: imin = np.iinfo(src_np_dtype).min imax = np.iinfo(src_np_dtype).max - src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) + if src_type.is_unsigned: + src = np.array([imin, imax, 0, 1], dtype=src_np_dtype) + else: + src = np.array([imin, imax, 0, 1, -1], dtype=src_np_dtype) elif src_type in self.float_tf_types: if dst_type.is_integer: imin = np.iinfo(dst_np_dtype).min @@ -936,63 +939,75 @@ def _testCast(self, src_type, dst_type): eps = np.finfo(dst_np_dtype).eps src = np.array( [fmin, fmax, np.nan, eps, -eps, tiny, -tiny, np.inf, -np.inf], - dtype=src_np_dtype) + dtype=src_np_dtype, + ) dst = src.astype(dst_np_dtype) self._assertOpOutputMatchesExpected( lambda x, dst_type=dst_type: math_ops.cast(x, dst_type), src, - expected=dst) + expected=dst, + ) def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), - np.array([1, 0x3f800000], np.int32), - expected=np.array([1, 0x3f800000], np.int32)) + np.array([1, 0x3F800000], np.int32), + expected=np.array([1, 0x3F800000], np.int32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float32), - np.array([1, 0x3f800000], np.int32), - expected=np.array([1e-45, 1.0], np.float32)) + np.array([1, 0x3F800000], np.int32), + expected=np.array([1e-45, 1.0], np.float32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), np.array([1e-45, 1.0], np.float32), - expected=np.array([1, 0x3f800000], np.int32)) + expected=np.array([1, 0x3F800000], np.int32), + ) if np.int64 in self.numeric_types: self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int64), - np.array([1, 0x100000003f800000], np.uint64), - expected=np.array([1, 0x100000003f800000], np.int64)) + np.array([1, 0x100000003F800000], np.uint64), + expected=np.array([1, 0x100000003F800000], np.int64), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.uint64), - np.array([1, 0x100000003f800000], np.int64), - expected=np.array([1, 0x100000003f800000], np.uint64)) + np.array([1, 0x100000003F800000], np.int64), + expected=np.array([1, 0x100000003F800000], np.uint64), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float64), np.array( - [0, 0x3FF0000000000000, 0xc3af161421c8e000, 0x4032000000000007], + [0, 0x3FF0000000000000, 0xC3AF161421C8E000, 0x4032000000000007], np.uint64, ), expected=np.array( [0, 1.0, -1.12e+18, 18.000000000000024869], np.float64 ), - atol=0 + atol=0, ) def testBitcastInt8ToFloat(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.float32), - np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8), - expected=np.array([1e-45, 3.14159], np.float32)) + np.array([[1, 0, 0, 0], [0xD0, 0x0F, 0x49, 0x40]]).astype(np.int8), + expected=np.array([1e-45, 3.14159], np.float32), + ) self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.np.int8), np.array([1e-45, 3.14159], np.float32), - expected=np.array([[1, 0, 0, 0], [0xd0, 0x0f, 0x49, 0x40]], np.int8)) + expected=np.array([[1, 0, 0, 0], [0xD0, 0x0F, 0x49, 0x40]]).astype( + np.int8 + ), + ) def testInvertPermutation(self): for np_dtype in [np.int32, np.int64]: self._assertOpOutputMatchesExpected( array_ops.invert_permutation, np.array([1, 2, 0], np_dtype), - expected=np.array([2, 0, 1], dtype=np_dtype)) + expected=np.array([2, 0, 1], dtype=np_dtype), + ) def testInvertPermutationTwiceIsNoop(self): @@ -1013,12 +1028,12 @@ def testRank(self): self._assertOpOutputMatchesExpected( rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1)) + rank_op, np.array([0, 1], dtype=dtype), expected=np.int32(1)) self._assertOpOutputMatchesExpected( - rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) + rank_op, np.array([[0, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( rank_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int32(2)) def testShape(self): @@ -1032,15 +1047,15 @@ def testShape(self): expected=np.array([2, 0], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([-1, 1], dtype=dtype), + np.array([0, 1], dtype=dtype), expected=np.array([2], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([[-1, 1]], dtype=dtype), + np.array([[0, 1]], dtype=dtype), expected=np.array([1, 2], dtype=np.int32)) self._assertOpOutputMatchesExpected( shape_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.array([3, 1], dtype=np.int32)) def testSize(self): @@ -1051,12 +1066,12 @@ def testSize(self): self._assertOpOutputMatchesExpected( size_op, np.array([[], []], dtype=dtype), expected=np.int32(0)) self._assertOpOutputMatchesExpected( - size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2)) + size_op, np.array([0, 1], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( - size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2)) + size_op, np.array([[0, 1]], dtype=dtype), expected=np.int32(2)) self._assertOpOutputMatchesExpected( size_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int32(3)) def testSizeWithInt64OutType(self): @@ -1067,7 +1082,7 @@ def size_op(x): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( size_op, - np.array([[-1], [1], [4]], dtype=dtype), + np.array([[0], [1], [4]], dtype=dtype), expected=np.int64(3)) def testUnpack(self): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index c1ac5b11268e25..75e1eb44941ac2 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -668,6 +668,10 @@ cc_library( hdrs = [ "xla_resource.h", ], + visibility = [ + ":internal", + "//learning/deepmind/tensorflow/tpufunc:__pkg__", + ], deps = [ ":common", ":sharding_util", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 9ba3dedf4a6f54..e8695e29d7bfb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -211,7 +211,6 @@ tf_cuda_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", "@local_tsl//tsl/platform:statusor", @@ -223,9 +222,11 @@ tf_cuda_library( "@local_xla//xla/service:custom_call_target_registry", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/stream_executor", + "@local_xla//xla/stream_executor:stream_finder", "@local_xla//xla/stream_executor/gpu:gpu_executor_header", "@local_xla//xla/stream_executor/gpu:gpu_stream_header", "@local_xla//xla/stream_executor/gpu:gpu_types_header", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc index fe5b5d1626944d..78da9df6dd3cf3 100644 --- a/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc +++ b/tensorflow/compiler/tf2xla/kernels/light_outside_compilation.cc @@ -56,6 +56,8 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_finder.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/common_runtime/process_state.h" @@ -73,7 +75,6 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/statusor.h" @@ -585,11 +586,8 @@ Status CallTfKernel(void* stream_handle, void** buffers, const char* opaque, }(); if (platform_status != nullptr) return *platform_status; - se::StreamExecutorConfig config; - config.gpu_stream = stream_handle; - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); - se::Stream* stream = executor->FindAllocatedStream(stream_handle); + TF_ASSIGN_OR_RETURN(se::Stream * stream, + stream_executor::FindStream(platform, stream_handle)); if (!stream) { return xla::Internal("Stream not found for %p", stream_handle); } diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 39dd10d914eec9..31a6b689811e74 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -138,6 +138,7 @@ class ReshapeOp : public XlaOpKernel { std::vector dims_are_dynamic; const auto& dims = shape.dims(); dims_are_dynamic.reserve(dims); + output_dim_sizes.reserve(dims); for (int64_t i = 0; i < dims; ++i) { output_dim_sizes.push_back( xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 26d3cff1ede132..09d6898f4287e4 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -405,6 +405,7 @@ class ZerosLikeOp : public XlaOpKernel { std::vector dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); auto sub_element = xla::GetTupleElement(list, i); + dynamic_dims.reserve(shape.dimensions_size()); for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); } diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index 73c6c34cb877fa..1845b9b5590520 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -275,6 +275,7 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { // and then scatter iotas[out_idxs] into the output. std::vector iotas_to_concat; auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions()); + iotas_to_concat.reserve(iota_shape.rank()); for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) { iotas_to_concat.push_back( xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1})); diff --git a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc index 56139e2ed9dcba..be344f17ed7941 100644 --- a/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/mlir_xla_op_kernel.cc @@ -31,7 +31,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/xla_builder.h" #include "tensorflow/core/framework/device.h" -#include "tensorflow/core/framework/graph_debug_info.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/resource_base.h" @@ -154,7 +153,6 @@ Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { core::ScopedUnref unref_ctx(ctx_res); // Compile the graph to HLO. - GraphDebugInfo debug_info; std::vector returns(1); auto build_hlo = [&](bool unconditionally_use_output_shapes) { return BuildHloFromGraph( @@ -162,7 +160,7 @@ Status MlirXlaOpKernel::ConstructXlaOp(XlaOpKernelContext* ctx) { unconditionally_use_output_shapes, mlir::SpanToArrayRef(xla_args), control_rets, device->device_type(), - *ctx->function_library()->GetFunctionLibraryDefinition(), debug_info); + *ctx->function_library()->GetFunctionLibraryDefinition()); }; // Some of the operations that come through here do not know how to set their diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 28d063b721ba57..fa3617a97402a6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -1393,6 +1393,7 @@ void ConvertConstantsToExpressions(xla::XlaBuilder* builder, // TODO(b/265059672): Remove when end-to-end stack trace handling is in place class DummyStackTrace : public AbstractStackTrace { absl::Span ToFrames() const override { return frames_; } + std::vector ToUncachedFrames() const override { return frames_; } StackFrame LastUserFrame() const override { return frames_.back(); } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 787d67674a2c8e..f46bd78cf6960b 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -331,16 +331,6 @@ TEST(XlaJitCompiledCpuFunction, CanCompileWithAdditionalPlatform) { return nullptr; } - absl::StatusOr GetExecutor( - const se::StreamExecutorConfig& config) override { - return nullptr; - } - - absl::StatusOr> GetUncachedExecutor( - const se::StreamExecutorConfig& config) override { - return std::unique_ptr(nullptr); - } - private: string name_; }; diff --git a/tensorflow/core/common_runtime/BUILD b/tensorflow/core/common_runtime/BUILD index 11712d5266a4e1..5302d2036fa574 100644 --- a/tensorflow/core/common_runtime/BUILD +++ b/tensorflow/core/common_runtime/BUILD @@ -3443,10 +3443,10 @@ tf_cc_test( "//tensorflow/core/framework:function_testlib", "//tensorflow/core/framework:optimized_function_graph_proto_cc", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -3493,9 +3493,9 @@ tf_cc_test( "//tensorflow/core/kernels:function_ops", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -3513,7 +3513,7 @@ tf_cc_test( "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc index 9b7930ad1ef590..473be0c108896d 100644 --- a/tensorflow/core/common_runtime/accumulate_n_optimizer.cc +++ b/tensorflow/core/common_runtime/accumulate_n_optimizer.cc @@ -64,7 +64,7 @@ class AccumulateNV2RemovePass : public GraphOptimizationPass { } // Build up a todo list of ops to replace, *then* modify the graph - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->op_nodes()) { if (n->type_string() == "AccumulateNV2") { matches.push_back(n); diff --git a/tensorflow/core/common_runtime/arg_ret_placement.cc b/tensorflow/core/common_runtime/arg_ret_placement.cc index a995564c8c2964..386c54849a254e 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement.cc @@ -255,7 +255,7 @@ Status SetAllocAttrsForArgs(const gtl::InlinedVector& nodes, /*weak_flag=*/false, nullptr, &alloc_attrs); } -Status WeakSetAllocAttrsForArgs(const gtl::InlinedVector& nodes, +Status WeakSetAllocAttrsForArgs(const absl::InlinedVector& nodes, const DataTypeVector& dtypes, std::vector& alloc_attrs) { return SetMemoryTypeHelper(nodes, dtypes, /*is_arg=*/true, diff --git a/tensorflow/core/common_runtime/arg_ret_placement.h b/tensorflow/core/common_runtime/arg_ret_placement.h index 4f00d18e3bb6ca..fd8a4858b83c8e 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement.h +++ b/tensorflow/core/common_runtime/arg_ret_placement.h @@ -93,7 +93,7 @@ Status SetAllocAttrsForRets(const gtl::InlinedVector& nodes, // ops) based on dtype. Logging of warnings if an int32 ret does not have // expected full_type information (i.e. if the source of the input to the ret // does not have expected full type information) can be enabled. -Status WeakSetAllocAttrsForRets(const gtl::InlinedVector& nodes, +Status WeakSetAllocAttrsForRets(const absl::InlinedVector& nodes, const DataTypeVector& dtypes, std::vector& alloc_attrs); diff --git a/tensorflow/core/common_runtime/arg_ret_placement_test.cc b/tensorflow/core/common_runtime/arg_ret_placement_test.cc index 8aea65709154ed..284702a4ecc3e2 100644 --- a/tensorflow/core/common_runtime/arg_ret_placement_test.cc +++ b/tensorflow/core/common_runtime/arg_ret_placement_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "tensorflow/cc/framework/scope.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/full_type.pb.h" #include "tensorflow/core/framework/function.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -205,7 +205,7 @@ TEST_F(FullTypeGraphUtilsTest, MemoryTypeRetWithFT) { } TEST_F(FullTypeGraphUtilsTest, AllowAttrRetWithFT) { - gtl::InlinedVector nodes; + absl::InlinedVector nodes; DataTypeVector dtypes; std::vector alloc_attrs; diff --git a/tensorflow/core/common_runtime/collective_test_util.cc b/tensorflow/core/common_runtime/collective_test_util.cc index 24c85b321ae0d9..18ef2ab824daf1 100644 --- a/tensorflow/core/common_runtime/collective_test_util.cc +++ b/tensorflow/core/common_runtime/collective_test_util.cc @@ -325,10 +325,11 @@ Status RunCollective(CollectiveTestEnv* test_env, CollectiveParams* col_params, op_params.step_id = kStepId; op_params.device = device; op_params.cancellation_manager = &cancellation_manager; - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.push_back(TensorValue(&input_buffer)); op_params.inputs = inputs; - gtl::InlinedVector input_aa({AllocatorAttributes()}); + absl::InlinedVector input_aa( + {AllocatorAttributes()}); op_params.input_alloc_attrs = input_aa; DeviceContext* dev_ctx = nullptr; auto* dev_info = device->tensorflow_accelerator_device_info(); diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h index 01fb8b8c81cd2f..b53e779701afce 100644 --- a/tensorflow/core/common_runtime/collective_util.h +++ b/tensorflow/core/common_runtime/collective_util.h @@ -37,9 +37,9 @@ string SubdivPermDebugString(const CollectiveParams& col_params); class SubContext { public: OpKernelContext::Params sub_params_; - gtl::InlinedVector sub_inputs_; - gtl::InlinedVector sub_input_attr_; - gtl::InlinedVector sub_input_dc_; + absl::InlinedVector sub_inputs_; + absl::InlinedVector sub_input_attr_; + absl::InlinedVector sub_input_dc_; // Used only for Binary and Unary Ops for which we require // the calculation to be in-place on the first input. int forward_from_ = 0; diff --git a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc index a16b90cf0bf87c..6e78f3247cfa57 100644 --- a/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc +++ b/tensorflow/core/common_runtime/colocate_predecessor_trees_pass_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/cc/framework/scope.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/config/flag_defs.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index a4712a5c83a742..675ffc624c68e4 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -206,7 +206,7 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context, const Tensor* input, Tensor* output, int dev_to_dev_stream_index, StatusCallback done, bool sync_dst_compute) { - profiler::ScopedAnnotation annotation( + tsl::profiler::ScopedAnnotation annotation( [&] { return absl::StrCat("#edge_name=", edge_name, "#"); }); VLOG(4) << "Copy " << edge_name; diff --git a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc index d2b51ee6652580..1f8c8cd7c7622e 100644 --- a/tensorflow/core/common_runtime/eager/context_distributed_manager.cc +++ b/tensorflow/core/common_runtime/eager/context_distributed_manager.cc @@ -136,49 +136,6 @@ class XlaKeyValueStore : public xla::KeyValueStoreInterface { std::string key_prefix_; }; -// Remove LocalDeviceState objects from -// info->local_device_states that have unique hardware IDs -// (i.e. ignore duplicate virtual devices) and return them in a map. -static std::map> -GetUniqueDeviceStates(PjRtGpuClientCreationInfo* info) { - // Only consider each hardware device once. In test environments, one - // physical GPU (e.g. hardware_id 0) might be shared as virtual GPUs (e.g. - // local_id 0 and 1) by multiple workers (multiple processes on the same - // computer). If there is a need to not ignore these for an actual case, a - // possible solution is to add a flag to only enable the use of - // hardware_id_to_local_id for tests. - - auto input_states = std::move(info->local_device_states); - - absl::flat_hash_map hardware_id_to_local_id; - for (const auto& id_state : input_states) { - int local_id = id_state.second->local_device_id().value(); - int hardware_id = id_state.second->local_hardware_id().value(); - if (hardware_id_to_local_id.contains(hardware_id)) { - if (hardware_id_to_local_id[hardware_id] > local_id) { - // Use the device with the smallest local_id, ignore others. - hardware_id_to_local_id[hardware_id] = local_id; - } - } else { - hardware_id_to_local_id[hardware_id] = local_id; - } - } - std::map> local_device_states; - for (auto& id_state : input_states) { - int local_id = id_state.second->local_device_id().value(); - int hardware_id = id_state.second->local_hardware_id().value(); - if (hardware_id_to_local_id[hardware_id] != local_id) { - VLOG(1) << "For hardware_id=" << hardware_id - << ", ignoring redundant local_id=" << local_id - << ". local_id=" << hardware_id_to_local_id[hardware_id] - << " will be used instead."; - continue; - } - local_device_states.emplace(id_state.first, std::move(id_state.second)); - } - return local_device_states; -} - // Coordinate creation of a PjRt GPU client with distributed devices when there // are multiple threads (which typically occurs in test environments that use // multiple threads to simulate multiple workers). @@ -319,10 +276,9 @@ absl::Status CreateClientOnce( auto kv_store = std::make_shared(coordination_service_agent); - std::map> - unique_local_device_states; + std::map> local_device_states; if (use_creation_info) { - unique_local_device_states = GetUniqueDeviceStates(info); + local_device_states = std::move(info->local_device_states); } if (use_creation_info) { // Tell any other threads are waiting to call BuildDistributedDevices to @@ -330,7 +286,7 @@ absl::Status CreateClientOnce( creation_state->SetReady(); } auto device_topology_pair = BuildDistributedDevices( - platform_name, std::move(unique_local_device_states), node_id, num_nodes, + platform_name, std::move(local_device_states), node_id, num_nodes, gpu_run_options.get(), kv_store, /*enable_mock_nccl=*/false); if (!device_topology_pair.ok()) { if (use_creation_info) { diff --git a/tensorflow/core/common_runtime/eager/context_test.cc b/tensorflow/core/common_runtime/eager/context_test.cc index 1758e3139376f1..4c50c62c33ab5b 100644 --- a/tensorflow/core/common_runtime/eager/context_test.cc +++ b/tensorflow/core/common_runtime/eager/context_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/composite_device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context_distributed_manager.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/eager/eager_executor_test.cc b/tensorflow/core/common_runtime/eager/eager_executor_test.cc index e933a04324bd9f..1650dbf975866e 100644 --- a/tensorflow/core/common_runtime/eager/eager_executor_test.cc +++ b/tensorflow/core/common_runtime/eager/eager_executor_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/common_runtime/eager/placement_utils_test.cc b/tensorflow/core/common_runtime/eager/placement_utils_test.cc index 8803d7471e29e8..6220cc95778d66 100644 --- a/tensorflow/core/common_runtime/eager/placement_utils_test.cc +++ b/tensorflow/core/common_runtime/eager/placement_utils_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/eager_operation.h" #include "tensorflow/core/common_runtime/eager/execute_node.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #define DEVICE_CPU0 "/job:localhost/replica:0/task:0/device:CPU:0" #define DEVICE_CPU0_TASK1 "/job:localhost/replica:0/task:1/device:CPU:0" diff --git a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc index 1038ba93faa9c7..efa597feb838ce 100644 --- a/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc +++ b/tensorflow/core/common_runtime/eager/summary_optimizer_test.cc @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/common_runtime/entry.h b/tensorflow/core/common_runtime/entry.h index 9164cce3eae94c..82bf44eae816b9 100644 --- a/tensorflow/core/common_runtime/entry.h +++ b/tensorflow/core/common_runtime/entry.h @@ -134,7 +134,7 @@ struct Entry { }; // TODO(b/152925936): Re-evaluate this constant with current usage patterns. -typedef gtl::InlinedVector EntryVector; +typedef absl::InlinedVector EntryVector; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 48ec47636e30df..2054114de4d86d 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -142,8 +142,8 @@ struct KernelTimer { }; // TODO(b/152925936): Re-evaluate these constants with current usage patterns. -typedef gtl::InlinedVector TensorValueVec; -typedef gtl::InlinedVector AllocatorAttributeVec; +typedef absl::InlinedVector TensorValueVec; +typedef absl::InlinedVector AllocatorAttributeVec; class ExecutorImpl : public Executor { public: diff --git a/tensorflow/core/common_runtime/function_body.cc b/tensorflow/core/common_runtime/function_body.cc index 1ca6f6a535ceb6..60a6f41f1d8162 100644 --- a/tensorflow/core/common_runtime/function_body.cc +++ b/tensorflow/core/common_runtime/function_body.cc @@ -35,7 +35,7 @@ FunctionBody::FunctionBody(core::RefCountPtr&& record, this->arg_nodes.resize(arg_types.size()); this->ret_nodes.resize(ret_types.size()); for (Node* n : this->graph->op_nodes()) { - gtl::InlinedVector* node_vec; + absl::InlinedVector* node_vec; if (n->type_string() == FunctionLibraryDefinition::kRetOp || n->type_string() == FunctionLibraryDefinition::kDeviceRetOp) { node_vec = &this->ret_nodes; diff --git a/tensorflow/core/common_runtime/function_body.h b/tensorflow/core/common_runtime/function_body.h index 97d27f51099e31..959f9803227764 100644 --- a/tensorflow/core/common_runtime/function_body.h +++ b/tensorflow/core/common_runtime/function_body.h @@ -37,11 +37,11 @@ struct FunctionBody { DataTypeVector ret_types; // arg_nodes[i] contains the i'th function input. In other words, // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector arg_nodes; + absl::InlinedVector arg_nodes; // ret_nodes[i] contains the i'th function output. In other words, // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i. - gtl::InlinedVector ret_nodes; - gtl::InlinedVector control_ret_nodes; + absl::InlinedVector ret_nodes; + absl::InlinedVector control_ret_nodes; FunctionBody() {} FunctionBody(core::RefCountPtr&& record, diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 432167b8c23b5d..bb2a3f7f46478d 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -385,6 +385,7 @@ TEST_F(FunctionLibraryRuntimeTest, XTimesTwo) { TEST_F(FunctionLibraryRuntimeTest, InstantiationStackTraceCopying) { class DummyStackTrace : public AbstractStackTrace { absl::Span ToFrames() const override { return {}; } + std::vector ToUncachedFrames() const override { return {}; } std::string ToString(const TracePrintingOptions& opts) const override { return "DummyStackTrace"; diff --git a/tensorflow/core/common_runtime/function_utils.cc b/tensorflow/core/common_runtime/function_utils.cc index facd31481c05ed..56389623808262 100644 --- a/tensorflow/core/common_runtime/function_utils.cc +++ b/tensorflow/core/common_runtime/function_utils.cc @@ -162,7 +162,7 @@ bool RemoveIdentityNodes(Graph* g) { bool RemoveListArrayConverter(Graph* g) { VLOG(2) << "Removing list array converter"; - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->nodes()) { if ((n->type_string() == "_ListToArray") || (n->type_string() == "_ArrayToList")) { diff --git a/tensorflow/core/common_runtime/gpu/BUILD b/tensorflow/core/common_runtime/gpu/BUILD index 100792f4266d91..24c6f24efeb290 100644 --- a/tensorflow/core/common_runtime/gpu/BUILD +++ b/tensorflow/core/common_runtime/gpu/BUILD @@ -1,4 +1,5 @@ load("@bazel_skylib//lib:selects.bzl", "selects") +load("@local_xla//xla/tsl:tsl.bzl", "if_hermetic_cuda_libs") load( "//tensorflow:tensorflow.bzl", "clean_dep", @@ -140,6 +141,19 @@ filegroup( visibility = ["//visibility:private"], ) +cc_library( + name = "gpu_runtime_hermetic_cuda_deps", + visibility = ["//visibility:public"], + deps = if_hermetic_cuda_libs([ + "@local_xla//xla/tsl/cuda:cudart", + "@local_xla//xla/tsl/cuda:cublas", + "@local_xla//xla/tsl/cuda:cufft", + "@local_xla//xla/tsl/cuda:cusolver", + "@local_xla//xla/tsl/cuda:cusparse", + "@local_xla//xla/tsl/cuda:cudnn", + ]), +) + tf_cuda_library( name = "gpu_runtime_impl", srcs = [ @@ -158,6 +172,7 @@ tf_cuda_library( "@local_config_cuda//cuda:cudnn_header", "@local_xla//xla/stream_executor/cuda:cuda_platform", "@local_xla//xla/stream_executor/gpu:gpu_stream", + ":gpu_runtime_hermetic_cuda_deps", ], defines = if_linux_x86_64(["TF_PLATFORM_LINUX_X86_64"]), features = ["-layering_check"], diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index 1736dd54dd95e0..3b2480784ab187 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/tests/test_macros.h" #include "xla/tsl/framework/device_id.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/gpu/gpu_process_state.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #ifdef TF_GPU_USE_PJRT #include "xla/pjrt/pjrt_client.h" diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc index 74ff893a5bef3d..179a0c89f73d9a 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc @@ -28,13 +28,11 @@ namespace { TEST(PoolAllocatorTest, ZeroSizeBuffers) { se::Platform* platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); - PoolAllocator pool( - 2 /*pool_size_limit*/, false /*auto_resize*/, - new DeviceHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .value(), - 0 /*numa_node*/, {}, {}), - new NoopRounder, "pool"); + PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/, + new DeviceHostAllocator( + platform->ExecutorForDevice(/*ordinal=*/0).value(), + 0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/)); pool.DeallocateRaw(nullptr); // Should not crash. @@ -47,13 +45,11 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) { TEST(PoolAllocatorTest, ZeroSizePool) { se::Platform* platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); - PoolAllocator pool( - 0 /*pool_size_limit*/, false /*auto_resize*/, - new DeviceHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .value(), - 0 /*numa_node*/, {}, {}), - new NoopRounder, "pool"); + PoolAllocator pool(0 /*pool_size_limit*/, false /*auto_resize*/, + new DeviceHostAllocator( + platform->ExecutorForDevice(/*ordinal=*/0).value(), + 0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); EXPECT_EQ(0, pool.get_from_pool_count()); EXPECT_EQ(0, pool.put_count()); @@ -81,13 +77,11 @@ TEST(PoolAllocatorTest, ZeroSizePool) { TEST(PoolAllocatorTest, Alignment) { se::Platform* platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); - PoolAllocator pool( - 0 /*pool_size_limit*/, false /*auto_resize*/, - new DeviceHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .value(), - 0 /*numa_node*/, {}, {}), - new NoopRounder, "pool"); + PoolAllocator pool(0 /*pool_size_limit*/, false /*auto_resize*/, + new DeviceHostAllocator( + platform->ExecutorForDevice(/*ordinal=*/0).value(), + 0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); for (int i = 0; i < 16; ++i) { size_t alignment = 1 << i; void* p = pool.AllocateRaw(alignment, 111); @@ -144,8 +138,8 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { se::Platform* platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); DeviceHostAllocator* sub_allocator = new DeviceHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)).value(), - 0 /*numa_node*/, {alloc_visitor}, {free_visitor}); + platform->ExecutorForDevice(/*ordinal=*/0).value(), 0 /*numa_node*/, + {alloc_visitor}, {free_visitor}); PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/, sub_allocator, new NoopRounder, "pool"); EXPECT_EQ(0, alloc_count); @@ -245,13 +239,11 @@ TEST(PoolAllocatorTest, Pow2Rounder) { TEST(PoolAllocatorTest, Name) { se::Platform* platform = se::PlatformManager::PlatformWithName(se::GpuPlatformName()).value(); - PoolAllocator pool( - 2 /*pool_size_limit*/, false /*auto_resize*/, - new DeviceHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .value(), - 0 /*numa_node*/, {}, {}), - new NoopRounder, "pool"); + PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/, + new DeviceHostAllocator( + platform->ExecutorForDevice(/*ordinal=*/0).value(), + 0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); EXPECT_EQ("pool", pool.Name()); } diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index dcc3bfb4335d95..1c8f6283c57c07 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -29,14 +29,14 @@ namespace tensorflow { class GPUDeviceContext : public DeviceContext { public: // Does not take ownership of streams. - GPUDeviceContext(int stream_id, se::Stream* stream, + GPUDeviceContext( + int stream_id, se::Stream* stream, #if TENSORFLOW_USE_ROCM - se::Stream* nccl_stream, + se::Stream* nccl_stream, #endif - se::Stream* host_to_device_stream, - se::Stream* device_to_host_stream, - gtl::InlinedVector device_to_device_stream, - Allocator* host_memory_allocator) + se::Stream* host_to_device_stream, se::Stream* device_to_host_stream, + absl::InlinedVector device_to_device_stream, + Allocator* host_memory_allocator) : stream_id_(stream_id), stream_(stream), #if TENSORFLOW_USE_ROCM @@ -96,7 +96,7 @@ class GPUDeviceContext : public DeviceContext { // The stream to use for copying data from GPU to host. se::Stream* device_to_host_stream_; // Streams to use for copying data between GPUs. - gtl::InlinedVector device_to_device_stream_; + absl::InlinedVector device_to_device_stream_; // The allocator to use for allocating pinned host memory. // Not owned. Allocator* host_memory_allocator_; diff --git a/tensorflow/core/common_runtime/gradients.cc b/tensorflow/core/common_runtime/gradients.cc index 7c48847cb22149..b91d6986705fcc 100644 --- a/tensorflow/core/common_runtime/gradients.cc +++ b/tensorflow/core/common_runtime/gradients.cc @@ -345,7 +345,7 @@ Status SymbolicGradientBuilder::Compute() { InitBackprop(); // Backward propagation. - gtl::InlinedVector dy; + absl::InlinedVector dy; while (!ready_.empty()) { // n has collected all gradients. Node* n = ready_.front(); diff --git a/tensorflow/core/common_runtime/graph_constructor.cc b/tensorflow/core/common_runtime/graph_constructor.cc index 66109aee89eaa9..3705ede827e0f6 100644 --- a/tensorflow/core/common_runtime/graph_constructor.cc +++ b/tensorflow/core/common_runtime/graph_constructor.cc @@ -112,7 +112,7 @@ class GraphConstructor { : allow_internal_ops(false), expect_device_spec(false), propagate_device_spec(in.propagate_device_spec), - prefix(in.prefix.empty() || str_util::EndsWith(in.prefix, "/") + prefix(in.prefix.empty() || absl::EndsWith(in.prefix, "/") ? in.prefix : in.prefix + "/"), uniquify_names(in.uniquify_names), @@ -370,7 +370,7 @@ class GraphConstructor { // Mapping between index within node_defs_ and the index within node_defs_ of // all nodes it outputs to. - std::vector> outputs_; + std::vector> outputs_; // Used in the conversion from node_defs_ to g_ to represent the ith input // of a node. diff --git a/tensorflow/core/common_runtime/graph_view.cc b/tensorflow/core/common_runtime/graph_view.cc index 29458524bd5051..4bbd22c89dfe6f 100644 --- a/tensorflow/core/common_runtime/graph_view.cc +++ b/tensorflow/core/common_runtime/graph_view.cc @@ -157,7 +157,7 @@ char* GraphView::InitializeNode(char* ptr, const Node* n) { // a given output slot. For all but the last, we need to do a copy of the // Tensor when propagating results downstream in the graph, but for the // last one, we can just do a move of the Tensor object to propagate it. - gtl::InlinedVector last_indices(num_outputs, nullptr); + absl::InlinedVector last_indices(num_outputs, nullptr); EdgeInfo* dst_edge = item->output_edge_base(); for (auto e : n->out_edges()) { if (e->IsControlEdge()) continue; diff --git a/tensorflow/core/common_runtime/inspecting_placer.cc b/tensorflow/core/common_runtime/inspecting_placer.cc index a84cd700874d8c..8a0eb150dd497d 100644 --- a/tensorflow/core/common_runtime/inspecting_placer.cc +++ b/tensorflow/core/common_runtime/inspecting_placer.cc @@ -77,7 +77,7 @@ class ColocationGraphToIOColocationGroups { ColocationGraph* colocation_graph) : colocation_graph_(colocation_graph), next_group_id_(0) {} - void AssignGroups(const gtl::InlinedVector& nodes, + void AssignGroups(const absl::InlinedVector& nodes, std::vector* groups) { for (int i = 0; i < nodes.size(); ++i) { int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id()); diff --git a/tensorflow/core/common_runtime/int32_fulltype_test.cc b/tensorflow/core/common_runtime/int32_fulltype_test.cc index e6ead597aea23e..5d2c0e0b9bdb46 100644 --- a/tensorflow/core/common_runtime/int32_fulltype_test.cc +++ b/tensorflow/core/common_runtime/int32_fulltype_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/graph_def_builder_util.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/framework/full_type.pb.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index b20d0057c727d0..63fd2f1b59c223 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -126,7 +126,7 @@ LocalDevice::LocalDevice(const SessionOptions& options, // computations. static mutex& global_tp_mu = *new mutex; static auto& global_tp_info TF_GUARDED_BY(global_tp_mu) = - *new gtl::InlinedVector; + *new absl::InlinedVector; mutex_lock l(global_tp_mu); if (options.config.experimental().use_numa_affinity()) { diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc index b57145c73167ff..31c1e40b431e7b 100644 --- a/tensorflow/core/common_runtime/lower_while_op_test.cc +++ b/tensorflow/core/common_runtime/lower_while_op_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/cc/ops/control_flow_ops_internal.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/common_runtime/graph_runner.h" #include "tensorflow/core/common_runtime/lower_functional_ops.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/common_runtime/next_pluggable_device/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/BUILD index 74f843c189bb38..4a36e8f67e0a68 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/BUILD @@ -329,7 +329,6 @@ tf_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:errors", @@ -341,6 +340,7 @@ tf_cc_test( "@local_xla//xla/tsl/distributed_runtime:call_options", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_client", "@local_xla//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD index 7feb974ed430b7..7862391ec43c6a 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/BUILD @@ -169,8 +169,8 @@ tf_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc index b5a37df930e86e..02ea581b909bfd 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/synchronization/notification.h" #include "xla/tsl/framework/allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_helper.h" #include "tensorflow/core/common_runtime/next_pluggable_device/c/tf_rendezvous_c_api_internal.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc index f61be2ef6a6a39..5d62f8c58668c6 100644 --- a/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc +++ b/tensorflow/core/common_runtime/next_pluggable_device/c_plugin_coordination_service_agent_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/protobuf/coordination_config.pb.h" diff --git a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc index b2cf3d13e78766..52925d9d539a26 100644 --- a/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc +++ b/tensorflow/core/common_runtime/optimize_function_graph_utils_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info.h b/tensorflow/core/common_runtime/optimized_function_graph_info.h index b2bd9af5bb1c5a..b15790dbeede36 100644 --- a/tensorflow/core/common_runtime/optimized_function_graph_info.h +++ b/tensorflow/core/common_runtime/optimized_function_graph_info.h @@ -71,9 +71,10 @@ struct OptimizedFunctionGraphInfo { OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo& info) = delete; OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo& info) = delete; - OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) = default; + OptimizedFunctionGraphInfo(OptimizedFunctionGraphInfo&& info) = + default; // NOLINT OptimizedFunctionGraphInfo& operator=(OptimizedFunctionGraphInfo&& info) = - default; + default; // NOLINT // Converts from the struct to OptimizedFunctionGraph proto. static OptimizedFunctionGraph ToProto(const OptimizedFunctionGraphInfo& info); diff --git a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc index 800da5f50d6297..cab15e62819083 100644 --- a/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc +++ b/tensorflow/core/common_runtime/optimized_function_graph_info_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/substitute.h" #include "third_party/protobuf/text_format.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/optimized_function_graph.pb.h" #include "tensorflow/core/graph/node_builder.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index bb6cc17b5b1665..8829f6d9d270c0 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -38,7 +38,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { "Parallel concat removal should happen before partitioning and a " "graph should be available."); } - gtl::InlinedVector matches; + absl::InlinedVector matches; for (Node* n : g->op_nodes()) { if (n->type_string() == "ParallelConcat") { matches.push_back(n); diff --git a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc index 346ba173886846..bf335df3bf54b8 100644 --- a/tensorflow/core/common_runtime/replicate_constants_pass_test.cc +++ b/tensorflow/core/common_runtime/replicate_constants_pass_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/config/flags.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index b2b22f29e9f5f0..06c353eb3669de 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -742,19 +742,29 @@ Status ShapeRefiner::RunShapeFn(const Node* node, // performing inference on the function body. auto const_tensor_map_copy = const_tensor_map_; const_tensor_map_.clear(); + VLOG(4) << "Running shape inference for function \"" + << function.name() << "\"."; Status function_inference_status = InferShapesForFunction( function_def, AttrSlice(&function.attr()), c); const_tensor_map_ = const_tensor_map_copy; + VLOG(4) << "Shape inference for function \"" << function.name() + << "\" returned status " << function_inference_status << "."; return function_inference_status; } } } if (op_reg_data->shape_inference_fn) { + VLOG(4) << "Running shape inference function for node \"" << node->name() + << "\" of type \"" << node->type_string() << "\"."; TF_RETURN_IF_ERROR(c->Run(op_reg_data->shape_inference_fn)); } else { + VLOG(4) << "Unknown shape inference function for node \"" << node->name() + << "\" of type \"" << node->type_string() << "\"."; TF_RETURN_IF_ERROR(c->Run(shape_inference::UnknownShape)); } + VLOG(4) << "Shape inference passed for node \"" << node->name() + << "\" of type \"" << node->type_string() << "\"."; return absl::OkStatus(); }; TF_RETURN_IF_ERROR(run_inference_lambda()); diff --git a/tensorflow/core/config/flag_defs.h b/tensorflow/core/config/flag_defs.h index a773fbb1b20c1c..23e9989a31edb7 100644 --- a/tensorflow/core/config/flag_defs.h +++ b/tensorflow/core/config/flag_defs.h @@ -64,6 +64,9 @@ class Flags { // TODO(b/341325107): Make this behavior the default and remove the flag. TF_DECLARE_FLAG(enable_function_pruning_before_inlining, false, "If true, functions will be pruned before inlining.") + TF_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs, false, + "If true, TF2XLA encapsulation will be skipped for non-TPU " + "graphs.") // LINT.ThenChange(//tensorflow/core/config/flags_api_wrapper.cc) }; diff --git a/tensorflow/core/config/flags_api_wrapper.cc b/tensorflow/core/config/flags_api_wrapper.cc index 096d48c5dc1720..060ede3846df23 100644 --- a/tensorflow/core/config/flags_api_wrapper.cc +++ b/tensorflow/core/config/flags_api_wrapper.cc @@ -55,5 +55,6 @@ PYBIND11_MODULE(flags_pybind, m) { TF_PY_DECLARE_FLAG(enable_colocation_key_propagation_in_while_op_lowering); TF_PY_DECLARE_FLAG(enable_tf2min_ici_weight) TF_PY_DECLARE_FLAG(enable_function_pruning_before_inlining) + TF_PY_DECLARE_FLAG(enable_skip_encapsulation_for_non_tpu_graphs) // LINT.ThenChange(//tensorflow/core/config/flag_defs.h) }; diff --git a/tensorflow/core/data/BUILD b/tensorflow/core/data/BUILD index 1ec7a6f8f819f9..748dfc17ce213e 100644 --- a/tensorflow/core/data/BUILD +++ b/tensorflow/core/data/BUILD @@ -218,7 +218,9 @@ cc_library( "//tensorflow/core:framework", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/tensorflow/core/data/captured_function.cc b/tensorflow/core/data/captured_function.cc index 6ddf987b3ef95e..2206cb09d7e9a5 100644 --- a/tensorflow/core/data/captured_function.cc +++ b/tensorflow/core/data/captured_function.cc @@ -943,8 +943,10 @@ Status InstantiatedCapturedFunction::RunInstantiated( } void InstantiatedCapturedFunction::RunAsync( - IteratorContext* ctx, std::vector&& args, std::vector* rets, - FunctionLibraryRuntime::DoneCallback done, + std::function)> runner, + CancellationManager* parent_cancellation_manager, + CollectiveExecutor* collective_executor, std::vector&& args, + std::vector* rets, FunctionLibraryRuntime::DoneCallback done, const std::shared_ptr& node) const { auto& info = captured_func_->short_circuit_info(); if (!info.indices.empty()) { @@ -952,7 +954,7 @@ void InstantiatedCapturedFunction::RunAsync( // potentially do a non-trivial amount of (e.g. copying) work, and we may // want to run that concurrently with the next invocation. Status s = RunShortCircuit(info, std::move(args), captured_func_, rets); - (*ctx->runner())( + runner( std::bind([s](FunctionLibraryRuntime::DoneCallback& done) { done(s); }, std::move(done))); return; @@ -971,18 +973,18 @@ void InstantiatedCapturedFunction::RunAsync( resource_mgr->Cleanup(name).IgnoreError(); }); f_opts.step_container = step_container; - f_opts.runner = ctx->runner(); + f_opts.runner = &runner; f_opts.create_rendezvous = ShouldCreateRendezvous(); auto cancellation_manager = - std::make_unique(ctx->cancellation_manager()); + std::make_unique(parent_cancellation_manager); f_opts.cancellation_manager = cancellation_manager.get(); - f_opts.collective_executor = ctx->collective_executor(); + f_opts.collective_executor = collective_executor; std::shared_ptr stats_collector; - if (node || ctx->stats_aggregator()) { + if (node) { stats_collector = std::make_shared(); } - const bool collect_usage = node && ctx->model(); + const bool collect_usage = node != nullptr; f_opts.stats_collector = stats_collector.get(); // Transfer ownership of the cancellation manager to `callback`. @@ -992,7 +994,6 @@ void InstantiatedCapturedFunction::RunAsync( [this, rets, step_container, raw_cancellation_manager, frame, node, collect_usage]( const FunctionLibraryRuntime::DoneCallback& done, - IteratorContext* ctx, const std::shared_ptr& stats_collector, // Begin unbound arguments. Status s) { @@ -1003,18 +1004,6 @@ void InstantiatedCapturedFunction::RunAsync( } delete frame; if (node) { - // TODO(b/129085499) Utilize the `node_name` which would be unique - // than the prefix for the function execution time statistics. - // prefix_with_func_name would then be node_name + func_name. - if (ctx->stats_aggregator()) { - string prefix_with_func_name = - strings::StrCat(node->name(), stats_utils::kDelimiter, - captured_func_->func().name()); - ctx->stats_aggregator()->AddToHistogram( - stats_utils::ExecutionTimeHistogramName(prefix_with_func_name), - {static_cast(stats_collector->processing_time())}, - node->num_elements()); - } node->add_processing_time(stats_collector->processing_time()); } if (collect_usage) { @@ -1025,7 +1014,7 @@ void InstantiatedCapturedFunction::RunAsync( node->record_stop(EnvTime::NowNanos()); } }, - std::move(done), ctx, std::move(stats_collector), std::placeholders::_1); + std::move(done), std::move(stats_collector), std::placeholders::_1); tsl::profiler::TraceMe activity( [&] { diff --git a/tensorflow/core/data/captured_function.h b/tensorflow/core/data/captured_function.h index e415c546f970ae..854d9fc22cad53 100644 --- a/tensorflow/core/data/captured_function.h +++ b/tensorflow/core/data/captured_function.h @@ -288,6 +288,18 @@ class InstantiatedCapturedFunction { void RunAsync(IteratorContext* ctx, std::vector&& args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done, + const std::shared_ptr& node) const { + RunAsync(*(ctx->runner()), ctx->cancellation_manager(), + ctx->collective_executor(), std::move(args), rets, done, node); + } + + // A version of `RunAsync` that does not take an `IteratorContext` but a + // runner, a cancellation manager, and a collective executor. + void RunAsync(std::function)> runner, + CancellationManager* parent_cancellation_manager, + CollectiveExecutor* collective_executor, + std::vector&& args, std::vector* rets, + FunctionLibraryRuntime::DoneCallback done, const std::shared_ptr& node) const; std::string func_name() const { return captured_func_->func().name(); } diff --git a/tensorflow/core/data/dataset_test_base.cc b/tensorflow/core/data/dataset_test_base.cc index 7e295e367285a3..e770b4fa9df02d 100644 --- a/tensorflow/core/data/dataset_test_base.cc +++ b/tensorflow/core/data/dataset_test_base.cc @@ -348,7 +348,7 @@ Status DatasetOpsTestBase::CreateOpKernel( Status DatasetOpsTestBase::CreateDatasetContext( OpKernel* const dateset_kernel, - gtl::InlinedVector* const inputs, + absl::InlinedVector* const inputs, std::unique_ptr* dataset_context_params, std::unique_ptr* dataset_context) { Status status = CheckOpKernelInput(*dateset_kernel, *inputs); @@ -515,13 +515,13 @@ Status DatasetOpsTestBase::RunFunction( } Status DatasetOpsTestBase::CreateOpKernelContext( - OpKernel* kernel, gtl::InlinedVector* inputs, + OpKernel* kernel, absl::InlinedVector* inputs, std::unique_ptr* context) { return CreateOpKernelContext(kernel, inputs, ¶ms_, context); } Status DatasetOpsTestBase::CreateOpKernelContext( - OpKernel* kernel, gtl::InlinedVector* inputs, + OpKernel* kernel, absl::InlinedVector* inputs, std::unique_ptr* context_params, std::unique_ptr* context) { auto params = std::make_unique(); @@ -565,7 +565,7 @@ Status DatasetOpsTestBase::CreateSerializationContext( } Status DatasetOpsTestBase::CheckOpKernelInput( - const OpKernel& kernel, const gtl::InlinedVector& inputs) { + const OpKernel& kernel, const absl::InlinedVector& inputs) { if (kernel.num_inputs() != inputs.size()) { return errors::InvalidArgument("The number of input elements should be ", kernel.num_inputs(), @@ -575,7 +575,7 @@ Status DatasetOpsTestBase::CheckOpKernelInput( } Status DatasetOpsTestBase::AddDatasetInput( - gtl::InlinedVector* inputs, DataTypeVector input_types, + absl::InlinedVector* inputs, DataTypeVector input_types, DataType dtype, const TensorShape& shape) { if (input_types.size() < inputs->size()) { return errors::InvalidArgument("Adding more inputs than types: ", @@ -862,7 +862,7 @@ Status DatasetOpsTestBase::RunDatasetOp( input_datasets.push_back(t.get()); created_tensors->push_back(std::move(t)); } - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.reserve(input_datasets.size()); for (auto input_dataset : input_datasets) { inputs.emplace_back(TensorValue(input_dataset)); @@ -985,7 +985,7 @@ Status DatasetOpsTestBase::MakeDatasetTensor( TF_RETURN_IF_ERROR(dataset_params.GetAttributes(&attributes)); auto input_tensors = dataset_params.GetInputTensors(); - gtl::InlinedVector inputs; + absl::InlinedVector inputs; inputs.reserve(input_datasets.size() + input_tensors.size()); for (auto input_dataset : input_datasets) { inputs.emplace_back(TensorValue(input_dataset)); @@ -1165,7 +1165,7 @@ std::vector TensorSliceDatasetParams::TensorSliceShapes( const std::vector& input_components) { std::vector shapes; for (const auto& component : input_components) { - gtl::InlinedVector partial_dim_sizes; + absl::InlinedVector partial_dim_sizes; for (int i = 1; i < component.dims(); ++i) { partial_dim_sizes.push_back(component.dim_size(i)); } diff --git a/tensorflow/core/data/dataset_test_base.h b/tensorflow/core/data/dataset_test_base.h index ec9805b806fe0d..e7278237d9f130 100644 --- a/tensorflow/core/data/dataset_test_base.h +++ b/tensorflow/core/data/dataset_test_base.h @@ -766,7 +766,7 @@ class DatasetOpsTestBase : public ::testing::Test { // Creates a new op kernel context. Status CreateDatasetContext( - OpKernel* dateset_kernel, gtl::InlinedVector* inputs, + OpKernel* dateset_kernel, absl::InlinedVector* inputs, std::unique_ptr* dataset_context_params, std::unique_ptr* dataset_context); @@ -798,16 +798,16 @@ class DatasetOpsTestBase : public ::testing::Test { // Checks that the size of `inputs` matches the requirement of the op kernel. Status CheckOpKernelInput(const OpKernel& kernel, - const gtl::InlinedVector& inputs); + const absl::InlinedVector& inputs); // Creates a new context for running the dataset operation. Status CreateOpKernelContext(OpKernel* kernel, - gtl::InlinedVector* inputs, + absl::InlinedVector* inputs, std::unique_ptr* context); // Creates a new context for running the dataset operation. Status CreateOpKernelContext(OpKernel* kernel, - gtl::InlinedVector* inputs, + absl::InlinedVector* inputs, std::unique_ptr* params, std::unique_ptr* context); @@ -856,7 +856,7 @@ class DatasetOpsTestBase : public ::testing::Test { // Adds an empty tensor with the specified dtype and shape to the input // vector. - Status AddDatasetInput(gtl::InlinedVector* inputs, + Status AddDatasetInput(absl::InlinedVector* inputs, DataTypeVector input_types, DataType dtype, const TensorShape& shape); diff --git a/tensorflow/core/data/dataset_utils.cc b/tensorflow/core/data/dataset_utils.cc index cc7ed17cdd767a..19345990f355f8 100644 --- a/tensorflow/core/data/dataset_utils.cc +++ b/tensorflow/core/data/dataset_utils.cc @@ -1018,7 +1018,7 @@ REGISTER_DATASET_EXPERIMENT("file_locality_v2", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("no_compression", RandomJobSamplePercentage<0>, AllTasks); -REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<50>, +REGISTER_DATASET_EXPERIMENT("no_compression_v2", RandomJobSamplePercentage<0>, AllTasks); REGISTER_DATASET_EXPERIMENT("inject_io_prefetch", RandomJobSamplePercentage<0>, AllTasks); diff --git a/tensorflow/core/data/dataset_utils_test.cc b/tensorflow/core/data/dataset_utils_test.cc index e581f6e3cbe3e8..2e107eb29b0778 100644 --- a/tensorflow/core/data/dataset_utils_test.cc +++ b/tensorflow/core/data/dataset_utils_test.cc @@ -359,11 +359,10 @@ TEST_P(GetExperimentsOptTest, DatasetUtils) { auto opt_ins = test_case.opt_ins; auto opt_outs = test_case.opt_outs; if (!opt_ins.empty()) { - setenv("TF_DATA_EXPERIMENT_OPT_IN", str_util::Join(opt_ins, ",").c_str(), - 1); + setenv("TF_DATA_EXPERIMENT_OPT_IN", absl::StrJoin(opt_ins, ",").c_str(), 1); } if (!opt_outs.empty()) { - setenv("TF_DATA_EXPERIMENT_OPT_OUT", str_util::Join(opt_outs, ",").c_str(), + setenv("TF_DATA_EXPERIMENT_OPT_OUT", absl::StrJoin(opt_outs, ",").c_str(), 1); } const std::string job_name = "job"; @@ -376,14 +375,14 @@ TEST_P(GetExperimentsOptTest, DatasetUtils) { for (const auto& experiment : test_case.expected_in) { EXPECT_TRUE(experiment_set.find(experiment) != experiment_set.end()) << "experiment=" << experiment << " opt_ins={" - << str_util::Join(opt_ins, ",") << "} opt_outs={" - << str_util::Join(opt_outs, ",") << "}"; + << absl::StrJoin(opt_ins, ",") << "} opt_outs={" + << absl::StrJoin(opt_outs, ",") << "}"; } for (const auto& experiment : test_case.expected_out) { EXPECT_TRUE(experiment_set.find(experiment) == experiment_set.end()) << "experiment=" << experiment << " opt_ins={" - << str_util::Join(opt_ins, ",") << "} opt_outs={" - << str_util::Join(opt_outs, ",") << "}"; + << absl::StrJoin(opt_ins, ",") << "} opt_outs={" + << absl::StrJoin(opt_outs, ",") << "}"; } if (!opt_ins.empty()) { diff --git a/tensorflow/core/data/global_shuffle_utils.cc b/tensorflow/core/data/global_shuffle_utils.cc index 132a35f1d10620..dc4256378a29a2 100644 --- a/tensorflow/core/data/global_shuffle_utils.cc +++ b/tensorflow/core/data/global_shuffle_utils.cc @@ -16,10 +16,13 @@ limitations under the License. #include #include +#include #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" @@ -29,6 +32,13 @@ limitations under the License. namespace tensorflow { namespace data { +namespace { + +constexpr absl::string_view kGlobalShuffleIteratorNextIndex = + "global_shuffle_iterator_next_index"; + +} + IteratorContextWithIndexMapper::IteratorContextWithIndexMapper( IteratorContext* ctx, const IteratorBase* iterator) : ctx_(ctx) { @@ -60,10 +70,22 @@ absl::Status GlobalShuffleIterator::GetNext(IteratorContext* ctx, } absl::MutexLock l(&mu_); - TF_ASSIGN_OR_RETURN(int64_t output_index, - ctx->index_mapper()(element_count_++)); + absl::StatusOr shuffled_index = + absl::NotFoundError("Default not found"); + + while (absl::IsNotFound(shuffled_index.status())) { + shuffled_index = ctx->index_mapper()(element_count_++); + } + + if (absl::IsOutOfRange(shuffled_index.status())) { + *end_of_sequence = true; + return absl::OkStatus(); + } + + TF_RETURN_IF_ERROR(shuffled_index.status()); + absl::Status status = - dataset_->Get(AnyContext(ctx), output_index, out_tensors); + dataset_->Get(AnyContext(ctx), shuffled_index.value(), out_tensors); if (absl::IsOutOfRange(status)) { *end_of_sequence = true; return absl::OkStatus(); @@ -73,7 +95,18 @@ absl::Status GlobalShuffleIterator::GetNext(IteratorContext* ctx, return absl::OkStatus(); } -absl::Status GlobalShuffleIterator::Restore(IteratorContext* ctx) { +absl::Status GlobalShuffleIterator::Save( + const std::string& parent_iterator_prefix, SerializationContext* ctx, + IteratorStateWriter* writer) { + absl::MutexLock l(&mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar( + parent_iterator_prefix, kGlobalShuffleIteratorNextIndex, element_count_)); + return absl::OkStatus(); +} + +absl::Status GlobalShuffleIterator::Restore( + const std::string& parent_iterator_prefix, IteratorContext* ctx, + IteratorStateReader* reader) { if (!ctx->restored_element_count().has_value()) { return absl::FailedPreconditionError(absl::StrCat( "Trying to restore random element count for dataset ", @@ -81,7 +114,9 @@ absl::Status GlobalShuffleIterator::Restore(IteratorContext* ctx) { } absl::MutexLock l(&mu_); - element_count_ = *(ctx->restored_element_count()); + TF_RETURN_IF_ERROR(reader->ReadScalar(parent_iterator_prefix, + kGlobalShuffleIteratorNextIndex, + &element_count_)); return absl::OkStatus(); } diff --git a/tensorflow/core/data/global_shuffle_utils.h b/tensorflow/core/data/global_shuffle_utils.h index 91b4fa085b8aab..c7513a0238dabb 100644 --- a/tensorflow/core/data/global_shuffle_utils.h +++ b/tensorflow/core/data/global_shuffle_utils.h @@ -75,9 +75,13 @@ class GlobalShuffleIterator { absl::Status GetNext(IteratorContext* ctx, std::vector* out_tensors, bool* end_of_sequence); + absl::Status Save(const std::string& parent_iterator_prefix, + SerializationContext* ctx, IteratorStateWriter* writer); + // Restores the element count. // REQUIRES: ctx->restored_element_count() != nullopt. - absl::Status Restore(IteratorContext* ctx); + absl::Status Restore(const std::string& parent_iterator_prefix, + IteratorContext* ctx, IteratorStateReader* reader); private: const DatasetBase* const dataset_; diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index fe265ba3b07949..8309b8cdd210a6 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -680,8 +680,6 @@ cc_library( # copybara:uncomment copts = ["-Wthread-safety-analysis"], deps = [ ":credentials_factory", - "//tensorflow/core:framework", - "//tensorflow/core/data:dataset_utils", ], ) @@ -758,8 +756,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -1055,6 +1053,7 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:platform_port", ] + tf_grpc_cc_dependencies(), ) diff --git a/tensorflow/core/data/service/auto_scaler_test.cc b/tensorflow/core/data/service/auto_scaler_test.cc index c04ea49d216bf6..299715d0c48a4e 100644 --- a/tensorflow/core/data/service/auto_scaler_test.cc +++ b/tensorflow/core/data/service/auto_scaler_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/lib/monitoring/cell_reader.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace tensorflow { diff --git a/tensorflow/core/data/service/client/BUILD b/tensorflow/core/data/service/client/BUILD index 16bd0efd808e78..60e2da30b5a8ac 100644 --- a/tensorflow/core/data/service/client/BUILD +++ b/tensorflow/core/data/service/client/BUILD @@ -57,7 +57,6 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", - "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:retrying_utils", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], @@ -120,10 +119,10 @@ tf_cc_test( "//tensorflow/core/data/service:dispatcher_client", "//tensorflow/core/data/service:test_cluster", "//tensorflow/core/data/service:test_util", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/protobuf:protos_all_cc", + "@local_xla//xla/tsl/lib/core:status_test_util", ] + tf_grpc_cc_dependencies() + tf_protos_profiler_service(), ) diff --git a/tensorflow/core/data/service/client/data_service_client.cc b/tensorflow/core/data/service/client/data_service_client.cc index dbd65ea37d7a61..a323a02b6096bc 100644 --- a/tensorflow/core/data/service/client/data_service_client.cc +++ b/tensorflow/core/data/service/client/data_service_client.cc @@ -53,7 +53,6 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme_encode.h" -#include "tsl/platform/host_info.h" #include "tsl/platform/retrying_utils.h" #include "tsl/protobuf/error_codes.pb.h" @@ -381,9 +380,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback( absl::StatusOr> DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) { if (params_.data_transfer_protocol == kLocalTransferProtocol || - // TODO(b/291994182): Use remote workers in unit tests. - (tsl::port::JobUid() != -1 && - LocalWorkers::Get(task_info.worker_address()) != nullptr)) { + ForceLocalProtocol(task_info.worker_address())) { DataTransferServerInfo info; info.set_protocol(kLocalTransferProtocol); info.set_address(task_info.worker_address()); diff --git a/tensorflow/core/data/service/client/data_service_client_test.cc b/tensorflow/core/data/service/client/data_service_client_test.cc index 8ec654b33eabde..09b1edede48d20 100644 --- a/tensorflow/core/data/service/client/data_service_client_test.cc +++ b/tensorflow/core/data/service/client/data_service_client_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/client/common.h" #include "tensorflow/core/data/service/common.h" #include "tensorflow/core/data/service/test_cluster.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/data_service.pb.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/data/service/client/utils_test.cc b/tensorflow/core/data/service/client/utils_test.cc index c3d945163a90a1..8729bff56cbcb2 100644 --- a/tensorflow/core/data/service/client/utils_test.cc +++ b/tensorflow/core/data/service/client/utils_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/test_cluster.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/data_service.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/data/service/dispatcher_state_test.cc b/tensorflow/core/data/service/dispatcher_state_test.cc index b1341be1e546b1..e561ecb4dd08c2 100644 --- a/tensorflow/core/data/service/dispatcher_state_test.cc +++ b/tensorflow/core/data/service/dispatcher_state_test.cc @@ -21,13 +21,13 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/journal.pb.h" #include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/data_service.pb.h" #include "tensorflow/core/protobuf/service_config.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace tensorflow { diff --git a/tensorflow/core/data/service/py_utils.cc b/tensorflow/core/data/service/py_utils.cc index d14e1c9d1ed2cf..be5308df607f98 100644 --- a/tensorflow/core/data/service/py_utils.cc +++ b/tensorflow/core/data/service/py_utils.cc @@ -17,9 +17,7 @@ limitations under the License. #include -#include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/service/credentials_factory.h" -#include "tensorflow/core/framework/metrics.h" namespace tensorflow { namespace data { @@ -39,17 +37,5 @@ std::string DefaultProtocol() { return "grpc"; } -bool DisableCompressionAtRegistrationTime() { -#if defined(PLATFORM_GOOGLE) - if (!GetExperiments().contains("no_compression_v2")) { - return false; - } - metrics::RecordTFDataServiceCompressionAction( - "disabled_at_registration_time"); - return true; -#endif // PLATFORM_GOOGLE - return false; -} - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/py_utils.h b/tensorflow/core/data/service/py_utils.h index 010c155022fee3..b0ea8928a3af4e 100644 --- a/tensorflow/core/data/service/py_utils.h +++ b/tensorflow/core/data/service/py_utils.h @@ -27,10 +27,6 @@ namespace data { // Returns the default protocol to use for tf.data service control flow. std::string DefaultProtocol(); -// Returns `true` if tf.data service compression is to be disabled at -// registration time. -bool DisableCompressionAtRegistrationTime(); - } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/snapshot/BUILD b/tensorflow/core/data/service/snapshot/BUILD index 40b5cbaa6873aa..35d28b0841e7d9 100644 --- a/tensorflow/core/data/service/snapshot/BUILD +++ b/tensorflow/core/data/service/snapshot/BUILD @@ -150,13 +150,13 @@ tf_cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc index 71e9bfc68ad55e..8974964c9b3a81 100644 --- a/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc +++ b/tensorflow/core/data/service/snapshot/distributed_snapshot_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/dispatcher_client.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/snapshot/test_utils.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/file_utils_test.cc b/tensorflow/core/data/service/snapshot/file_utils_test.cc index 9bf1e5257d6a95..9582cab18bc143 100644 --- a/tensorflow/core/data/service/snapshot/file_utils_test.cc +++ b/tensorflow/core/data/service/snapshot/file_utils_test.cc @@ -18,13 +18,13 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc index d55e7d1f8b0b82..43944c6a41b8f1 100644 --- a/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/parallel_tfrecord_writer_test.cc @@ -30,10 +30,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc index 0f0c9f9cc62840..1e019a1742651e 100644 --- a/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/prefetched_split_provider_test.cc @@ -31,13 +31,13 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/split_provider.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc index 4dc6f2342e2ea9..e40fd0ad918387 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_chunk_provider_test.cc @@ -26,13 +26,13 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_tensor_data.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc index 9e116536ad07e5..65b3c59e8ecba4 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_manager_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc index 96c7c3e21be098..5a6f8200b589ab 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider.cc @@ -49,8 +49,9 @@ constexpr char kNextSplitIndex[] = "next_split_index"; constexpr char kRepetitionIndex[] = "repetition_index"; absl::StatusOr GetRepetitionIndex(const std::string& split_file) { - tsl::StringPiece repetition_dir_path = tsl::io::Dirname(split_file); - tsl::StringPiece repetition_dir_name = tsl::io::Basename(repetition_dir_path); + absl::string_view repetition_dir_path = tsl::io::Dirname(split_file); + absl::string_view repetition_dir_name = + tsl::io::Basename(repetition_dir_path); return ParseRepetitionDirectoryName(repetition_dir_name); } } // namespace diff --git a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc index 2d6a9fdf702962..b9b9f3d3d4d8ed 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_split_provider_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/serialization_utils.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/dispatcher_client.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/protobuf/snapshot.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc index 85d401b5fdccd8..071c7e1f1c72a1 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_checkpoint_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/path_utils.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/data/service/snapshot/test_utils.h" #include "tensorflow/core/data/service/task_runner.h" #include "tensorflow/core/data/service/test_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" diff --git a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc index 84243d8f67b3f4..f9183410648dd8 100644 --- a/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc +++ b/tensorflow/core/data/service/snapshot/snapshot_stream_writer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/byte_size.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/snapshot/file_utils.h" @@ -35,7 +36,6 @@ limitations under the License. #include "tensorflow/core/data/snapshot_utils.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/io/compression.h" #include "tsl/lib/monitoring/cell_reader.h" #include "tsl/platform/env.h" diff --git a/tensorflow/core/data/service/split_provider_test.cc b/tensorflow/core/data/service/split_provider_test.cc index 08adc907058af2..d311db235b9dff 100644 --- a/tensorflow/core/data/service/split_provider_test.cc +++ b/tensorflow/core/data/service/split_provider_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/framework/dataset.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/data/service/worker_client.cc b/tensorflow/core/data/service/worker_client.cc index 871d549a729c3b..673bc59976c814 100644 --- a/tensorflow/core/data/service/worker_client.cc +++ b/tensorflow/core/data/service/worker_client.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" +#include "tsl/platform/host_info.h" namespace tensorflow { namespace data { @@ -91,7 +92,7 @@ Status DataServiceWorkerClient::EnsureInitialized() { } std::string DataServiceWorkerClient::GetDataTransferProtocol() const { - if (LocalWorkers::Get(address_) != nullptr) { + if (ForceLocalProtocol(address_)) { return kLocalTransferProtocol; } return transfer_protocol_; @@ -275,5 +276,13 @@ class LocalTransferClientRegistrar { }; static LocalTransferClientRegistrar local_client_registrar; +bool ForceLocalProtocol(const std::string& worker_address) { + // TODO(b/291994182): Use remote workers in unit tests. + if (tsl::port::JobUid() == -1) { + return false; + } + return LocalWorkers::Get(worker_address) != nullptr; +} + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/service/worker_client.h b/tensorflow/core/data/service/worker_client.h index 0799ab72999044..014afdc6a98d1c 100644 --- a/tensorflow/core/data/service/worker_client.h +++ b/tensorflow/core/data/service/worker_client.h @@ -22,11 +22,8 @@ limitations under the License. #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/worker.pb.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace data { @@ -85,6 +82,10 @@ CreateDataServiceWorkerClient( const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info, Allocator* allocator); +// If true, clients should use local protocol for data transfer (disregarding +// any other user-specified or runtime-defaulted protocol). +bool ForceLocalProtocol(const std::string& worker_address); + } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/data/snapshot_utils.cc b/tensorflow/core/data/snapshot_utils.cc index a1dd6179cc8e50..8874484c835af0 100644 --- a/tensorflow/core/data/snapshot_utils.cc +++ b/tensorflow/core/data/snapshot_utils.cc @@ -579,7 +579,7 @@ class Reader::NestedDataset : public DatasetBase { std::vector datasets) : DatasetBase(std::move(ctx)), datasets_(datasets) { dtypes_.push_back(DT_VARIANT); - gtl::InlinedVector element_dim_sizes; + absl::InlinedVector element_dim_sizes; element_dim_sizes.push_back(1); partial_shapes_.emplace_back(element_dim_sizes); } @@ -859,9 +859,9 @@ Status CustomReader::Initialize(Env* env) { } Status CustomReader::ReadTensors(std::vector* read_tensors) { - profiler::TraceMe activity( + tsl::profiler::TraceMe activity( [&]() { return absl::StrCat(kClassName, kSeparator, "ReadTensors"); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); if (version_ == 0 || compression_type_ != io::compression::kSnappy) { return ReadTensorsV0(read_tensors); } diff --git a/tensorflow/core/data/standalone_save_restore_test.cc b/tensorflow/core/data/standalone_save_restore_test.cc index fd163691d71f88..9798021302614f 100644 --- a/tensorflow/core/data/standalone_save_restore_test.cc +++ b/tensorflow/core/data/standalone_save_restore_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/test_util.h" #include "tensorflow/core/data/standalone.h" #include "tensorflow/core/framework/tensor.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/data/standalone_test.cc b/tensorflow/core/data/standalone_test.cc index 54f438b1cc2308..fac2a9eeb6e6a8 100644 --- a/tensorflow/core/data/standalone_test.cc +++ b/tensorflow/core/data/standalone_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index 6c8211537271f9..fb01f411380311 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -182,11 +182,8 @@ tf_cc_test( ":debug_grpc_testlib", ":debug_io_utils", ":debug_node_key", - ":debug_service_proto_cc", ":debugger_event_metadata_proto_cc", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -194,7 +191,6 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//tensorflow/core/platform/default/build_config:platformlib", ], ) @@ -260,7 +256,6 @@ tf_cc_test( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc index 3eaf3651126528..87aea157cdb04c 100644 --- a/tensorflow/core/debug/debug_grpc_io_utils_test.cc +++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/core/debug/debug_graph_utils.h" #include "tensorflow/core/debug/debug_grpc_testlib.h" #include "tensorflow/core/debug/debug_io_utils.h" @@ -47,10 +49,10 @@ class GrpcDebugTest : public ::testing::Test { int64_t server_start_delay_micros) { server_data->port = testing::PickUnusedPortOrDie(); server_data->url = strings::StrCat("grpc://localhost:", server_data->port); - server_data->server.reset(new test::TestEventListenerImpl()); + server_data->server = std::make_unique(); - server_data->thread_pool.reset( - new thread::ThreadPool(Env::Default(), "test_server", 1)); + server_data->thread_pool = + std::make_unique(Env::Default(), "test_server", 1); server_data->thread_pool->Schedule( [server_data, server_start_delay_micros]() { Env::Default()->SleepForMicroseconds(server_start_delay_micros); diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h index 18009a33c69547..2a57df8d866331 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.h +++ b/tensorflow/core/debug/debug_grpc_testlib.h @@ -39,7 +39,7 @@ class TestEventListenerImpl final : public grpc::EventListener::Service { ::grpc::Status SendEvents( ::grpc::ServerContext* context, ::grpc::ServerReaderWriter< ::tensorflow::EventReply, - ::tensorflow::Event>* stream); + ::tensorflow::Event>* stream) override; // Clear debug data (e.g., Tensors) received so far. void ClearReceivedDebugData(); diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index dad4360c865e36..74d5758c306ef1 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/debug/debug_io_utils.h" + #include +#include #include -#include "tensorflow/core/debug/debug_io_utils.h" - #include "tensorflow/core/debug/debug_callback_registry.h" #include "tensorflow/core/debug/debug_node_key.h" #include "tensorflow/core/debug/debugger_event_metadata.pb.h" @@ -40,7 +41,7 @@ class DebugIOUtilsTest : public ::testing::Test { void Initialize() { env_ = Env::Default(); - tensor_a_.reset(new Tensor(DT_FLOAT, TensorShape({2, 2}))); + tensor_a_ = std::make_unique(DT_FLOAT, TensorShape({2, 2})); tensor_a_->flat()(0) = 5.0; tensor_a_->flat()(1) = 3.0; tensor_a_->flat()(2) = -1.0; diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h index 9de9bdc2f290c3..4114d68549e2f0 100644 --- a/tensorflow/core/debug/debugger_state_impl.h +++ b/tensorflow/core/debug/debugger_state_impl.h @@ -26,7 +26,7 @@ namespace tensorflow { class DebuggerState : public DebuggerStateInterface { public: DebuggerState(const DebugOptions& debug_options); - virtual ~DebuggerState(); + ~DebuggerState() override; // Publish metadata about the debugged Session::Run() call. // @@ -47,7 +47,7 @@ class DebugGraphDecorator : public DebugGraphDecoratorInterface { public: DebugGraphDecorator(const DebugOptions& debug_options) : debug_options_(debug_options) {} - virtual ~DebugGraphDecorator() {} + ~DebugGraphDecorator() override {} Status DecorateGraph(Graph* graph, Device* device) override; Status PublishGraph(const Graph& graph, const string& device_name) override; diff --git a/tensorflow/core/distributed_runtime/integration_test/BUILD b/tensorflow/core/distributed_runtime/integration_test/BUILD index 4927d6fd3cdf58..7408bcbfdc9f71 100644 --- a/tensorflow/core/distributed_runtime/integration_test/BUILD +++ b/tensorflow/core/distributed_runtime/integration_test/BUILD @@ -52,7 +52,7 @@ tf_cuda_cc_test( "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:env", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -168,6 +168,6 @@ tf_cuda_cc_test( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/platform:env", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc index 4e0bd6f5ed4c08..356f0a08412fd9 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_coordination_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/blocking_counter.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc index ba48750cb44716..cffe93d297a8df 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_multi_client_test.cc @@ -185,8 +185,7 @@ TEST(CAPI, MultiClientSendRecv) { if (worker_id == 0) { TFE_TensorHandle* in = TestMatrixTensorHandle(ctx); const std::string& op_name = - tensorflow::str_util::StrContains(send_device, "GPU") ? "Send" - : "_HostSend"; + absl::StrContains(send_device, "GPU") ? "Send" : "_HostSend"; TFE_Op* sendop = SendOp(ctx, in, op_name, send_device, recv_device, send_device_incarnation); TFE_TensorHandle* retvals[1]; @@ -197,8 +196,7 @@ TEST(CAPI, MultiClientSendRecv) { TFE_DeleteTensorHandle(in); } else { const std::string& op_name = - tensorflow::str_util::StrContains(send_device, "GPU") ? "Recv" - : "_HostRecv"; + absl::StrContains(send_device, "GPU") ? "Recv" : "_HostRecv"; TFE_Op* recvop = RecvOp(ctx, op_name, send_device, recv_device, send_device_incarnation); TFE_TensorHandle* retvals[1]; diff --git a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc index dde7f65419b646..3d9ff3c459181f 100644 --- a/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc +++ b/tensorflow/core/distributed_runtime/integration_test/c_api_recoverable_jobs_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/c_api_test_util.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/platform/strcat.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/protobuf/cluster.pb.h" #include "tensorflow/core/protobuf/rewriter_config.pb.h" #include "tensorflow/core/protobuf/tensorflow_server.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/protobuf/coordination_config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 7777bc6725356d..2fcbc724b49884 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1868,14 +1868,17 @@ Status MasterSession::CreateDebuggerState( DebuggerStateRegistry::CreateState(debug_options, debugger_state)); std::vector input_names; + input_names.reserve(req.num_feeds()); for (size_t i = 0; i < req.num_feeds(); ++i) { input_names.push_back(req.feed_name(i)); } std::vector output_names; + output_names.reserve(req.num_fetches()); for (size_t i = 0; i < req.num_fetches(); ++i) { output_names.push_back(req.fetch_name(i)); } std::vector target_names; + target_names.reserve(req.num_targets()); for (size_t i = 0; i < req.num_targets(); ++i) { target_names.push_back(req.target_name(i)); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 56eb09ea4d20c7..c1026dc273136c 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/protobuf/error_codes.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/port.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/example/CMakeLists.txt b/tensorflow/core/example/CMakeLists.txt new file mode 100644 index 00000000000000..2450c9eddd5107 --- /dev/null +++ b/tensorflow/core/example/CMakeLists.txt @@ -0,0 +1,50 @@ +# +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if (NOT TARGET protobuf::libprotobuf) + find_package(Protobuf REQUIRED) +endif() + +set(GEN_PROTO_DIR ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/core/example) + +# Generate feature proto .h, .cc and lib. +list(APPEND feature_generated_files ${GEN_PROTO_DIR}/feature.pb.h ${GEN_PROTO_DIR}/feature.pb.cc) + +add_custom_command( + OUTPUT ${feature_generated_files} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} --proto_path=${TENSORFLOW_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/feature.proto + DEPENDS ${Protobuf_PROTOC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/feature.proto +) + +set_source_files_properties(${feature_generated_files} PROPERTIES GENERATED TRUE) +add_library(feature_proto ${feature_generated_files}) +target_link_libraries(feature_proto protobuf::libprotobuf) +target_include_directories(feature_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) + +# Generate example proto .h, .cc and lib. +list(APPEND example_generated_files ${GEN_PROTO_DIR}/example.pb.h ${GEN_PROTO_DIR}/example.pb.cc) + +add_custom_command( + OUTPUT ${example_generated_files} + COMMAND ${Protobuf_PROTOC_EXECUTABLE} + ARGS --cpp_out=${CMAKE_CURRENT_BINARY_DIR} --proto_path=${TENSORFLOW_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/example.proto + DEPENDS ${Protobuf_PROTOC_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/example.proto ${feature_generated_files} +) + +set_source_files_properties(${example_generated_files} PROPERTIES GENERATED TRUE) +add_library(example_proto ${example_generated_files}) +target_link_libraries(example_proto feature_proto protobuf::libprotobuf) +target_include_directories(example_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) \ No newline at end of file diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index a0b39c7b45d312..526085c82e73bb 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -1948,7 +1948,7 @@ tf_cc_fuzz_test( deps = [ "//tensorflow/core:framework", "//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -1962,7 +1962,7 @@ tf_cc_fuzz_test( "//tensorflow/security/fuzzing/cc/core/framework:datatype_domains", "//tensorflow/security/fuzzing/cc/core/framework:tensor_domains", "//tensorflow/security/fuzzing/cc/core/framework:tensor_shape_domains", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/framework/allocator_test.cc b/tensorflow/core/framework/allocator_test.cc index 7e85b25a9df6f7..6557a4cec7598e 100644 --- a/tensorflow/core/framework/allocator_test.cc +++ b/tensorflow/core/framework/allocator_test.cc @@ -236,7 +236,7 @@ TEST(CPUAllocatorTest, ProfilerReporting) { // Get profiling results tensorflow::profiler::XSpace xspace; - EXPECT_EQ(OkStatus(), profiler->CollectData(&xspace)); + EXPECT_EQ(absl::OkStatus(), profiler->CollectData(&xspace)); // Validate the output const auto plane = ::tsl::profiler::FindPlaneWithName( diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc index 31bddbea68f93b..351ba293276456 100644 --- a/tensorflow/core/framework/attr_value_util.cc +++ b/tensorflow/core/framework/attr_value_util.cc @@ -446,7 +446,7 @@ Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) { } } - return OkStatus(); + return absl::OkStatus(); } bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) { @@ -530,7 +530,7 @@ void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; } DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, FIELD) DEFINE_SET_ATTR_VALUE_ONE(const string&, s) -DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice, s) +DEFINE_SET_ATTR_VALUE_LIST(absl::Span, s) DEFINE_SET_ATTR_VALUE_BOTH(const char*, s) DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i) DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i) @@ -545,7 +545,7 @@ void SetAttrValue(const tstring& value, AttrValue* out) { out->set_s(value.data(), value.size()); } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); for (const auto& v : value) { out->mutable_list()->add_s(v.data(), v.size()); @@ -556,7 +556,7 @@ void SetAttrValue(StringPiece value, AttrValue* out) { out->set_s(value.data(), value.size()); } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { out->mutable_list()->add_s(v.data(), v.size()); @@ -582,21 +582,21 @@ void SetAttrValue(const PartialTensorShape& value, AttrValue* out) { value.AsProto(out->mutable_shape()); } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { v.AsProto(out->mutable_list()->add_shape()); } } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_shape() = v; } } -void SetAttrValue(const gtl::ArraySlice value, +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { @@ -612,7 +612,7 @@ void SetAttrValue(const Tensor& value, AttrValue* out) { } } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { if (v.NumElements() > 1) { @@ -627,7 +627,7 @@ void SetAttrValue(const TensorProto& value, AttrValue* out) { *out->mutable_tensor() = value; } -void SetAttrValue(const gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(const absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_tensor() = v; @@ -638,7 +638,7 @@ void SetAttrValue(const NameAttrList& value, AttrValue* out) { *out->mutable_func() = value; } -void SetAttrValue(gtl::ArraySlice value, AttrValue* out) { +void SetAttrValue(absl::Span value, AttrValue* out) { out->mutable_list()->Clear(); // Create list() even if value empty. for (const auto& v : value) { *out->mutable_list()->add_func() = v; diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc index bf99f54baa6b45..996acd12d78b3b 100644 --- a/tensorflow/core/framework/collective.cc +++ b/tensorflow/core/framework/collective.cc @@ -207,7 +207,7 @@ Status CollectiveRegistry::Register(const string& collective_name, collective_name); } registry->emplace_back(collective_name, std::move(factory)); - return OkStatus(); + return absl::OkStatus(); } /*static*/ @@ -222,7 +222,7 @@ Status CollectiveRegistry::LookupHelper( } else { *implementation = reg_info.factory(); } - return OkStatus(); + return absl::OkStatus(); } } return errors::Internal( diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index be1bfc2581e2aa..b400203013b0b2 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -84,7 +84,7 @@ Status GetWindowedOutputSizeFromDimsV2( /*evenly_divisible=*/false, output_size)); break; } - return OkStatus(); + return absl::OkStatus(); } Status GetWindowedOutputSizeFromDims( @@ -112,7 +112,7 @@ Status UnchangedShape(shape_inference::InferenceContext* c) { if (handle_data != nullptr) { c->set_output_handle_shapes_and_types(0, *handle_data); } - return OkStatus(); + return absl::OkStatus(); } Status MatMulShape(shape_inference::InferenceContext* c) { @@ -135,7 +135,7 @@ Status MatMulShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); c->set_output(0, c->Matrix(output_rows, output_cols)); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -155,7 +155,7 @@ Status ValidateEinsumEllipsis(absl::string_view subscript, "Periods found outside of ellipsis in subscript: ", subscript); } *found_ellipsis = num_periods > 0; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -166,7 +166,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // more latin alphabets and contains at most one ellipsis ('...'). string equation; TF_RETURN_IF_ERROR(c->GetAttr("equation", &equation)); - gtl::InlinedVector input_labels; + absl::InlinedVector input_labels; string output_labels; TF_RETURN_IF_ERROR( ValidateEinsumEquation(equation, &input_labels, &output_labels)); @@ -185,7 +185,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // Validate input subscripts, build the label to dimension mapping and obtain // the broadcast shapes that map to ellipsis. absl::flat_hash_map label_to_dimension; - gtl::InlinedVector input_bcast_shapes(c->num_inputs()); + absl::InlinedVector input_bcast_shapes(c->num_inputs()); for (int i = 0, end = c->num_inputs(); i < end; ++i) { bool has_ellipsis = false; TF_RETURN_IF_ERROR(ValidateEinsumEllipsis(input_labels[i], &has_ellipsis)); @@ -276,7 +276,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { // unknown, then the output shape should have unknown rank. if (!c->RankKnown(output_bcast_shape)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } } else { // If the output subscripts don't have ellipsis then make sure the output @@ -311,7 +311,7 @@ Status EinsumShape(shape_inference::InferenceContext* c) { output_dims.push_back(dimension_it->second); } c->set_output(0, c->MakeShape(output_dims)); - return OkStatus(); + return absl::OkStatus(); } Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { @@ -348,7 +348,7 @@ Status BatchMatMulV2Shape(shape_inference::InferenceContext* c) { output_batch_shape, c->Matrix(output_rows, output_cols), &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status BatchMatMulShape(shape_inference::InferenceContext* c) { @@ -382,7 +382,7 @@ Status BatchMatMulShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR( c->Concatenate(batch_dims, c->Matrix(output_rows, output_cols), &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // -------------------------------------------------------------------------- @@ -407,7 +407,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { // If rank unknown, return unknown shape. if (!c->RankKnown(input_shape)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } // Output has the same shape as the input, and matches the length of @@ -443,7 +443,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) { } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status BiasAddGradShape(shape_inference::InferenceContext* c) { @@ -460,7 +460,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) { c->set_output(0, c->Vector(c->Dim(input_shape, -1))); } - return OkStatus(); + return absl::OkStatus(); } Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, @@ -479,7 +479,7 @@ Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, } } - return OkStatus(); + return absl::OkStatus(); } Status DatasetIteratorShape(shape_inference::InferenceContext* c) { @@ -499,7 +499,7 @@ Status DatasetIteratorShape(shape_inference::InferenceContext* c) { output_shapes[i], &output_shape_handle)); c->set_output(static_cast(i), output_shape_handle); } - return OkStatus(); + return absl::OkStatus(); } Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, @@ -524,12 +524,12 @@ Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, context->MakeDim(spatial[spatial_dim]); } *out = context->MakeShape(dims_actual); - return OkStatus(); + return absl::OkStatus(); } Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, DimensionHandle* batch_dim, - gtl::MutableArraySlice spatial_dims, + absl::Span spatial_dims, DimensionHandle* filter_dim, InferenceContext* context) { const int32_t rank = @@ -550,12 +550,12 @@ Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), filter_dim)); } - return OkStatus(); + return absl::OkStatus(); } // vect_size must be provided if format is NCHW_VECT_C. Status ShapeFromDimensions(DimensionHandle batch_dim, - gtl::ArraySlice spatial_dims, + absl::Span spatial_dims, DimensionHandle filter_dim, TensorFormat format, absl::optional vect_size, InferenceContext* context, ShapeHandle* shape) { @@ -585,7 +585,7 @@ Status ShapeFromDimensions(DimensionHandle batch_dim, } *shape = context->MakeShape(out_dims); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -652,7 +652,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, DimensionHandle batch_size_dim; DimensionHandle input_depth_dim; - gtl::InlinedVector input_spatial_dims(2); + absl::InlinedVector input_spatial_dims(2); TF_RETURN_IF_ERROR(DimensionsFromShape( conv_input_shape, data_format, &batch_size_dim, absl::MakeSpan(input_spatial_dims), &input_depth_dim, c)); @@ -760,7 +760,7 @@ Status Conv2DShapeImpl(shape_inference::InferenceContext* c, batch_size_dim, {output_rows, output_cols}, output_depth_dim, data_format, vect_size, c, &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -777,7 +777,7 @@ Status ConvShape(shape_inference::InferenceContext* c) { if (input_rank == InferenceContext::kUnknownRank || filter_rank == InferenceContext::kUnknownRank) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int batch_dims; @@ -981,7 +981,7 @@ Status ConvShape(shape_inference::InferenceContext* c) { output_shape = c->MakeShape(output_shape_vector); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } // Shape function for Conv2D-like operations that support explicit padding. @@ -1107,7 +1107,7 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { output_cols, output_depth_dim}); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { @@ -1130,7 +1130,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { DimensionHandle batch_size_dim; DimensionHandle output_grad_depth_dim; - gtl::InlinedVector output_grad_spatial_dims(2); + absl::InlinedVector output_grad_spatial_dims(2); TF_RETURN_IF_ERROR(DimensionsFromShape( output_grad_shape, data_format, &batch_size_dim, absl::MakeSpan(output_grad_spatial_dims), &output_grad_depth_dim, c)); @@ -1151,7 +1151,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { // input_grad_depth_dim from input_sizes; otherwise we compute it as // c->Dim(filter_shape,2). DimensionHandle input_grad_depth_dim; - gtl::InlinedVector specified_input_grad_spatial_dims(2); + absl::InlinedVector specified_input_grad_spatial_dims(2); int specified_input_grad_rank = c->Rank(specified_input_grad_shape); if (specified_input_grad_rank == 4) { DimensionHandle specified_batch_size_dim; @@ -1179,7 +1179,7 @@ Status Conv2DBackpropInputShape(shape_inference::InferenceContext* c) { batch_size_dim, specified_input_grad_spatial_dims, input_grad_depth_dim, data_format, /*vect_size=*/absl::nullopt, c, &input_grad_shape)); c->set_output(0, input_grad_shape); - return OkStatus(); + return absl::OkStatus(); } Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { @@ -1198,7 +1198,7 @@ Status Conv2DBackpropFilterWithBiasShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); c->set_output(0, sh); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -1320,7 +1320,7 @@ Status DepthwiseConv2DNativeShapeImpl(shape_inference::InferenceContext* c, c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } }; // namespace @@ -1400,7 +1400,7 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { {output_rows, output_cols}, depth_dim, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status AvgPoolGradShape(shape_inference::InferenceContext* c) { @@ -1408,7 +1408,7 @@ Status AvgPoolGradShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormShape(shape_inference::InferenceContext* c) { @@ -1450,13 +1450,13 @@ Status FusedBatchNormShape(shape_inference::InferenceContext* c) { c->set_output(2, vector_shape); c->set_output(3, vector_shape); c->set_output(4, vector_shape); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormV3Shape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(FusedBatchNormShape(c)); c->set_output(5, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { @@ -1481,7 +1481,7 @@ Status FusedBatchNormExShape(shape_inference::InferenceContext* c) { "_FusedBatchNormEx channel dimension must be divisible by 4."); } - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { @@ -1522,7 +1522,7 @@ Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { c->set_output(2, c->Vector(channel_dim)); c->set_output(3, c->Vector(0)); c->set_output(4, c->Vector(0)); - return OkStatus(); + return absl::OkStatus(); } Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { @@ -1531,7 +1531,7 @@ Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { int num_side_inputs; TF_RETURN_IF_ERROR(c->GetAttr("num_side_inputs", &num_side_inputs)); if (num_side_inputs == 0) { - return OkStatus(); + return absl::OkStatus(); } string data_format_str; @@ -1558,7 +1558,7 @@ Status FusedBatchNormGradExShape(shape_inference::InferenceContext* c) { &side_input_backprop)); c->set_output(5, side_input_backprop); - return OkStatus(); + return absl::OkStatus(); } Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, @@ -1581,7 +1581,7 @@ Status ReadDiagIndex(InferenceContext* c, const Tensor* diag_index_tensor, num_elements, " elements."); } } - return OkStatus(); + return absl::OkStatus(); } Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { @@ -1594,7 +1594,7 @@ Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || diag_index_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int32_t lower_diag_index = 0; int32_t upper_diag_index = 0; @@ -1634,7 +1634,7 @@ Status MatrixDiagPartV2Shape(shape_inference::InferenceContext* c) { } dims.push_back(c->MakeDim(max_diag_len)); c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { @@ -1651,7 +1651,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { if (!c->RankKnown(input_shape) || !c->FullyDefined(diag_index_shape) || diag_index_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int32_t lower_diag_index = 0; int32_t upper_diag_index = 0; @@ -1735,7 +1735,7 @@ Status MatrixDiagV2Shape(shape_inference::InferenceContext* c) { output_col_dim, &output_shape)); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { @@ -1807,7 +1807,7 @@ Status MatrixSetDiagV2Shape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->Merge(input_shape, diag_shape, &output_shape)); } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, @@ -1903,7 +1903,7 @@ Status MaxPoolShapeImpl(shape_inference::InferenceContext* c, output_depth, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPoolShape(shape_inference::InferenceContext* c) { @@ -1954,7 +1954,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); if (kernel_sizes_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); auto kernel_sizes_vec = kernel_sizes_tensor->flat(); @@ -1964,7 +1964,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); if (strides_tensor == nullptr) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } strides.resize(strides_tensor->shape().num_elements()); auto strides_vec = strides_tensor->flat(); @@ -2017,7 +2017,7 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { output_depth, &output_shape, c)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status Pool3DShape(shape_inference::InferenceContext* c) { @@ -2099,7 +2099,7 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { } c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status MaxPool3DGradShape(shape_inference::InferenceContext* c) { @@ -2111,14 +2111,14 @@ Status AvgPool3DGradShape(shape_inference::InferenceContext* c) { TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); } - return OkStatus(); + return absl::OkStatus(); } template @@ -2141,7 +2141,7 @@ Status ReductionShapeHelper(const Tensor* reduction_indices_t, true_indices->insert(wrapped_index); } - return OkStatus(); + return absl::OkStatus(); } Status ReductionShape(InferenceContext* c) { @@ -2167,7 +2167,7 @@ Status ReductionShape(InferenceContext* c) { if (keep_dims && c->RankKnown(input)) { // output rank matches input input if . c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return OkStatus(); + return absl::OkStatus(); } else { return shape_inference::UnknownShape(c); } @@ -2198,7 +2198,7 @@ Status ReductionShape(InferenceContext* c) { } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } Status ConcatShapeHelper(InferenceContext* c, int start_value_index, @@ -2220,7 +2220,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, } if (rank == InferenceContext::kUnknownRank) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } else if (rank == 0) { return errors::InvalidArgument( "Can't concatenate scalars (use tf.stack instead)"); @@ -2235,7 +2235,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, dims.reserve(rank); for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } // Merge all the non-concat dims, and sum the concat dim to make an output @@ -2286,7 +2286,7 @@ Status ConcatShapeHelper(InferenceContext* c, int start_value_index, c->Concatenate(output_before, c->Vector(output_middle), &s)); TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); c->set_output(0, s); - return OkStatus(); + return absl::OkStatus(); } Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { @@ -2315,7 +2315,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, CHECK_NOTNULL(out); if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } const int32_t rank_x = c->Rank(shape_x); const int32_t rank_y = c->Rank(shape_y); @@ -2347,13 +2347,13 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, if (c->Value(dim_x) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(dim_x); } else if (c->Value(dim_y) > 1) { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(dim_y); } else if (c->Value(dim_x) == 1) { @@ -2367,7 +2367,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } else { if (!incompatible_shape_error) { *out = c->UnknownShape(); - return OkStatus(); + return absl::OkStatus(); } dims.push_back(c->UnknownDim()); } @@ -2386,7 +2386,7 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, if (!s.ok()) { if (!incompatible_shape_error) { *out = c->MakeShape({}); - return OkStatus(); + return absl::OkStatus(); } return s; } @@ -2395,14 +2395,14 @@ Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, } *out = c->MakeShape(dims); - return OkStatus(); + return absl::OkStatus(); } Status RandomShape(shape_inference::InferenceContext* c) { shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { @@ -2433,7 +2433,7 @@ Status SegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { out = c->UnknownShape(); } c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -2463,7 +2463,7 @@ Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -2507,7 +2507,7 @@ Status SliceShape(InferenceContext* c) { SliceHelper(c, begin_value, sizes_value, &dims)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } else { // In case `sizes` is not available (`sizes_value` is null), // we could try to use `MakeShapeFromShapeTensor` here. @@ -2529,18 +2529,18 @@ Status SliceShape(InferenceContext* c) { dims.emplace_back(c->Dim(sizes_value, i)); } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } // We might know the rank of the input. if (c->RankKnown(input)) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return OkStatus(); + return absl::OkStatus(); } else { return shape_inference::UnknownShape(c); } } - return OkStatus(); + return absl::OkStatus(); } Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, @@ -2581,7 +2581,7 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, } } - return OkStatus(); + return absl::OkStatus(); } Status ValidateVariableResourceHandle( @@ -2601,7 +2601,7 @@ Status ValidateVariableResourceHandle( DataTypeString(value_dtype)); } } - return OkStatus(); + return absl::OkStatus(); } Status GatherNdShape(InferenceContext* c) { @@ -2620,7 +2620,7 @@ Status GatherNdShape(InferenceContext* c) { if (!c->RankKnown(params) || !c->ValueKnown(r_dim)) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } if (c->Value(r_dim) > c->Rank(params)) { @@ -2637,7 +2637,7 @@ Status GatherNdShape(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(c->Concatenate(indices_slice, params_slice, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, @@ -2700,7 +2700,7 @@ Status ScatterNdShapeHelper(InferenceContext* c, ShapeHandle indices_shape, // This is called for tf.scatter_nd; output is a tensor with this shape. c->set_output(0, input_shape); } - return OkStatus(); + return absl::OkStatus(); } Status ExplicitShape(InferenceContext* c) { @@ -2709,7 +2709,7 @@ Status ExplicitShape(InferenceContext* c) { ShapeHandle output_shape; TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); c->set_output(0, output_shape); - return OkStatus(); + return absl::OkStatus(); } Status ExplicitShapes(InferenceContext* c) { @@ -2724,7 +2724,7 @@ Status ExplicitShapes(InferenceContext* c) { c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape)); c->set_output(i, output_shape); } - return OkStatus(); + return absl::OkStatus(); } Status SparseReduceShapeFn(InferenceContext* c) { @@ -2770,7 +2770,7 @@ Status SparseReduceShapeFn(InferenceContext* c) { } c->set_output(0, c->MakeShape(dims)); - return OkStatus(); + return absl::OkStatus(); } return UnknownShape(c); } @@ -2784,7 +2784,7 @@ Status QuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { @@ -2831,19 +2831,19 @@ Status FusedQuantizedConvShape(InferenceContext* c, int num_dims) { c->set_output(1, channel); c->set_output(2, channel); } - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedConv2DShape(InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); - return OkStatus(); + return absl::OkStatus(); } Status FusedQuantizedDepthwiseConv2D(InferenceContext* c) { TF_RETURN_IF_ERROR(DepthwiseConv2DNativeShapeImpl(c, true)); TF_RETURN_IF_ERROR(FusedQuantizedConvShape(c, 4)); - return OkStatus(); + return absl::OkStatus(); } Status QuantizedAvgPoolShape(InferenceContext* c) { @@ -2853,7 +2853,7 @@ Status QuantizedAvgPoolShape(InferenceContext* c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); c->set_output(1, c->Scalar()); c->set_output(2, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } Status QuantizeV2Shape(InferenceContext* c) { @@ -2879,7 +2879,7 @@ Status QuantizeV2Shape(InferenceContext* c) { } c->set_output(1, minmax); c->set_output(2, minmax); - return OkStatus(); + return absl::OkStatus(); } Status ReduceScatterShape(shape_inference::InferenceContext* c) { @@ -2887,7 +2887,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { if (!c->RankKnown(in)) { // Input shape unknown, so set unknown output shape. c->set_output(0, in); - return OkStatus(); + return absl::OkStatus(); } shape_inference::ShapeHandle group_assignment_shape = c->input(1); @@ -2898,7 +2898,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { const Tensor* scatter_dimension = c->input_tensor(2); if (!scatter_dimension) { c->set_output(0, c->UnknownShape()); - return OkStatus(); + return absl::OkStatus(); } int64_t scatter_dim; TF_RETURN_IF_ERROR(c->GetScalarFromTensor(scatter_dimension, &scatter_dim)); @@ -2919,7 +2919,7 @@ Status ReduceScatterShape(shape_inference::InferenceContext* c) { } } c->set_output(0, c->MakeShape(out_dims)); - return OkStatus(); + return absl::OkStatus(); } } // namespace shape_inference diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index ce65aa99d13706..f1d43d6c2abfd3 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -52,7 +52,7 @@ inline Status UnchangedShapeWithRank(shape_inference::InferenceContext* c, ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Transfers shape of input(0) to output(0), after asserting its rank >= . @@ -61,7 +61,7 @@ inline Status UnchangedShapeWithRankAtLeast( ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Transfers shape of input(0) to output(0), after asserting its rank <= . @@ -70,18 +70,18 @@ inline Status UnchangedShapeWithRankAtMost(shape_inference::InferenceContext* c, ShapeHandle out; TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), rank, &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for use with ops no outputs. inline Status NoOutputs(shape_inference::InferenceContext* c) { - return OkStatus(); + return absl::OkStatus(); } // Shape function for ops that output a single scalar value. inline Status ScalarShape(shape_inference::InferenceContext* c) { c->set_output(0, c->Scalar()); - return OkStatus(); + return absl::OkStatus(); } // Shape function for binary ops where both inputs and the output match. @@ -89,7 +89,7 @@ inline Status MergeBothInputsShapeFn(InferenceContext* c) { ShapeHandle out; TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &out)); c->set_output(0, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for dataset iterators. @@ -240,7 +240,7 @@ inline Status BroadcastBinaryOpOutputShapeFn(InferenceContext* c, TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( c, c->input(0), c->input(1), true, &out)); c->set_output(output_index, out); - return OkStatus(); + return absl::OkStatus(); } // Shape function for binary operators that broadcast their inputs. diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 188e9813359e9a..4fd31ab201458e 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -211,7 +211,7 @@ static Status WrappedDatasetVariantDeviceCopy( const WrappedDatasetVariantWrapper& from, WrappedDatasetVariantWrapper* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { *to = WrappedDatasetVariantWrapper(from); - return OkStatus(); + return absl::OkStatus(); } #define REGISTER_OPTIONAL_COPY(DIRECTION) \ @@ -248,7 +248,7 @@ Status GraphDefBuilderWrapper::AddDataset( Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& list_inputs, const std::vector>& attrs, Node** output) { return AddDataset(dataset, inputs, list_inputs, attrs, @@ -258,7 +258,7 @@ Status GraphDefBuilderWrapper::AddDataset( Status GraphDefBuilderWrapper::AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& list_inputs, const std::vector>& attrs, bool use_dataset_name, Node** output) { auto& type_string = dataset->type_string(); @@ -320,7 +320,7 @@ Status GraphDefBuilderWrapper::AddDataset( return errors::Internal("AddDataset: Failed to build ", type_string, " op with error ", opts->StatusToString()); } - return OkStatus(); + return absl::OkStatus(); } Status GraphDefBuilderWrapper::AddFunction( @@ -329,7 +329,7 @@ Status GraphDefBuilderWrapper::AddFunction( if (b_->HasFunction(function_name)) { VLOG(1) << "Function with name " << function_name << "already exists in" << " the graph. It will not be added again."; - return OkStatus(); + return absl::OkStatus(); } const FunctionDef* f_def = lib_def.Find(function_name); if (f_def == nullptr) { @@ -363,7 +363,7 @@ Status GraphDefBuilderWrapper::AddFunction( for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) { TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second, lib_def)); } - return OkStatus(); + return absl::OkStatus(); } void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val, @@ -529,7 +529,7 @@ Status MemoryCheckpoint::Save(IteratorStateWriter* writer) const { auto [prefix, key] = id_registry_->Get(id); TF_RETURN_IF_ERROR(writer->WriteTensor(prefix, key, value)); } - return OkStatus(); + return absl::OkStatus(); } Status IteratorBase::InitializeBase(IteratorContext* ctx, @@ -551,7 +551,7 @@ Status IteratorBase::InitializeBase(IteratorContext* ctx, cleanup_fns_.push_back([this, model]() { model->RemoveNode(node_); }); } } - return OkStatus(); + return absl::OkStatus(); } Status GetCompressedElementFromVariantTensor( @@ -569,7 +569,7 @@ Status GetCompressedElementFromVariantTensor( "Tensor must be a `CompressedElement` object."); } *out_compressed_element = compressed_element; - return OkStatus(); + return absl::OkStatus(); } int64_t GetAllocatedBytes(const std::vector& element) { @@ -619,7 +619,7 @@ int64_t GetTotalBytes(const std::vector& element) { } std::string FullName(const std::string& prefix, const std::string& name) { - if (str_util::StrContains(name, kColon)) { + if (absl::StrContains(name, kColon)) { LOG(ERROR) << name << " should not contain " << kColon; } @@ -627,7 +627,7 @@ std::string FullName(const std::string& prefix, const std::string& name) { } Status ExtractIteratorPrefix(StringPiece key, string* prefix) { - if (!str_util::StartsWith(key, data::kFullNameRandomHex)) { + if (!absl::StartsWith(key, data::kFullNameRandomHex)) { return errors::InvalidArgument("Key: ", key, " was not generated using full_name."); } @@ -639,7 +639,7 @@ Status ExtractIteratorPrefix(StringPiece key, string* prefix) { string real_key = split_keys[1]; const int pos = real_key.rfind(kColon); *prefix = real_key.substr(0, pos); - return OkStatus(); + return absl::OkStatus(); } Status GetDatasetFromVariantTensor(const Tensor& tensor, @@ -658,7 +658,7 @@ Status GetDatasetFromVariantTensor(const Tensor& tensor, if (*out_dataset == nullptr) { return errors::Internal("Read uninitialized Dataset variant."); } - return OkStatus(); + return absl::OkStatus(); } Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { @@ -668,7 +668,7 @@ Status StoreDatasetInVariantTensor(DatasetBase* dataset, Tensor* tensor) { "Dataset tensor must be a scalar of dtype DT_VARIANT."); } tensor->scalar()() = DatasetVariantWrapper(dataset); - return OkStatus(); + return absl::OkStatus(); } namespace internal { @@ -792,12 +792,12 @@ Status DatasetBase::ComputeNumSources() { } if (num_sources_ >= 0) { // Already computed. - return OkStatus(); + return absl::OkStatus(); } num_sources_ = 0; if (inputs.empty()) { num_sources_ = 1; - return OkStatus(); + return absl::OkStatus(); } for (const auto& input : inputs) { if (input->num_sources() < 0) { @@ -808,7 +808,7 @@ Status DatasetBase::ComputeNumSources() { } num_sources_ += input->num_sources(); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { @@ -826,7 +826,7 @@ Status DatasetBase::CheckRandomAccessCompatible(const int64 index) const { return errors::OutOfRange("Index out of range [0, ", cardinality, "):", index); } - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::Get(OpKernelContext* ctx, int64 index, @@ -859,7 +859,7 @@ Status DatasetBase::MergeOptionsFromInputs() { return s; } if (inputs.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Merge options from inputs sequentially before merging options from dataset. // Since the last options merged takes precedence, the options that may be set @@ -871,7 +871,7 @@ Status DatasetBase::MergeOptionsFromInputs() { } internal::MergeOptions(options_, &merged_options); options_ = merged_options; - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::MakeIterator( @@ -883,12 +883,12 @@ Status DatasetBase::MakeIterator( Status s = InputDatasets(&inputs); return inputs[0]->MakeIterator(ctx, parent, output_prefix, iterator); } - profiler::TraceMe traceme( + tsl::profiler::TraceMe traceme( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( strings::StrCat("MakeIterator::", type_string()), {}); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); *iterator = MakeIteratorInternal(output_prefix); Status s = (*iterator)->InitializeBase(ctx, parent); if (s.ok()) { @@ -995,7 +995,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddInputDataset( << " will not be optimized because the dataset does not implement " "the " "AsGraphDefInternal() method needed to apply optimizations."; - return OkStatus(); + return absl::OkStatus(); } } return status; @@ -1033,7 +1033,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddIdentity( *output = ops::UnaryOp("Identity", *input, builder()->opts().WithName(UniqueNodeName(name_prefix))); - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( @@ -1055,7 +1055,7 @@ Status DatasetBase::DatasetGraphDefBuilder::AddDatasetOrTensorHelper( opts.op_registry()); node_builder.Input(std::move(nodes)); *output = opts.FinalizeBuilder(&node_builder); - return OkStatus(); + return absl::OkStatus(); } Status DatasetBase::DatasetGraphDefBuilder::AddResourceHelper( @@ -1138,8 +1138,8 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, "Iterator::GetNext", activity_watcher::ActivityCategory::kDatasetOp, std::move(attributes)); }); - profiler::TraceMe activity([&] { return BuildTraceMeName(); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity([&] { return BuildTraceMeName(); }, + tsl::profiler::TraceMeLevel::kInfo); DVLOG(3) << prefix() << " GetNext enter"; auto model = ctx->model(); bool output_was_recording = @@ -1189,8 +1189,8 @@ Status DatasetBaseIterator::GetNext(IteratorContext* ctx, Status DatasetBaseIterator::Skip(IteratorContext* ctx, int num_to_skip, bool* end_of_sequence, int* num_skipped) { - profiler::TraceMe activity([&] { return BuildTraceMeName(); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMe activity([&] { return BuildTraceMeName(); }, + tsl::profiler::TraceMeLevel::kInfo); DVLOG(3) << prefix() << " Skip enter"; auto model = ctx->model(); bool output_was_recording = @@ -1232,7 +1232,7 @@ Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, std::vector out_tensors; TF_RETURN_IF_ERROR(GetNextInternal(ctx, &out_tensors, end_of_sequence)); if (*end_of_sequence) { - return OkStatus(); + return absl::OkStatus(); } // RecordElement is used to count the number of element computed and // help calculate the CPU time spent on a given iterator to do the @@ -1244,7 +1244,7 @@ Status DatasetBaseIterator::SkipInternal(IteratorContext* ctx, int num_to_skip, RecordElement(ctx, &out_tensors); (*num_skipped)++; } - return OkStatus(); + return absl::OkStatus(); } void DatasetOpKernel::Compute(OpKernelContext* ctx) { @@ -1269,7 +1269,7 @@ void DatasetOpKernel::Compute(OpKernelContext* ctx) { string DatasetOpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const { - return profiler::TraceMeOp(name_view(), type_string_view()); + return tsl::profiler::TraceMeOp(name_view(), type_string_view()); } // static diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 9ebcd903961e8b..03470e6dd298f9 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -91,6 +91,14 @@ using TraceMeMetadata = std::vector>; // Maps the index of dataset elements to a globally shuffled index. See the // comment for IteratorContext::Params::index_mapper for more details. +// Notes: +// * `absl::OutOfRangeError` indicates the input index argument exceeds +// the cardinality of the dataset. +// * `absl::NotFoundError` indicates we should skip this element. +// This happens in the case we mix multiple datasets into one. For example, +// `dataset1.concatenate(dataset2)`. +// See go/tf-data-random-access-iterator and +// go/tf-data-random-access-iterator-for-concatenate for more info. using IndexMapperFn = std::function(size_t)>; constexpr char kTFDataFunction[] = "_tf_data_function"; @@ -223,7 +231,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddScalar: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a Const node with vector value to the Graph. @@ -242,7 +250,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddVector: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } Status AddVector(const std::vector& val, Node** output) { @@ -255,7 +263,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddVector: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a `Const` node for the given tensor value to the graph. @@ -268,7 +276,7 @@ class GraphDefBuilderWrapper { if (*output == nullptr) { return errors::Internal("AddTensor: Failed to build Const op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a `Placeholder` node for the given tensor value to the graph. @@ -282,7 +290,7 @@ class GraphDefBuilderWrapper { return errors::Internal( "AddPlaceholder: Failed to build Placeholder op."); } - return OkStatus(); + return absl::OkStatus(); } // Adds a node for the given dataset to the `Graph`. The value of @@ -311,13 +319,15 @@ class GraphDefBuilderWrapper { Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& + list_inputs, const std::vector>& attrs, Node** output); Status AddDataset( const DatasetBase* dataset, const std::vector>& inputs, - const std::vector>>& list_inputs, + const std::vector>>& + list_inputs, const std::vector>& attrs, bool use_dataset_name, Node** output); @@ -370,7 +380,7 @@ class GraphDefBuilderWrapper { TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name(), lib_def)); } } - return OkStatus(); + return absl::OkStatus(); } GraphDefBuilder* b_; @@ -493,7 +503,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { Status WriteScalar(StringPiece name, StringPiece key, int64_t val) override { auto id = id_registry_->Add(string(name), string(key)); int_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } Status WriteScalar(StringPiece key, const tstring& val) override { string prefix; @@ -504,7 +514,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { const tstring& val) override { auto id = id_registry_->Add(string(name), string(key)); str_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } Status WriteTensor(StringPiece key, const Tensor& val) override { string prefix; @@ -515,7 +525,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { const Tensor& val) override { auto id = id_registry_->Add(string(name), string(key)); tensor_values_[id] = val; - return OkStatus(); + return absl::OkStatus(); } // END implementation of `IteratorStateWriter` interface @@ -546,7 +556,7 @@ class MemoryCheckpoint final : public IteratorStateWriter { : is_root_(is_root), id_registry_(registry) {} void operator=(const MemoryCheckpoint&) = delete; - Status status_ = OkStatus(); + Status status_ = absl::OkStatus(); // Only set to true for the checkpoint in IteratorResource. // Root checkpoint does not track expired prefixes. const bool is_root_ = false; @@ -571,10 +581,10 @@ class SerializationContext { switch (params_.external_state_policy) { case ExternalStatePolicy::POLICY_WARN: LOG(WARNING) << s.ToString(); - return OkStatus(); + return absl::OkStatus(); case ExternalStatePolicy::POLICY_IGNORE: VLOG(2) << "Ignoring error status: " << s.ToString(); - return OkStatus(); + return absl::OkStatus(); case ExternalStatePolicy::POLICY_FAIL: return s; default: @@ -905,6 +915,10 @@ class IteratorContext { IndexMapperFn index_mapper() const { return params_.index_mapper; } + void set_restored_element_count(size_t element_count) { + params_.restored_element_count.emplace(element_count); + } + std::optional restored_element_count() const { return params_.restored_element_count; } @@ -1105,7 +1119,7 @@ class IteratorBase : public Checkpointable { // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. - virtual Status Initialize(IteratorContext* ctx) { return OkStatus(); } + virtual Status Initialize(IteratorContext* ctx) { return absl::OkStatus(); } // Performs initialization of the base iterator. Status InitializeBase(IteratorContext* ctx, const IteratorBase* parent); @@ -1116,7 +1130,7 @@ class IteratorBase : public Checkpointable { TF_RETURN_IF_ERROR(SaveInternal(ctx, writer)); VLOG(1) << "Saved " << prefix() << " in " << (EnvTime::NowMicros() - start_us) << "us"; - return OkStatus(); + return absl::OkStatus(); } // Restores the state of this iterator. @@ -1126,7 +1140,7 @@ class IteratorBase : public Checkpointable { ctx->SaveCheckpoint(this); VLOG(1) << "Restored " << prefix() << " in " << (EnvTime::NowMicros() - start_us) << "us"; - return OkStatus(); + return absl::OkStatus(); } // Returns the total number of bytes buffered by the iterator across all nodes @@ -1146,7 +1160,7 @@ class IteratorBase : public Checkpointable { Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer, const std::unique_ptr& input) { if (ctx->symbolic_checkpoint()) { - return OkStatus(); + return absl::OkStatus(); } return input->Save(ctx, writer); } @@ -1314,7 +1328,7 @@ class DatasetBase : public core::RefCounted { TF_RETURN_IF_ERROR(it->Restore(&restore_ctx, reader)); ctx->MergeCheckpoint(restore_ctx.checkpoint()); *iterator = std::move(it); - return OkStatus(); + return absl::OkStatus(); } Status MakeIteratorFromCheckpoint( @@ -1687,7 +1701,7 @@ Status ParseScalarArgument(OpKernelContext* ctx, return errors::InvalidArgument(argument_name, " must be a scalar"); } *output = argument_t->scalar()(); - return OkStatus(); + return absl::OkStatus(); } template @@ -1704,7 +1718,7 @@ Status ParseVectorArgument(OpKernelContext* ctx, for (int i = 0; i < size; ++i) { output->push_back(argument_t->vec()(i)); } - return OkStatus(); + return absl::OkStatus(); } // Encapsulates the work required to plug a DatasetBase into the core TensorFlow diff --git a/tensorflow/core/framework/dataset_stateful_op_allowlist.h b/tensorflow/core/framework/dataset_stateful_op_allowlist.h index 5e8cdd4af32a19..b92acf5fb74972 100644 --- a/tensorflow/core/framework/dataset_stateful_op_allowlist.h +++ b/tensorflow/core/framework/dataset_stateful_op_allowlist.h @@ -27,12 +27,12 @@ class AllowlistedStatefulOpRegistry { public: Status Add(string op_name) { op_names_.insert(std::move(op_name)); - return OkStatus(); + return absl::OkStatus(); } Status Remove(string op_name) { op_names_.erase(op_name); - return OkStatus(); + return absl::OkStatus(); } bool Contains(const string& op_name) { return op_names_.count(op_name); } diff --git a/tensorflow/core/framework/dataset_test.cc b/tensorflow/core/framework/dataset_test.cc index a63255171a2df7..66213ea5721b13 100644 --- a/tensorflow/core/framework/dataset_test.cc +++ b/tensorflow/core/framework/dataset_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/framework/device.h b/tensorflow/core/framework/device.h index 6cdcd2efd90ab9..08231d55d3a160 100644 --- a/tensorflow/core/framework/device.h +++ b/tensorflow/core/framework/device.h @@ -142,7 +142,7 @@ class Device : public DeviceBase { // 'graph' supplies the partition of the graph assigned to this // device. virtual Status MaybeRewriteGraph(std::unique_ptr* /*graph*/) { - return OkStatus(); + return absl::OkStatus(); } // Sets `out_context` a new DeviceContext* for executing a graph, or nullptr @@ -153,7 +153,7 @@ class Device : public DeviceBase { // and should call Unref(). virtual Status TryGetDeviceContext(DeviceContext** out_context) { *out_context = nullptr; - return OkStatus(); + return absl::OkStatus(); } // Returns the op segment of this device. The caller can reuse op diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index c8fbf9e1635296..065707fde4b8c2 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -229,7 +229,7 @@ class DeviceBase { PerOpGpuDevice* /*device*/, DeviceContext* /*dc*/, Allocator* /*allocator*/) { - return OkStatus(); + return absl::OkStatus(); } // Unimplemented by default diff --git a/tensorflow/core/framework/device_factory.cc b/tensorflow/core/framework/device_factory.cc index 43ad12393ac9a3..e39d768a56c785 100644 --- a/tensorflow/core/framework/device_factory.cc +++ b/tensorflow/core/framework/device_factory.cc @@ -151,7 +151,7 @@ Status DeviceFactory::ListAllPhysicalDevices(std::vector* devices) { } } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::ListPluggablePhysicalDevices( @@ -163,7 +163,7 @@ Status DeviceFactory::ListPluggablePhysicalDevices( TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices)); } } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::GetAnyDeviceDetails( @@ -223,7 +223,7 @@ Status DeviceFactory::AddCpuDevices( return errors::NotFound("No CPU devices are available in this process"); } - return OkStatus(); + return absl::OkStatus(); } Status DeviceFactory::AddDevices( @@ -259,7 +259,7 @@ Status DeviceFactory::AddDevices( } } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr DeviceFactory::NewDevice(const string& type, diff --git a/tensorflow/core/framework/device_factory.h b/tensorflow/core/framework/device_factory.h index c238aebf475bd2..7957af3cbad869 100644 --- a/tensorflow/core/framework/device_factory.h +++ b/tensorflow/core/framework/device_factory.h @@ -85,7 +85,7 @@ class DeviceFactory { // into devices from ListPhysicalDevices. virtual Status GetDeviceDetails(int device_index, std::unordered_map* details) { - return OkStatus(); + return absl::OkStatus(); } // Most clients should call AddDevices() instead. diff --git a/tensorflow/core/framework/fake_input.cc b/tensorflow/core/framework/fake_input.cc index 415125c73b1bf3..bf7edef06ddae9 100644 --- a/tensorflow/core/framework/fake_input.cc +++ b/tensorflow/core/framework/fake_input.cc @@ -108,14 +108,14 @@ Status FakeInputImpl::AddInputToBuilder() { "': ", status.message()); } SourceList(dts); - return OkStatus(); + return absl::OkStatus(); } DataType dt; TF_RETURN_IF_ERROR(GetDataType(&dt)); builder_->Input(in_node_, 0, dt); } - return OkStatus(); + return absl::OkStatus(); } // static @@ -134,13 +134,13 @@ Status FakeInputImpl::GetN(int* n) const { arg_->name(), "': ", status.message()); } } - return OkStatus(); + return absl::OkStatus(); } Status FakeInputImpl::GetDataType(DataType* dt) const { if (dt_specified_) { *dt = dt_; - return OkStatus(); // Ignore is_ref field of arg_. + return absl::OkStatus(); // Ignore is_ref field of arg_. } else if (arg_->type() != DT_INVALID) { *dt = arg_->type(); } else if (!arg_->type_attr().empty()) { @@ -162,7 +162,7 @@ Status FakeInputImpl::GetDataType(DataType* dt) const { if (arg_->is_ref()) { *dt = MakeRefType(*dt); } - return OkStatus(); + return absl::OkStatus(); } void FakeInputImpl::NSources(int n, DataType dt) const { @@ -171,7 +171,7 @@ void FakeInputImpl::NSources(int n, DataType dt) const { for (int i = 0; i < n; ++i) { srcs.emplace_back(in_node_, i, dt); } - builder_->Input(gtl::ArraySlice(srcs)); + builder_->Input(absl::Span(srcs)); } void FakeInputImpl::SourceList(DataTypeSlice dts) const { @@ -180,7 +180,7 @@ void FakeInputImpl::SourceList(DataTypeSlice dts) const { for (size_t i = 0; i < dts.size(); ++i) { srcs.emplace_back(in_node_, i, dts[i]); } - builder_->Input(gtl::ArraySlice(srcs)); + builder_->Input(absl::Span(srcs)); } } // namespace diff --git a/tensorflow/core/framework/full_type_util.cc b/tensorflow/core/framework/full_type_util.cc index fcc6446b67ac4b..b76b1d52274095 100644 --- a/tensorflow/core/framework/full_type_util.cc +++ b/tensorflow/core/framework/full_type_util.cc @@ -41,7 +41,7 @@ OpTypeConstructor NoOp() { OpTypeConstructor NoOutputs() { return [](OpDef* op_def) { op_def->mutable_output_arg(); - return OkStatus(); + return absl::OkStatus(); }; } @@ -50,7 +50,7 @@ OpTypeConstructor Nullary(FullTypeId t) { FullTypeDef* tdef = op_def->mutable_output_arg(0)->mutable_experimental_full_type(); tdef->set_type_id(t); - return OkStatus(); + return absl::OkStatus(); }; } @@ -64,7 +64,7 @@ OpTypeConstructor Unary(FullTypeId t, const string& var_name) { arg->set_type_id(TFT_VAR); arg->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -77,7 +77,7 @@ OpTypeConstructor UnaryGeneric(FullTypeId t) { FullTypeDef* arg = tdef->add_args(); arg->set_type_id(TFT_ANY); - return OkStatus(); + return absl::OkStatus(); }; } @@ -92,7 +92,7 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, FullTypeId dtype) { FullTypeDef* targ = arg->add_args(); targ->set_type_id(dtype); - return OkStatus(); + return absl::OkStatus(); }; } @@ -108,7 +108,7 @@ OpTypeConstructor UnaryTensorContainer(FullTypeId t, const string& var_name) { varg->set_type_id(TFT_VAR); varg->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -133,7 +133,7 @@ OpTypeConstructor VariadicTensorContainer(FullTypeId t, tvar->set_type_id(TFT_VAR); tvar->set_s(var_name); - return OkStatus(); + return absl::OkStatus(); }; } @@ -176,7 +176,7 @@ Status SubstituteVar(AttrMap& attrs, FullTypeDef& t) { attr->DebugString(), " for name ", var_name)); } t.clear_s(); - return OkStatus(); + return absl::OkStatus(); } Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { @@ -238,7 +238,7 @@ Status SubstituteForEach(AttrMap& attrs, FullTypeDef& t) { attr->DebugString(), "\nfor name ", var_name)); } t = result; - return OkStatus(); + return absl::OkStatus(); } Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { @@ -257,7 +257,7 @@ Status SubstituteGeneric(AttrMap& attrs, FullTypeDef& t) { break; } } - return OkStatus(); + return absl::OkStatus(); } inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { @@ -281,7 +281,7 @@ inline Status SubstituteFromAttrs(AttrMap& attrs, FullTypeDef& t) { default: return SubstituteGeneric(attrs, t); } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -312,7 +312,7 @@ Status SpecializeType(const AttrSlice& attrs, const OpDef& op_def, t.DebugString(), "\nfrom\n", attrs.SummarizeNode()); } - return OkStatus(); + return absl::OkStatus(); } const FullTypeDef& GetArgDefaultUnset(const FullTypeDef& t, int i) { diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 0b6bacd94af0d9..61cfee4198de94 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -90,7 +90,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, for (int i = 0; i < v->list().type_size(); ++i) { dtypes->push_back(v->list().type(i)); } - return OkStatus(); + return absl::OkStatus(); } *is_type_list = false; @@ -116,7 +116,7 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, dtype = v->type(); } dtypes->resize(num, dtype); - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -166,7 +166,7 @@ Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { } #endif - return OkStatus(); + return absl::OkStatus(); } // A helper class for instantiating functions. This contains shared information @@ -229,7 +229,7 @@ class FunctionInstantiationHelper { result_.arg_types.push_back(dtypes[i]); ++arg_index; } - return OkStatus(); + return absl::OkStatus(); } Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, @@ -259,7 +259,7 @@ class FunctionInstantiationHelper { } start += dtypes.size(); } - return OkStatus(); + return absl::OkStatus(); } Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { @@ -363,7 +363,7 @@ class FunctionInstantiationHelper { *gnode->mutable_experimental_type() = fnode.experimental_type(); } - return OkStatus(); + return absl::OkStatus(); } Status AddReturnNode( @@ -406,7 +406,7 @@ class FunctionInstantiationHelper { AddAttr("index", (*ret_index)++, gnode); result_.ret_types.push_back(dtypes[i]); } - return OkStatus(); + return absl::OkStatus(); } // Adds the actual node inputs to the result graph by converting indexes to @@ -452,7 +452,7 @@ class FunctionInstantiationHelper { " name: "), name); } - return OkStatus(); + return absl::OkStatus(); } const NameInfoItem* GetItemOrNull(const string& name) const { @@ -644,7 +644,7 @@ string Print(const FunctionDef& fdef) { return out; } -string Print(gtl::ArraySlice nodes) { +string Print(absl::Span nodes) { std::vector arg; std::vector ret; std::vector body; @@ -738,7 +738,7 @@ Status AddDefaultAttrs(const string& op, } } } - return OkStatus(); + return absl::OkStatus(); } } // end namespace @@ -857,7 +857,7 @@ Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, // Adds the actual node inputs using the input indexes. helper.AddNodeInputs(); - return OkStatus(); + return absl::OkStatus(); } string DebugString(const FunctionDef& func_def) { return Print(func_def); } @@ -870,7 +870,7 @@ string DebugString(const GraphDef& instantiated_func_def) { return Print(ptrs); } -string DebugString(gtl::ArraySlice instantiated_func_nodes) { +string DebugString(absl::Span instantiated_func_nodes) { std::vector ptrs; for (const NodeDef& n : instantiated_func_nodes) { ptrs.push_back(&n); @@ -1147,7 +1147,7 @@ FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, FunctionCallFrame::~FunctionCallFrame() {} -Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { +Status FunctionCallFrame::SetArgs(absl::Span args) { // Input type checks. if (args.size() != arg_types_.size()) { return errors::InvalidArgument("Expects ", arg_types_.size(), @@ -1162,7 +1162,7 @@ Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { } args_[i] = args[i]; } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::GetRetvals(std::vector* rets) const { @@ -1176,7 +1176,7 @@ Status FunctionCallFrame::GetRetvals(std::vector* rets) const { return errors::Internal("Retval[", i, "] does not have value"); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, @@ -1192,7 +1192,7 @@ Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, return errors::Internal("Retval[", i, "] does not have value"); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::GetArg(int index, const Tensor** val) { @@ -1201,7 +1201,7 @@ Status FunctionCallFrame::GetArg(int index, const Tensor** val) { args_.size(), ")"); } *val = &args_[index]; - return OkStatus(); + return absl::OkStatus(); } Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { @@ -1221,7 +1221,7 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { } else { return errors::Internal("Retval[", index, "] has already been set."); } - return OkStatus(); + return absl::OkStatus(); } FunctionRecord::FunctionRecord(const FunctionDef& fdef, @@ -1446,7 +1446,7 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, "exists."); } // Ignore duplicate FunctionDefs. - return OkStatus(); + return absl::OkStatus(); } const OpDef* op_def; if (default_registry_ @@ -1460,7 +1460,7 @@ Status FunctionLibraryDefinition::AddHelper(FunctionRecord* registration, registration->finalize(); records_.insert({registration->fdef().signature().name(), registration}); *added = true; - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::CopyFunctionDefFrom( @@ -1485,7 +1485,7 @@ Status FunctionLibraryDefinition::CopyFunctionDefFrom( "' because a different function with the same name already " "exists."); } else { - return OkStatus(); + return absl::OkStatus(); } } else if (other_record->finalized()) { bool added; @@ -1514,11 +1514,11 @@ Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, "'", *entry, "'"); } // Ignore duplicate GradientDefs - return OkStatus(); + return absl::OkStatus(); } *entry = grad.gradient_func(); *added = true; - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::AddLibrary( @@ -1567,7 +1567,7 @@ Status FunctionLibraryDefinition::AddLibrary( funcs_with_grads.push_back(grad.function_name()); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::AddLibrary( @@ -1625,7 +1625,7 @@ Status FunctionLibraryDefinition::AddLibrary( funcs_with_grads.push_back(grad.function_name()); } } - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::ReplaceFunction( @@ -1636,7 +1636,7 @@ Status FunctionLibraryDefinition::ReplaceFunction( TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); TF_RETURN_IF_ERROR(AddFunctionDefHelper( FunctionDef(fdef), StackTracesMap(stack_traces), &added)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { @@ -1644,13 +1644,13 @@ Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { bool added; TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::RemoveFunction(const string& func) { mutex_lock l(mu_); TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { @@ -1661,7 +1661,7 @@ Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { } iter->second->Unref(); records_.erase(iter); - return OkStatus(); + return absl::OkStatus(); } void FunctionLibraryDefinition::Clear() { @@ -1681,7 +1681,7 @@ Status FunctionLibraryDefinition::RemoveGradient(const string& func) { func, "'."); } func_grad_.erase(i); - return OkStatus(); + return absl::OkStatus(); } Status FunctionLibraryDefinition::Remove( @@ -1700,7 +1700,7 @@ Status FunctionLibraryDefinition::Remove( return s; } } - return OkStatus(); + return absl::OkStatus(); } string FunctionLibraryDefinition::FindGradient(const string& func) const { @@ -1718,7 +1718,7 @@ Status FunctionLibraryDefinition::LookUp( auto iter = records_.find(op); if (iter != records_.end()) { *op_reg_data = &iter->second->op_registration_data(); - return OkStatus(); + return absl::OkStatus(); } return default_registry_->LookUp(op, op_reg_data); } @@ -1796,7 +1796,7 @@ Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, const string& attr, T* value) const { const FunctionDef* fdef = GetAttrImpl(ndef); if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) { - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Attr ", attr, " is not defined."); } @@ -1837,7 +1837,7 @@ std::set ReachableFunctions(const FunctionLibraryDefinition& flib, // Functions might be reachable from the nested function calls, so we keep a // queue of functions that we have to check. - gtl::InlinedVector, 4> func_queue; + absl::InlinedVector, 4> func_queue; // Add reachable and not already processed functions to the functions queue. const auto add_to_func_queue = [&](const string& func_name) { @@ -2043,7 +2043,7 @@ void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( const string& name, - gtl::ArraySlice> attrs) { + absl::Span> attrs) { AttrValueWrapper ret; ret.proto.mutable_func()->set_name(name); for (const auto& a : attrs) { @@ -2081,11 +2081,11 @@ NodeDef FunctionDefHelper::Node::ToNodeDef() const { /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, gtl::ArraySlice in_def, - gtl::ArraySlice out_def, gtl::ArraySlice attr_def, - gtl::ArraySlice node_def, - gtl::ArraySlice> ret_def, - gtl::ArraySlice> control_ret_def) { + const string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def, + absl::Span> control_ret_def) { FunctionDef fdef; // Signature @@ -2131,20 +2131,20 @@ FunctionDef FunctionDefHelper::Create( /* static */ FunctionDef FunctionDefHelper::Create( - const string& function_name, gtl::ArraySlice in_def, - gtl::ArraySlice out_def, gtl::ArraySlice attr_def, - gtl::ArraySlice node_def, - gtl::ArraySlice> ret_def) { + const string& function_name, absl::Span in_def, + absl::Span out_def, absl::Span attr_def, + absl::Span node_def, + absl::Span> ret_def) { return Create(function_name, in_def, out_def, attr_def, node_def, ret_def, /*control_ret_def=*/{}); } /* static */ FunctionDef FunctionDefHelper::Define(const string& name, - gtl::ArraySlice arg_def, - gtl::ArraySlice ret_def, - gtl::ArraySlice attr_def, - gtl::ArraySlice node_def) { + absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def) { FunctionDef fdef; OpDefBuilder b(name); for (const auto& a : arg_def) b.Input(a); @@ -2209,10 +2209,10 @@ FunctionDef FunctionDefHelper::Define(const string& name, return fdef; } -FunctionDef FunctionDefHelper::Define(gtl::ArraySlice arg_def, - gtl::ArraySlice ret_def, - gtl::ArraySlice attr_def, - gtl::ArraySlice node_def) { +FunctionDef FunctionDefHelper::Define(absl::Span arg_def, + absl::Span ret_def, + absl::Span attr_def, + absl::Span node_def) { return Define("_", arg_def, ret_def, attr_def, node_def); } @@ -2238,7 +2238,7 @@ Status GetOpGradientCreator(const string& op, Creator* creator) { return errors::NotFound("No gradient defined for op: ", op); } *creator = iter->second; - return OkStatus(); + return absl::OkStatus(); } } // end namespace gradient diff --git a/tensorflow/core/framework/function_handle_cache.cc b/tensorflow/core/framework/function_handle_cache.cc index 446f8cefdc81ed..add92c44aff5bc 100644 --- a/tensorflow/core/framework/function_handle_cache.cc +++ b/tensorflow/core/framework/function_handle_cache.cc @@ -51,7 +51,7 @@ Status FunctionHandleCache::Instantiate( } else { *handle = h; } - return OkStatus(); + return absl::OkStatus(); } Status FunctionHandleCache::Clear() { @@ -60,7 +60,7 @@ Status FunctionHandleCache::Clear() { TF_RETURN_IF_ERROR(lib_->ReleaseHandle(entry.second)); } handles_.clear(); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_def_util.cc b/tensorflow/core/framework/graph_def_util.cc index 418de4290a5c2d..8b9a8615bc6113 100644 --- a/tensorflow/core/framework/graph_def_util.cc +++ b/tensorflow/core/framework/graph_def_util.cc @@ -49,7 +49,7 @@ Status ValidateExternalGraphDefSyntax(const GraphDef& graph_def) { for (const NodeDef& node : graph_def.node()) { TF_RETURN_IF_ERROR(ValidateExternalNodeDefSyntax(node)); } - return OkStatus(); + return absl::OkStatus(); } Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, @@ -79,7 +79,7 @@ Status AddDefaultAttrsToGraphDef(GraphDef* graph_def, } } - return OkStatus(); + return absl::OkStatus(); } static Status RemoveNewDefaultAttrsFromNodeDef( @@ -124,7 +124,7 @@ static Status RemoveNewDefaultAttrsFromNodeDef( } } - return OkStatus(); + return absl::OkStatus(); } static bool IsFunction(const GraphDef& graph_def, const string& op_name) { @@ -161,7 +161,7 @@ Status RemoveNewDefaultAttrsFromGraphDef( } } - return OkStatus(); + return absl::OkStatus(); } void StripDefaultAttributes(const OpRegistryInterface& op_registry, @@ -261,7 +261,7 @@ Status StrippedOpListForGraph(const GraphDef& graph_def, stripped_op->CopyFrom(*op_def); RemoveDescriptionsFromOpDef(stripped_op); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc index 384d9cba6865a2..fcd48e3fc5e047 100644 --- a/tensorflow/core/framework/graph_to_functiondef.cc +++ b/tensorflow/core/framework/graph_to_functiondef.cc @@ -145,7 +145,7 @@ Status NodeNameMapping::UseOutputName(const string& name) { "' appears more than once in 'output_names' array."); } used_names_.emplace(name, 0); - return OkStatus(); + return absl::OkStatus(); } string NodeNameMapping::Lookup(const string& name) const { @@ -318,7 +318,7 @@ Status FillFunctionBody( func_attr_names.insert(func_attr_name); } } - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDefHelper( @@ -536,7 +536,7 @@ Status GraphToFunctionDefHelper( fdef->mutable_signature()->add_control_output(control_output); } - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDefHelper( @@ -560,7 +560,7 @@ Status GraphToFunctionDefHelper( (*args_or_retvals)[index].node->DebugString(), "\nNow we have:\n", node->DebugString()); } - return OkStatus(); + return absl::OkStatus(); }; std::vector body_nodes; @@ -599,7 +599,7 @@ Status GraphToFunctionDefHelper( "' node at index ", i); } } - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR(validate_args_retvals(inputs, "_Arg")); @@ -631,7 +631,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, copy_placeholder_attrs_from_nodes, body_nodes, inputs, outputs, output_names, control_outputs, control_output_names, description, /*allow_destructive_reads=*/false, fdef); - return OkStatus(); + return absl::OkStatus(); } Status GraphToFunctionDef( diff --git a/tensorflow/core/framework/graph_to_functiondef_test.cc b/tensorflow/core/framework/graph_to_functiondef_test.cc index e6c30171910402..f29295274dfbe2 100644 --- a/tensorflow/core/framework/graph_to_functiondef_test.cc +++ b/tensorflow/core/framework/graph_to_functiondef_test.cc @@ -229,7 +229,7 @@ TEST(GraphToFunctionDefTest, ArgAttrConstInput) { args_or_retvals->resize(index + 1); } (*args_or_retvals)[index].node = node; - return OkStatus(); + return absl::OkStatus(); }; for (Node* node : root.graph()->op_nodes()) { // Set const as the input node. diff --git a/tensorflow/core/framework/kernel_def_util.cc b/tensorflow/core/framework/kernel_def_util.cc index 69738eea671f52..d1f556bdaa9288 100644 --- a/tensorflow/core/framework/kernel_def_util.cc +++ b/tensorflow/core/framework/kernel_def_util.cc @@ -117,7 +117,7 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, if (attr_value->type() != DT_INVALID) { if (!InTypeList(attr_value->type(), constraint.allowed_values())) { - return OkStatus(); + return absl::OkStatus(); } } else { if (!AttrValueHasType(*attr_value, "list(type)").ok()) { @@ -133,13 +133,13 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs, for (int t : attr_value->list().type()) { if (!InTypeList(static_cast(t), constraint.allowed_values())) { - return OkStatus(); + return absl::OkStatus(); } } } } *match = true; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/kernel_shape_util.cc b/tensorflow/core/framework/kernel_shape_util.cc index 071821ce4a56d6..f06a366f435e5f 100644 --- a/tensorflow/core/framework/kernel_shape_util.cc +++ b/tensorflow/core/framework/kernel_shape_util.cc @@ -63,7 +63,7 @@ Status GetWindowedOutputSizeVerbose(int64_t input_size, int64_t filter_size, ", effective_filter_size: ", effective_filter_size, ", stride: ", stride, "]"); } - return OkStatus(); + return absl::OkStatus(); } Status GetWindowedOutputSize(int64_t input_size, int64_t filter_size, @@ -93,6 +93,6 @@ Status Get3dOutputSizeV2(const std::array& input, input[i], window[i], dilations[i], strides[i], padding_type, &(*output_ptr)[i], &(*padding_ptr)[i])); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc index f1c3a4b3935605..d428f6d463ea51 100644 --- a/tensorflow/core/framework/load_library.cc +++ b/tensorflow/core/framework/load_library.cc @@ -66,7 +66,7 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, if (seen_op_names.find(opdef.name()) == seen_op_names.end()) { // Over writing a registration of an op not in this custom op // library. Treat this as not an error. - return OkStatus(); + return absl::OkStatus(); } } if (s.ok()) { @@ -98,7 +98,7 @@ Status LoadDynamicLibrary(const char* library_filename, void** result, *len = str.length(); *result = library.handle; - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/local_rendezvous.cc b/tensorflow/core/framework/local_rendezvous.cc index 488e9251d8e913..910c8a92a744fb 100644 --- a/tensorflow/core/framework/local_rendezvous.cc +++ b/tensorflow/core/framework/local_rendezvous.cc @@ -191,7 +191,7 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, queue->push_back(new Item(std::move(rc_owner), send_args, val, is_dead, std::move(activity_scope))); bucket.mu.unlock(); - return OkStatus(); + return absl::OkStatus(); } DVLOG(2) << "Consume Recv Item (key:" << key.FullKey() << "). "; @@ -210,7 +210,8 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, bucket.mu.unlock(); DCHECK_EQ(item->type, Item::kRecv); - (*item->recv_state.waiter)(OkStatus(), send_args, item->args, val, is_dead); + (*item->recv_state.waiter)(absl::OkStatus(), send_args, item->args, val, + is_dead); { mutex_lock l(bucket.mu); bucket.pending_callback_counter--; @@ -220,7 +221,7 @@ Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key, } // Delete the item at last since it may unref and destruct the rendezvous. delete item; - return OkStatus(); + return absl::OkStatus(); } void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, @@ -367,7 +368,7 @@ void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key, bucket.mu.unlock(); DCHECK_EQ(item->type, Item::kSend); - done(OkStatus(), item->args, recv_args, *item->send_state.value, + done(absl::OkStatus(), item->args, recv_args, *item->send_state.value, item->send_state.is_dead); { mutex_lock l(bucket.mu); diff --git a/tensorflow/core/framework/lookup_interface.cc b/tensorflow/core/framework/lookup_interface.cc index b868faf03ef426..2dc224c3f5b6ea 100644 --- a/tensorflow/core/framework/lookup_interface.cc +++ b/tensorflow/core/framework/lookup_interface.cc @@ -27,7 +27,7 @@ Status LookupInterface::CheckKeyShape(const TensorShape& shape) { " must end with the table's key shape ", key_shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, @@ -40,7 +40,7 @@ Status LookupInterface::CheckKeyAndValueTypes(const Tensor& keys, return errors::InvalidArgument("Value must be type ", value_dtype(), " but got ", values.dtype()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, @@ -58,7 +58,7 @@ Status LookupInterface::CheckKeyAndValueTensorsHelper(const Tensor& keys, "Expected shape ", expected_value_shape.DebugString(), " for value, got ", values.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } Status LookupInterface::CheckKeyAndValueTensorsForInsert(const Tensor& keys, @@ -95,7 +95,7 @@ Status LookupInterface::CheckFindArguments(const Tensor& key, fullsize_value_shape.DebugString(), " for default value, got ", default_value.shape().DebugString()); } - return OkStatus(); + return absl::OkStatus(); } } // namespace lookup diff --git a/tensorflow/core/framework/memory_types.cc b/tensorflow/core/framework/memory_types.cc index 52a6afd2632845..b983cf95d8ca4a 100644 --- a/tensorflow/core/framework/memory_types.cc +++ b/tensorflow/core/framework/memory_types.cc @@ -193,7 +193,7 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry, } } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index a34e274c48228f..1fc6622bebe170 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -375,7 +375,7 @@ Status ModelToProtoHelper(std::shared_ptr output, ModelProto* model) { to_serialize.push_back(input); } } - return OkStatus(); + return absl::OkStatus(); } // Recursively produces node tree rooted in `output` from the given model proto. @@ -398,7 +398,7 @@ Status ModelFromProtoHelper(ModelProto model, std::shared_ptr* output) { to_restore_inputs.push_back(input); } } - return OkStatus(); + return absl::OkStatus(); } // The first input of InterleaveMany corresponds to the input dataset whose @@ -555,7 +555,7 @@ class InterleaveMany : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::INTERLEAVE_MANY); - return OkStatus(); + return absl::OkStatus(); } }; @@ -778,7 +778,7 @@ class AsyncInterleaveMany : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_INTERLEAVE_MANY); - return OkStatus(); + return absl::OkStatus(); } }; @@ -871,7 +871,7 @@ class KnownRatio : public Node { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::KNOWN_RATIO); node_proto->set_ratio(ratio_); - return OkStatus(); + return absl::OkStatus(); } private: @@ -1250,7 +1250,7 @@ class UnknownRatio : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN_RATIO); - return OkStatus(); + return absl::OkStatus(); } }; @@ -1304,7 +1304,7 @@ class Unknown : public Node { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::UNKNOWN); - return OkStatus(); + return absl::OkStatus(); } }; @@ -1347,7 +1347,7 @@ class AsyncKnownRatio : public AsyncRatio { parameter->set_value(parameter->state_value()); parameter->set_tunable(true); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -1390,7 +1390,7 @@ class AsyncUnknownRatio : public AsyncRatio { Status ToProto(ModelProto::Node* node_proto) const override { TF_RETURN_IF_ERROR(Node::ToProto(node_proto)); node_proto->set_node_class(NodeClass::ASYNC_UNKNOWN_RATIO); - return OkStatus(); + return absl::OkStatus(); } }; @@ -2168,7 +2168,7 @@ Status Node::ToProto(ModelProto::Node* node_proto) const { for (auto const& input : inputs_) { node_proto->add_inputs(input->id()); } - return OkStatus(); + return absl::OkStatus(); } Status Node::FromProtoHelper(ModelProto::Node node_proto, @@ -2218,7 +2218,7 @@ Status Node::FromProtoHelper(ModelProto::Node node_proto, mutex_lock l(node->mu_); node->UpdateProcessingTimeEma(); } - return OkStatus(); + return absl::OkStatus(); } Status Node::FromProto(ModelProto::Node node_proto, @@ -2567,7 +2567,7 @@ Status Model::OptimizeLoop(AutotuneAlgorithm algorithm, current_time_ms = EnvTime::NowMicros() / EnvTime::kMillisToMicros; } if (cancellation_manager->IsCancelled()) { - return OkStatus(); + return absl::OkStatus(); } } @@ -3194,7 +3194,7 @@ Status Model::ToProto(ModelProto* model_proto) { tf_shared_lock gap_lock(gap_mu_); *model_proto->mutable_gap_times() = {gap_times_usec_.begin(), gap_times_usec_.end()}; - return OkStatus(); + return absl::OkStatus(); } Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { @@ -3204,7 +3204,7 @@ Status Model::FromProto(ModelProto model_proto, std::unique_ptr* model) { ModelFromProtoHelper(model_proto, &restored_model->output_)); restored_model->id_counter_ = model_proto.id_counter(); *model = std::move(restored_model); - return OkStatus(); + return absl::OkStatus(); } Status Model::Save(const string& fname, std::shared_ptr snapshot, @@ -3232,7 +3232,7 @@ Status Model::Load(const string& fname, std::unique_ptr* model, const OptimizationParams restored_optimization_params = model_proto.optimization_params(); *optimization_params = restored_optimization_params; - return OkStatus(); + return absl::OkStatus(); } std::string Model::DebugString() { diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc index fcf73e6970bb5c..86365b494217bd 100644 --- a/tensorflow/core/framework/node_def_builder.cc +++ b/tensorflow/core/framework/node_def_builder.cc @@ -106,7 +106,7 @@ NodeDefBuilder& NodeDefBuilder::Input(const NodeOut& src) { } // For inputs that take a list of tensors. -NodeDefBuilder& NodeDefBuilder::Input(gtl::ArraySlice src_list) { +NodeDefBuilder& NodeDefBuilder::Input(absl::Span src_list) { const OpDef::ArgDef* arg = NextArgDef(); if (arg != nullptr) ListInput(arg, src_list); return *this; @@ -134,7 +134,7 @@ void NodeDefBuilder::SingleInput(const OpDef::ArgDef* input_arg, } void NodeDefBuilder::ListInput(const OpDef::ArgDef* input_arg, - gtl::ArraySlice src_list) { + absl::Span src_list) { for (const auto& node_out : src_list) { AddInput(node_out.node, node_out.index); } @@ -262,7 +262,7 @@ Status NodeDefBuilder::Finalize(NodeDef* node_def, bool consume) { // Add default values for unspecified attrs. AddDefaultsToNodeDef(*op_def_, node_def); - return OkStatus(); + return absl::OkStatus(); } } @@ -311,21 +311,21 @@ ATTR(const PartialTensorShape&) ATTR(const Tensor&) ATTR(const TensorProto&) ATTR(const NameAttrList&) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) ATTR(const std::vector&) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) -ATTR(gtl::ArraySlice) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) +ATTR(absl::Span) #undef ATTR } // namespace tensorflow diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index d3af99893e7897..183a80ac18b1f5 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -185,7 +185,7 @@ const AttrValue* AttrSlice::FindByString(const string& attr_name) const { Status AttrSlice::CheckFind(StringPiece attr_name, const AttrValue* attr_value) const { if (attr_value != nullptr) { - return OkStatus(); + return absl::OkStatus(); } Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:"); // Skip AttachDef for internal attrs since it is a little bit @@ -402,7 +402,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, for (const auto& v : attr_value->list().type()) { value->push_back(static_cast(v)); } - return OkStatus(); + return absl::OkStatus(); } Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -411,7 +411,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor")); *value = &attr_value->tensor(); - return OkStatus(); + return absl::OkStatus(); } bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -434,7 +434,7 @@ Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func")); *value = &attr_value->func(); - return OkStatus(); + return absl::OkStatus(); } bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, @@ -523,7 +523,7 @@ Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs, (*sig)[i] = MakeRefType((*sig)[i]); } } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -537,7 +537,7 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, if (input_types_size > input_port) { const DataType dtype = input_types[input_port]; *input_type = dtype; - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument("Input ", input_port, " not found for node ", @@ -549,7 +549,7 @@ Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, for (const auto& arg : op_def.input_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); } - return OkStatus(); + return absl::OkStatus(); } Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, @@ -561,7 +561,7 @@ Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, if (output_types_size > output_port) { const DataType dtype = output_types[output_port]; *output_type = dtype; - return OkStatus(); + return absl::OkStatus(); } } return errors::InvalidArgument("Output ", output_port, " not found for node ", @@ -573,7 +573,7 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs)); } - return OkStatus(); + return absl::OkStatus(); } Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, @@ -581,7 +581,7 @@ Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, for (const auto& arg : op_def.output_arg()) { TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs)); } - return OkStatus(); + return absl::OkStatus(); } Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, @@ -595,7 +595,7 @@ Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector outputs; TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); *num_outputs = outputs.size(); - return OkStatus(); + return absl::OkStatus(); } int OpPortIdToArgId(const NodeDef& node, @@ -718,7 +718,7 @@ Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def)); } - return OkStatus(); + return absl::OkStatus(); } namespace { // Helpers for NameRangesForNode() @@ -739,7 +739,7 @@ Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def, "Argument '", arg_def.name(), "' incorrectly specified in op definition: ", SummarizeOpDef(op_def)); } - return OkStatus(); + return absl::OkStatus(); } Status NameRangesHelper(const AttrSlice& attrs, @@ -752,7 +752,7 @@ Status NameRangesHelper(const AttrSlice& attrs, (*result)[arg.name()] = std::make_pair(start, start + num); start += num; } - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -766,7 +766,7 @@ Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, if (outputs != nullptr) { return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs); } - return OkStatus(); + return absl::OkStatus(); } void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) { @@ -866,10 +866,10 @@ const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); Status ValidateOpInput(const string& input_name, bool* is_control_input) { *is_control_input = false; if (IsValidDataInputName(input_name)) { - return OkStatus(); + return absl::OkStatus(); } else if (IsValidControlInputName(input_name)) { *is_control_input = true; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Illegal op input name '", input_name, "'"); } @@ -877,7 +877,7 @@ Status ValidateOpInput(const string& input_name, bool* is_control_input) { Status ValidateNodeName(const string& node_name) { if (IsValidNodeName(node_name)) { - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Illegal op name '", node_name, "'"); } @@ -903,7 +903,7 @@ Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) { } in_control_inputs = is_control_input; } - return OkStatus(); + return absl::OkStatus(); } Status AttachDef(const Status& status, const NodeDef& node_def, @@ -947,20 +947,20 @@ ADD_NODE_ATTR(const PartialTensorShape&) ADD_NODE_ATTR(const Tensor&) ADD_NODE_ATTR(const TensorProto&) ADD_NODE_ATTR(const NameAttrList&) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) ADD_NODE_ATTR(const std::vector&) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) -ADD_NODE_ATTR(gtl::ArraySlice) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) +ADD_NODE_ATTR(absl::Span) #undef ADD_NODE_ATTR void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) { @@ -990,7 +990,7 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, attr.set_s(frame_name); } - return OkStatus(); + return absl::OkStatus(); } Status MaybeAddPrefixToColocationConstraints( @@ -998,7 +998,7 @@ Status MaybeAddPrefixToColocationConstraints( NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { - return OkStatus(); + return absl::OkStatus(); } auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); @@ -1011,7 +1011,7 @@ Status MaybeAddPrefixToColocationConstraints( } } } - return OkStatus(); + return absl::OkStatus(); } Status MaybeUpdateColocationConstraintsWithMap( @@ -1019,7 +1019,7 @@ Status MaybeUpdateColocationConstraintsWithMap( NodeDef* node_def) { auto attr = node_def->mutable_attr()->find(kColocationAttrName); if (attr == node_def->mutable_attr()->end()) { - return OkStatus(); + return absl::OkStatus(); } auto constraints_list = attr->second.mutable_list(); auto constraints_size = constraints_list->s_size(); @@ -1032,7 +1032,7 @@ Status MaybeUpdateColocationConstraintsWithMap( } } } - return OkStatus(); + return absl::OkStatus(); } void ChangeToNoOp(NodeDef* node_def) { diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index fbba2c86892112..67bde1fc71e228 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -50,7 +50,7 @@ NodeDef ToNodeDef(NodeDefBuilder&& builder) { } void ExpectSuccess(const NodeDef& good, const OpDef& op_def) { - EXPECT_EQ(OkStatus(), ValidateNodeDef(good, op_def)) + EXPECT_EQ(absl::OkStatus(), ValidateNodeDef(good, op_def)) << "NodeDef: " << SummarizeNodeDef(good) << "; OpDef: " << SummarizeOpDef(op_def); } @@ -318,7 +318,7 @@ TEST(NodeDefUtilTest, Device) { } void ExpectValidSyntax(const NodeDef& good) { - EXPECT_EQ(OkStatus(), ValidateExternalNodeDefSyntax(good)) + EXPECT_EQ(absl::OkStatus(), ValidateExternalNodeDefSyntax(good)) << "NodeDef: " << SummarizeNodeDef(good); } diff --git a/tensorflow/core/framework/node_properties.cc b/tensorflow/core/framework/node_properties.cc index 23eda55c6da49b..4af538b3b2c1c5 100644 --- a/tensorflow/core/framework/node_properties.cc +++ b/tensorflow/core/framework/node_properties.cc @@ -33,7 +33,7 @@ Status NodeProperties::CreateFromNodeDef( props->reset(new NodeProperties(op_def, std::move(node_def), std::move(input_types), std::move(output_types))); - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/node_properties_test.cc b/tensorflow/core/framework/node_properties_test.cc index 258f413fba8c6e..5621137c7aba71 100644 --- a/tensorflow/core/framework/node_properties_test.cc +++ b/tensorflow/core/framework/node_properties_test.cc @@ -44,7 +44,7 @@ class MockOpRegistry : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override { if (op_type_name == "Foo") { *op_reg_data = &op_reg_; - return OkStatus(); + return absl::OkStatus(); } else { *op_reg_data = nullptr; return errors::InvalidArgument("Op type named ", op_type_name, diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index ccd5edcb3d37b5..3c3970506389f9 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -34,7 +34,7 @@ namespace tensorflow { Status DefaultValidator(const OpRegistryInterface& op_registry) { LOG(WARNING) << "No kernel validator registered with OpRegistry."; - return OkStatus(); + return absl::OkStatus(); } // OpRegistry ----------------------------------------------------------------- @@ -45,7 +45,7 @@ Status OpRegistryInterface::LookUpOpDef(const string& op_type_name, const OpRegistrationData* op_reg_data = nullptr; TF_RETURN_IF_ERROR(LookUp(op_type_name, &op_reg_data)); *op_def = &op_reg_data->op_def; - return OkStatus(); + return absl::OkStatus(); } OpRegistry::OpRegistry() @@ -78,7 +78,7 @@ Status OpNotFound(const string& op_type_name) { Status OpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { - if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); + if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } @@ -155,7 +155,7 @@ Status OpRegistry::SetWatcher(const Watcher& watcher) { "Cannot over-write a valid watcher with another."); } watcher_ = watcher; - return OkStatus(); + return absl::OkStatus(); } void OpRegistry::Export(bool include_internal, OpList* ops) const { @@ -217,7 +217,7 @@ bool OpRegistry::MustCallDeferred() const { } Status OpRegistry::CallDeferred() const { - if (initialized_) return OkStatus(); + if (initialized_) return absl::OkStatus(); initialized_ = true; registry_.reserve(registry_.size() + deferred_.size()); for (const auto& op_data_factory : deferred_) { @@ -227,7 +227,7 @@ Status OpRegistry::CallDeferred() const { } } deferred_.clear(); - return OkStatus(); + return absl::OkStatus(); } Status OpRegistry::RegisterAlreadyLocked( @@ -278,7 +278,7 @@ const OpRegistrationData* OpListOpRegistry::LookUp( Status OpListOpRegistry::LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const { - if ((*op_reg_data = LookUp(op_type_name))) return OkStatus(); + if ((*op_reg_data = LookUp(op_type_name))) return absl::OkStatus(); return OpNotFound(op_type_name); } diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 71bc11acb1f8ea..83aa4d8e1974dd 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -492,7 +492,7 @@ void FinalizeDoc(const string& text, OpDef* op_def, // Trim trailing blank lines from the description. while (start_l < end_l && lines[end_l - 1].empty()) --end_l; string desc = absl::StrJoin( - gtl::ArraySlice(lines.data() + start_l, end_l - start_l), "\n"); + absl::Span(lines.data() + start_l, end_l - start_l), "\n"); if (!desc.empty()) op_def->set_description(desc); // name: description @@ -687,7 +687,7 @@ Status OpDefBuilder::Finalize(OpRegistrationData* op_reg_data) const { TF_RETURN_IF_ERROR(op_reg_data->type_ctor(op_def)); } - if (errors.empty()) return OkStatus(); + if (errors.empty()) return absl::OkStatus(); return errors::InvalidArgument(absl::StrJoin(errors, "\n")); } diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index fd6e284e5c1917..1da0aa726d64ca 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -45,7 +45,7 @@ Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (auto allowed : allowed_values.list().type()) { if (dt == allowed) { - return OkStatus(); + return absl::OkStatus(); } } string allowed_str; @@ -65,7 +65,7 @@ Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { const AttrValue& allowed_values(attr.allowed_values()); for (const auto& allowed : allowed_values.list().s()) { if (str == allowed) { - return OkStatus(); + return absl::OkStatus(); } } string allowed_str; @@ -143,7 +143,7 @@ Status ValidateAttrValue(const AttrValue& attr_value, "Support for allowed_values not implemented for type ", attr.type()); } } - return OkStatus(); + return absl::OkStatus(); } const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { @@ -244,7 +244,7 @@ static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); } - return OkStatus(); + return absl::OkStatus(); } bool IsValidOpName(StringPiece sp) { @@ -343,7 +343,7 @@ Status ValidateOpDef(const OpDef& op_def) { TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); } - return OkStatus(); + return absl::OkStatus(); } #undef VALIDATE @@ -372,7 +372,7 @@ Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) { } } } - return OkStatus(); + return absl::OkStatus(); } namespace { @@ -684,7 +684,7 @@ Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { " changed from ref to non-ref"); } - return OkStatus(); + return absl::OkStatus(); } Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, @@ -723,7 +723,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op, } } - return OkStatus(); + return absl::OkStatus(); } Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { @@ -752,7 +752,7 @@ Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) { } } - return OkStatus(); + return absl::OkStatus(); } void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) { diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc index 11a17486372f21..9151e1b0448fb2 100644 --- a/tensorflow/core/framework/op_gen_lib.cc +++ b/tensorflow/core/framework/op_gen_lib.cc @@ -55,7 +55,7 @@ string WordWrap(StringPiece prefix, StringPiece str, int width) { StringPiece to_append = str.substr(0, space); str.remove_prefix(space + 1); // Remove spaces at break. - while (str_util::EndsWith(to_append, " ")) { + while (absl::EndsWith(to_append, " ")) { to_append.remove_suffix(1); } while (absl::ConsumePrefix(&str, " ")) { @@ -466,7 +466,7 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) { strings::StrCat(description, "\n", new_api_def.description_suffix()); } base_api_def->set_description(description); - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -484,11 +484,11 @@ Status ApiDefMap::LoadFileList(Env* env, const std::vector& filenames) { for (const auto& filename : filenames) { TF_RETURN_IF_ERROR(LoadFile(env, filename)); } - return OkStatus(); + return absl::OkStatus(); } Status ApiDefMap::LoadFile(Env* env, const string& filename) { - if (filename.empty()) return OkStatus(); + if (filename.empty()) return absl::OkStatus(); string contents; TF_RETURN_IF_ERROR(ReadFileToString(env, filename, &contents)); Status status = LoadApiDef(contents); @@ -498,7 +498,7 @@ Status ApiDefMap::LoadFile(Env* env, const string& filename) { status, strings::StrCat("Error parsing ApiDef file ", filename, ": ", status.message())); } - return OkStatus(); + return absl::OkStatus(); } Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { @@ -514,7 +514,7 @@ Status ApiDefMap::LoadApiDef(const string& api_def_file_contents) { TF_RETURN_IF_ERROR(MergeApiDefs(&map_[api_def.graph_op_name()], api_def)); } } - return OkStatus(); + return absl::OkStatus(); } void ApiDefMap::UpdateDocs() { diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index f8b8f81b15a67a..cd9c83bebc626f 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -93,7 +93,7 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, " expected: ", DataTypeSliceString(expected_inputs), "->", DataTypeSliceString(expected_outputs)); } - return OkStatus(); + return absl::OkStatus(); } const absl::flat_hash_set* GetOpNodeDefsToLogFromEnv() { @@ -196,7 +196,7 @@ Status OpKernel::InputRange(StringPiece input_name, int* start, } else { *start = result->second.first; *stop = result->second.second; - return OkStatus(); + return absl::OkStatus(); } } @@ -208,7 +208,7 @@ Status OpKernel::OutputRange(StringPiece output_name, int* start, } else { *start = result->second.first; *stop = result->second.second; - return OkStatus(); + return absl::OkStatus(); } } @@ -235,12 +235,13 @@ string OpKernel::ShapeTraceString(const OpKernelContext& ctx) const { } string OpKernel::TraceString(const OpKernelContext& ctx, bool verbose) const { - string trace_string = profiler::TraceMeOp(name_view(), type_string_view()); + string trace_string = + tsl::profiler::TraceMeOp(name_view(), type_string_view()); if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - trace_string = - profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}}); + trace_string = tsl::profiler::TraceMeEncode(std::move(trace_string), + {{"shape", shape}}); } } return trace_string; @@ -302,7 +303,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelConstruction::allocate_temp(DataType type, @@ -327,7 +328,7 @@ Status OpKernelConstruction::allocate_temp(DataType type, def().name(), LogMemory::OP_KERNEL_CONSTRUCTION_STEP_ID, new_temp); } *out_temp = new_temp; - return OkStatus(); + return absl::OkStatus(); } // OpKernelContext ----------------------------------------------------------- @@ -411,7 +412,7 @@ Status OpKernelContext::input(StringPiece name, const Tensor** tensor) { "' when non-ref input was expected"); } *tensor = params_->inputs[index].tensor; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { @@ -419,14 +420,14 @@ Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const { TF_RETURN_IF_ERROR(get_input_index(name, &index)); const TensorValue& value(params_->inputs[index]); *dtype = value.dtype(); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) { int index; TF_RETURN_IF_ERROR(get_input_index(name, &index)); *out_mutex = input_ref_mutex(index); - return OkStatus(); + return absl::OkStatus(); } absl::StatusOr OpKernelContext::get_input(int index) const { @@ -516,7 +517,7 @@ Status OpKernelContext::forward_input_to_output_with_shape( return errors::FailedPrecondition("OpKernel could not forward input '", input_name, "' to output '", output_name); } - return OkStatus(); + return absl::OkStatus(); } std::unique_ptr OpKernelContext::forward_input( @@ -588,7 +589,7 @@ std::unique_ptr OpKernelContext::forward_input( } Status OpKernelContext::forward_input_or_allocate_temp( - gtl::ArraySlice candidate_input_indices, DataType type, + absl::Span candidate_input_indices, DataType type, const TensorShape& shape, const AllocatorAttributes& allocator_attr, Tensor* out_temp) { for (int input_index : candidate_input_indices) { @@ -597,14 +598,14 @@ Status OpKernelContext::forward_input_or_allocate_temp( type, shape, DEVICE_MEMORY, allocator_attr); if (new_tensor != nullptr) { *out_temp = std::move(*new_tensor); - return OkStatus(); + return absl::OkStatus(); } } return allocate_temp(type, shape, out_temp, allocator_attr); } Status OpKernelContext::forward_input_or_allocate_output( - gtl::ArraySlice candidate_input_indices, int output_index, + absl::Span candidate_input_indices, int output_index, const TensorShape& output_shape, Tensor** output, int* forwarded_input) { for (int input_index : candidate_input_indices) { if (forward_input_to_output_with_shape(input_index, output_index, @@ -612,7 +613,7 @@ Status OpKernelContext::forward_input_or_allocate_output( if (forwarded_input != nullptr) { *forwarded_input = input_index; } - return OkStatus(); + return absl::OkStatus(); } } if (forwarded_input != nullptr) { @@ -622,13 +623,13 @@ Status OpKernelContext::forward_input_or_allocate_output( } Status OpKernelContext::forward_input_or_allocate_output( - gtl::ArraySlice candidate_input_names, StringPiece output_name, - const TensorShape& output_shape, Tensor** output) { + absl::Span candidate_input_names, + StringPiece output_name, const TensorShape& output_shape, Tensor** output) { for (const StringPiece& input_name : candidate_input_names) { if (forward_input_to_output_with_shape(input_name, output_name, output_shape, output) .ok()) { - return OkStatus(); + return absl::OkStatus(); } } return allocate_output(output_name, output_shape, output); @@ -662,7 +663,7 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor, tf_shared_lock l(*input_ref_mutex(index)); *tensor = *params_->inputs[index].tensor; } - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::replace_ref_input(StringPiece name, @@ -675,14 +676,14 @@ Status OpKernelContext::replace_ref_input(StringPiece name, "' when ref input was expected"); } replace_ref_input(index, tensor, lock_held); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::input_list(StringPiece name, OpInputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpInputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::mutable_input_list(StringPiece name, @@ -690,14 +691,14 @@ Status OpKernelContext::mutable_input_list(StringPiece name, int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop)); *list = OpMutableInputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::output_list(StringPiece name, OpOutputList* list) { int start, stop; TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop)); *list = OpOutputList(this, start, stop); - return OkStatus(); + return absl::OkStatus(); } void OpKernelContext::maybe_initialize_scope_id_set() { @@ -779,7 +780,7 @@ Status OpKernelContext::allocate_tensor( params_->step_id, new_tensor); } *out_tensor = std::move(new_tensor); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::allocate_output(int index, const TensorShape& shape, @@ -889,7 +890,7 @@ Status OpKernelContext::get_input_index(StringPiece name, "expected"); } *out_index = start; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::get_output_index(StringPiece name, @@ -903,21 +904,21 @@ Status OpKernelContext::get_output_index(StringPiece name, "expected"); } *out_index = start; - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, tensor); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output(index, std::move(tensor)); - return OkStatus(); + return absl::OkStatus(); } bool OpKernelContext::maybe_set_output_by_allocate_and_copy( @@ -1025,14 +1026,14 @@ Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu, int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); set_output_ref(index, mu, tensor_for_ref); - return OkStatus(); + return absl::OkStatus(); } Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) { int index; TF_RETURN_IF_ERROR(get_output_index(name, &index)); *tensor = mutable_output(index); - return OkStatus(); + return absl::OkStatus(); } bool OpKernelContext::ValidateInputsAreSameShape(OpKernel* op) { @@ -1200,7 +1201,7 @@ static Status IsProbablySafeToLoad(const string& path) { errmsg.append(absl::StrJoin(missing_features, ", ")); return errors::FailedPrecondition(errmsg); } - return OkStatus(); + return absl::OkStatus(); } void LoadDynamicKernelsInternal() { @@ -1453,7 +1454,7 @@ Status FindKernelRegistration( } } - return OkStatus(); + return absl::OkStatus(); } Status FindKernelRegistration(const DeviceType& device_type, @@ -1517,7 +1518,7 @@ Status FindKernelDef( } if (def != nullptr) *def = ®->def; if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name; - return OkStatus(); + return absl::OkStatus(); } Status FindKernelDef(const DeviceType& device_type, const NodeDef& node_def, @@ -1596,7 +1597,7 @@ Status SupportedDeviceTypesForNode( prioritized_device_types->push_back(std::make_pair(device_type, 0)); } } - return OkStatus(); + return absl::OkStatus(); } void LogAllRegisteredKernels() { @@ -1782,7 +1783,7 @@ Status ValidateKernelRegistrations(const OpRegistryInterface& op_registry) { } } } - return OkStatus(); + return absl::OkStatus(); } template <> diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc index a4373446481d93..bea1208053c5e2 100644 --- a/tensorflow/core/framework/op_kernel_test.cc +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -422,8 +422,8 @@ TEST_F(OpKernelTest, InputDtype) { Tensor a(DT_FLOAT, TensorShape({})); Tensor b(DT_INT32, TensorShape({})); Tensor c(DT_UINT8, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), - TensorValue(&c)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b), + TensorValue(&c)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -448,7 +448,7 @@ TEST_F(OpKernelTest, InputOnly) { EXPECT_TRUE(status.ok()); params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a)}; + absl::InlinedVector inputs{TensorValue(&a)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -475,8 +475,8 @@ TEST_F(OpKernelTest, RefInputs) { Tensor* a = new Tensor(DT_FLOAT, TensorShape({})); Tensor* b = new Tensor(DT_FLOAT, TensorShape({2})); mutex mu_a, mu_b; - gtl::InlinedVector inputs{TensorValue(&mu_a, a), - TensorValue(&mu_b, b)}; + absl::InlinedVector inputs{TensorValue(&mu_a, a), + TensorValue(&mu_b, b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); @@ -502,7 +502,7 @@ TEST_F(OpKernelTest, AllocateOutput) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({})); Tensor b(DT_INT32, TensorShape({})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); Tensor* output = nullptr; @@ -566,7 +566,7 @@ class ScopedAllocatorDevice : public DeviceBase { StatusCallback done) override { CHECK(input_tensor->NumElements() == output_tensor->NumElements()); tensor::DeepCopy(*input_tensor, output_tensor); - done(OkStatus()); + done(absl::OkStatus()); } // Return the count of calls to GetAllocator or GetScopedAllocator, depending @@ -641,7 +641,7 @@ TEST_F(OpKernelTest, TraceString) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({4, 8})); - gtl::InlinedVector inputs{TensorValue(&a)}; + absl::InlinedVector inputs{TensorValue(&a)}; params.inputs = inputs; params.op_kernel = op.get(); @@ -1162,7 +1162,7 @@ void BM_TraceString(::testing::benchmark::State& state) { params.op_kernel = op.get(); Tensor a(DT_FLOAT, TensorShape({99000, 256})); Tensor b(DT_FLOAT, TensorShape({256, 256})); - gtl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; + absl::InlinedVector inputs{TensorValue(&a), TensorValue(&b)}; params.inputs = inputs; auto ctx = std::make_unique(¶ms); diff --git a/tensorflow/core/framework/op_registration_test.cc b/tensorflow/core/framework/op_registration_test.cc index af80036272a367..286a0db358702c 100644 --- a/tensorflow/core/framework/op_registration_test.cc +++ b/tensorflow/core/framework/op_registration_test.cc @@ -27,7 +27,7 @@ namespace { void Register(const string& op_name, OpRegistry* registry) { registry->Register([op_name](OpRegistrationData* op_reg_data) -> Status { op_reg_data->op_def.set_name(op_name); - return OkStatus(); + return absl::OkStatus(); }); } @@ -51,7 +51,7 @@ TEST(OpRegistrationTest, TestDuplicate) { TF_EXPECT_OK( registry->SetWatcher([](const Status& s, const OpDef& op_def) -> Status { EXPECT_TRUE(errors::IsAlreadyExists(s)); - return OkStatus(); + return absl::OkStatus(); })); Register("Foo", registry.get()); s = registry->ProcessRegistrations(); diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index 42651c8c6dde6c..6af4d8973b3e1c 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -46,7 +46,7 @@ Status OpSegment::FindOrCreate(const string& session_handle, } *kernel = gtl::FindPtrOrNull(item->name_kernel, node_name); if (*kernel != nullptr) { - return OkStatus(); + return absl::OkStatus(); } } Status s = create_fn(kernel); @@ -68,7 +68,7 @@ Status OpSegment::FindOrCreate(const string& session_handle, *kernel = *p_kernel; } } - return OkStatus(); + return absl::OkStatus(); } void OpSegment::AddHold(const string& session_handle) { diff --git a/tensorflow/core/framework/ops_util.cc b/tensorflow/core/framework/ops_util.cc index b53fb3e6c2b70c..abe57812774933 100644 --- a/tensorflow/core/framework/ops_util.cc +++ b/tensorflow/core/framework/ops_util.cc @@ -59,7 +59,7 @@ Status GetBroadcastSize(const int index, const int in_size, const int ksize, if (*bindex + ksize > in_size) { *bsize = std::min((in_size - *bindex), ksize); } - return OkStatus(); + return absl::OkStatus(); } string SanitizeThreadSuffix(string suffix) { diff --git a/tensorflow/core/framework/partial_tensor_shape_test.cc b/tensorflow/core/framework/partial_tensor_shape_test.cc index e20a585ff3b9e7..77f81cc5a8a549 100644 --- a/tensorflow/core/framework/partial_tensor_shape_test.cc +++ b/tensorflow/core/framework/partial_tensor_shape_test.cc @@ -17,13 +17,13 @@ limitations under the License. #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { @@ -295,14 +295,14 @@ TEST(PartialTensorShapeTest, PartialShapeMergeWith) { const PartialTensorShape e; PartialTensorShape test; - EXPECT_EQ(OkStatus(), a.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(b, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(b, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), 1); EXPECT_EQ(test.dim_size(1), 0); @@ -312,28 +312,28 @@ TEST(PartialTensorShapeTest, PartialShapeMergeWith) { EXPECT_TRUE(errors::IsInvalidArgument(a.MergeWith(d, &test))); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(c, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(c, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), c.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), c.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), a.MergeWith(e, &test)); + EXPECT_EQ(absl::OkStatus(), a.MergeWith(e, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); EXPECT_EQ(test.dim_size(2), 1); test = PartialTensorShape(); - EXPECT_EQ(OkStatus(), e.MergeWith(a, &test)); + EXPECT_EQ(absl::OkStatus(), e.MergeWith(a, &test)); EXPECT_EQ(test.dims(), 3); EXPECT_EQ(test.dim_size(0), -1); EXPECT_EQ(test.dim_size(1), 0); diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc index 2bc23d0b8a6d30..2e433fb1359d5a 100644 --- a/tensorflow/core/framework/reader_base.cc +++ b/tensorflow/core/framework/reader_base.cc @@ -50,7 +50,7 @@ Status ReaderBase::ResetLocked() { work_finished_ = 0; num_records_produced_ = 0; work_.clear(); - return OkStatus(); + return absl::OkStatus(); } Status ReaderBase::SerializeState(tstring* state) { @@ -261,7 +261,7 @@ Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) { "Inconsistent work started vs. finished when restoring in ", name(), ": ", debug_string); } - return OkStatus(); + return absl::OkStatus(); } } // namespace tensorflow diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h index 8f4e347e09aa99..644a5618f7564e 100644 --- a/tensorflow/core/framework/reader_base.h +++ b/tensorflow/core/framework/reader_base.h @@ -64,8 +64,8 @@ class ReaderBase : public ReaderInterface { bool* at_end); // Called when work starts / finishes. - virtual Status OnWorkStartedLocked() { return OkStatus(); } - virtual Status OnWorkFinishedLocked() { return OkStatus(); } + virtual Status OnWorkStartedLocked() { return absl::OkStatus(); } + virtual Status OnWorkFinishedLocked() { return absl::OkStatus(); } // Called to reset the Reader to a newly constructed state. virtual Status ResetLocked(); diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h index 36f59717e0e9db..1433a54e5e7d12 100644 --- a/tensorflow/core/framework/reader_op_kernel.h +++ b/tensorflow/core/framework/reader_op_kernel.h @@ -76,7 +76,7 @@ class ReaderOpKernel : public ResourceOpKernel { } std::function temp = nullptr; factory_.swap(temp); - return OkStatus(); + return absl::OkStatus(); } std::function factory_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/core/framework/rendezvous.cc b/tensorflow/core/framework/rendezvous.cc index efea3e2597c803..1792a1c1fed17d 100644 --- a/tensorflow/core/framework/rendezvous.cc +++ b/tensorflow/core/framework/rendezvous.cc @@ -109,7 +109,7 @@ Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) { out->src_device = StringPiece(parts[0].data(), parts[0].size()); out->dst_device = StringPiece(parts[2].data(), parts[2].size()); out->edge_name = StringPiece(parts[3].data(), parts[3].size()); - return OkStatus(); + return absl::OkStatus(); } return errors::InvalidArgument("Invalid rendezvous key: ", key); } diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index 1212fadfc1bdc8..1c52e259ba55b1 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -403,7 +403,7 @@ class DummyDeviceContext : public DeviceContext { void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, Tensor* output_tensor, StatusCallback done) const override { - done(OkStatus()); + done(absl::OkStatus()); } private: diff --git a/tensorflow/core/framework/resource_handle.cc b/tensorflow/core/framework/resource_handle.cc index bc6e459a6566e9..0fe49206846a5f 100644 --- a/tensorflow/core/framework/resource_handle.cc +++ b/tensorflow/core/framework/resource_handle.cc @@ -96,7 +96,7 @@ Status ResourceHandle::FromProto(const ResourceHandleProto& proto) { dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{dtype, shape}); } dtypes_and_shapes_ = std::move(dtypes_and_shapes); - return OkStatus(); + return absl::OkStatus(); } string ResourceHandle::SerializeAsString() const { @@ -157,7 +157,7 @@ Status ResourceHandle::ValidateType(const TypeIndex& type_index) const { port::Demangle(type_index.name()), "' (hash code ", type_index.hash_code(), ")"); } - return OkStatus(); + return absl::OkStatus(); } std::atomic ResourceHandle::current_id_; diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 872665170ae08a..a738f8d735addd 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -61,7 +61,7 @@ Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, context->allocate_output(output_index, TensorShape({}), &handle)); handle->scalar()() = MakeResourceHandle(container, name, *context->device(), type_index); - return OkStatus(); + return absl::OkStatus(); } namespace internal { @@ -72,7 +72,7 @@ Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p) { "Trying to access resource ", p.name(), " located in device ", p.device(), " from device ", ctx->device()->attributes().name()); } - return OkStatus(); + return absl::OkStatus(); } } // end namespace internal @@ -84,7 +84,7 @@ Status ResourceMgr::InsertDebugTypeName(uint64 hash_code, return errors::AlreadyExists("Duplicate hash code found for type ", type_name); } - return OkStatus(); + return absl::OkStatus(); } const char* ResourceMgr::DebugTypeName(uint64 hash_code) const { @@ -219,7 +219,7 @@ Status ResourceMgr::DoCreate(const string& container_name, TypeIndex type, auto st = container->insert(std::move(key_and_value)); if (st.second) { TF_RETURN_IF_ERROR(InsertDebugTypeName(type.hash_code(), type.name())); - return OkStatus(); + return absl::OkStatus(); } return errors::AlreadyExists("Resource ", container_name, "/", name, "/", type.name()); @@ -259,7 +259,7 @@ Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code, type_name, " has been destroyed."); } *resource = ptr; - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::PopResourceAndName(const string& container, @@ -279,7 +279,7 @@ Status ResourceMgr::PopResourceAndName(const string& container, } std::swap(resource_and_name, iter->second); b->erase(iter); - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, @@ -297,7 +297,7 @@ Status ResourceMgr::DoDelete(const string& container, uint64 type_hash_code, "This indicates ref-counting ResourceHandle is exposed to weak " "ResourceHandle code paths."); } - return OkStatus(); + return absl::OkStatus(); } Status ResourceMgr::DoDelete(const string& container, TypeIndex type, @@ -315,7 +315,7 @@ Status ResourceMgr::Cleanup(const string& container) { tf_shared_lock l(mu_); if (!gtl::FindOrNull(containers_, container)) { // Nothing to cleanup. - return OkStatus(); + return absl::OkStatus(); } } Container* b = nullptr; @@ -324,14 +324,14 @@ Status ResourceMgr::Cleanup(const string& container) { auto iter = containers_.find(container); if (iter == containers_.end()) { // Nothing to cleanup, it's OK (concurrent cleanup). - return OkStatus(); + return absl::OkStatus(); } b = iter->second; containers_.erase(iter); } CHECK(b != nullptr); delete b; - return OkStatus(); + return absl::OkStatus(); } static bool IsValidContainerName(StringPiece s) { @@ -373,7 +373,7 @@ Status ContainerInfo::Init(ResourceMgr* rmgr, const NodeDef& ndef, static std::atomic counter(0); name_ = strings::StrCat("_", counter.fetch_add(1), "_", ndef.name()); } - return OkStatus(); + return absl::OkStatus(); } string ContainerInfo::DebugString() const { @@ -394,7 +394,7 @@ Status HandleFromInput(OpKernelContext* ctx, int input, return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status HandleFromInput(OpKernelContext* ctx, StringPiece input, @@ -405,7 +405,7 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input, return absl::InvalidArgumentError("Empty resource handle"); } *handle = tensor->flat()(0); - return OkStatus(); + return absl::OkStatus(); } Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, @@ -414,7 +414,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, if (p.IsRefCounting()) { TF_ASSIGN_OR_RETURN(*value, p.GetResource()); (*value)->Ref(); - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Lookup(p, value); } @@ -422,7 +422,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); if (p.IsRefCounting()) { - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Delete(p); } diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index b13de22dd49e99..658ed31ebfea9f 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -680,7 +680,7 @@ Status ResourceMgr::LookupMany( (*resources)[i].reset(resource); } } - return OkStatus(); + return absl::OkStatus(); } // Simple wrapper to allow conditional dynamic / static casts. @@ -777,7 +777,7 @@ template Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) { TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p)); TF_RETURN_IF_ERROR(p.ValidateType()); - return OkStatus(); + return absl::OkStatus(); } } // namespace internal @@ -804,7 +804,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, TF_ASSIGN_OR_RETURN(*value, p.GetResource()); // Transfers out a new reference. (*value)->Ref(); - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Lookup(p.container(), @@ -825,7 +825,7 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, TF_RETURN_IF_ERROR(LookupResource(ctx, p, &raw_ptr)); value->reset(raw_ptr); - return OkStatus(); + return absl::OkStatus(); } // Similar to Lookup, but looks up multiple resources at once, with only a @@ -872,7 +872,7 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, p, &raw_ptr, creator)); value->reset(raw_ptr); - return OkStatus(); + return absl::OkStatus(); } // Deletes the resource pointed by "p", using the resource manager in "ctx". @@ -883,7 +883,7 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { // NOTE(feyu): if we can convert all resources handle to ref-counting, then // DeleteResource can be removed. if (p.IsRefCounting()) { - return OkStatus(); + return absl::OkStatus(); } return ctx->resource_manager()->Delete(p.container(), p.name()); } diff --git a/tensorflow/core/framework/resource_mgr_test.cc b/tensorflow/core/framework/resource_mgr_test.cc index 5c079cb2ac7318..6b12270ab97528 100644 --- a/tensorflow/core/framework/resource_mgr_test.cc +++ b/tensorflow/core/framework/resource_mgr_test.cc @@ -73,7 +73,7 @@ string LookupOrCreate(ResourceMgr* rm, const string& container, T* r; TF_CHECK_OK(rm->LookupOrCreate(container, name, &r, [&label](T** ret) { *ret = new T(label); - return OkStatus(); + return absl::OkStatus(); })); const string ret = r->DebugString(); r->Unref(); @@ -240,7 +240,7 @@ TEST(ResourceMgrTest, CreateOrLookupRaceCondition) { Env::Default()->SleepForMicroseconds(1 * 1000 * 1000); atomic_int += 1; *ret = new Resource("label"); - return OkStatus(); + return absl::OkStatus(); })); r->Unref(); }); @@ -265,7 +265,7 @@ Status ComputePolicy(const string& attr_container, } TF_RETURN_IF_ERROR(cinfo.Init(&rmgr, ndef, use_node_name_as_default)); *result = cinfo.DebugString(); - return OkStatus(); + return absl::OkStatus(); } string Policy(const string& attr_container, const string& attr_shared_name, diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc index 932fb4aee942bb..8c0b32d352fe2d 100644 --- a/tensorflow/core/framework/run_handler_util.cc +++ b/tensorflow/core/framework/run_handler_util.cc @@ -73,7 +73,7 @@ std::vector ParamFromEnvWithDefault(const char* var_name, bool ParamFromEnvBoolWithDefault(const char* var_name, bool default_value) { const char* val = std::getenv(var_name); - return (val) ? str_util::Lowercase(val) == "true" : default_value; + return (val) ? absl::AsciiStrToLower(val) == "true" : default_value; } void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index d74366937210c9..71d856eaeebb6b 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -1288,6 +1288,10 @@ bool InferenceContext::RelaxHandleShapesAndMergeTypes( bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " inputs."; if (output_handle_shapes_and_types_[idx] == nullptr) { output_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); @@ -1299,6 +1303,10 @@ bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; if (input_handle_shapes_and_types_[idx] == nullptr) { input_handle_shapes_and_types_[idx].reset( new std::vector(shapes_and_types)); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index f00dac88fd0388..6ed932e0c78189 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -683,6 +683,10 @@ class InferenceContext { void set_input_handle_shapes_and_types( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; input_handle_shapes_and_types_[idx] = absl::make_unique>(shapes_and_types); } @@ -690,17 +694,29 @@ class InferenceContext { // Returns the output handle shapes and types, for the resource tensor output // at index . Returns NULL if the shape and types were never set. const std::vector* output_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " outputs."; return output_handle_shapes_and_types_[idx].get(); } // Returns the inputs handle shapes and types, for the resource tensor input // at index . Returns NULL if the shape and types were not available. const std::vector* input_handle_shapes_and_types(int idx) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, input_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << input_handle_shapes_and_types_.size() << " inputs."; return input_handle_shapes_and_types_[idx].get(); } void set_output_handle_shapes_and_types( int idx, const std::vector& shapes_and_types) { + CHECK_GE(idx, 0) << "idx must be non-negative. Got idx: " << idx << "."; + CHECK_LT(idx, output_handle_shapes_and_types_.size()) + << "Got idx: " << idx << " but only " + << output_handle_shapes_and_types_.size() << " inputs."; output_handle_shapes_and_types_[idx] = absl::make_unique>(shapes_and_types); } diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index 63c5156dd22664..34574b6e54ede1 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -138,7 +138,7 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, } // Verify the dimensions. - CHECK(absl::StartsWith(expected, "[") && str_util::EndsWith(expected, "]")) + CHECK(absl::StartsWith(expected, "[") && absl::EndsWith(expected, "]")) << expected; expected.remove_prefix(1); expected.remove_suffix(1); diff --git a/tensorflow/core/framework/tensor_fuzz.cc b/tensorflow/core/framework/tensor_fuzz.cc index 5665185f121923..49f91b021bf9fc 100644 --- a/tensorflow/core/framework/tensor_fuzz.cc +++ b/tensorflow/core/framework/tensor_fuzz.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "fuzztest/fuzztest.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/security/fuzzing/cc/core/framework/datatype_domains.h" #include "tensorflow/security/fuzzing/cc/core/framework/tensor_domains.h" #include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow::fuzzing { namespace { diff --git a/tensorflow/core/framework/tensor_shape_fuzz.cc b/tensorflow/core/framework/tensor_shape_fuzz.cc index 7a0351ad0e3897..d14284e5530c96 100644 --- a/tensorflow/core/framework/tensor_shape_fuzz.cc +++ b/tensorflow/core/framework/tensor_shape_fuzz.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "fuzztest/fuzztest.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/security/fuzzing/cc/core/framework/tensor_shape_domains.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace fuzzing { diff --git a/tensorflow/core/framework/tensor_shape_test.cc b/tensorflow/core/framework/tensor_shape_test.cc index e55cefacdfbcc8..c13a16fd3c8004 100644 --- a/tensorflow/core/framework/tensor_shape_test.cc +++ b/tensorflow/core/framework/tensor_shape_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/random/simple_philox.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { class TensorShapeTestHelper { diff --git a/tensorflow/core/framework/typed_allocator.h b/tensorflow/core/framework/typed_allocator.h index 20e16358f2c4c3..6d89983b2fb575 100644 --- a/tensorflow/core/framework/typed_allocator.h +++ b/tensorflow/core/framework/typed_allocator.h @@ -56,7 +56,8 @@ class TypedAllocator { size_t num_elements) { if (ptr) { RunDtor(raw_allocator, ptr, num_elements); - raw_allocator->DeallocateRaw(ptr); + raw_allocator->DeallocateRaw(ptr, Allocator::kAllocatorAlignment, + sizeof(T) * num_elements); } } diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index 5795ee7c082e56..d1e42814d75f92 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -156,7 +156,7 @@ string DataTypeString(DataType dtype) { } bool DataTypeFromString(StringPiece sp, DataType* dt) { - if (str_util::EndsWith(sp, "_ref")) { + if (absl::EndsWith(sp, "_ref")) { sp.remove_suffix(4); DataType non_ref; if (DataTypeFromString(sp, &non_ref) && !IsRefType(non_ref)) { diff --git a/tensorflow/core/graph/graph_debug_info_builder.cc b/tensorflow/core/graph/graph_debug_info_builder.cc index 36d626b838534a..015494c181ed70 100644 --- a/tensorflow/core/graph/graph_debug_info_builder.cc +++ b/tensorflow/core/graph/graph_debug_info_builder.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/hash/hash.h" @@ -98,6 +99,10 @@ absl::Span FrozenStackTrace::ToFrames() const { return frames_; } +std::vector FrozenStackTrace::ToUncachedFrames() const { + return frames_; +} + StackFrame FrozenStackTrace::LastUserFrame() const { return frames_.back(); } std::vector FrozenStackTrace::GetUserFrames(int limit) const { diff --git a/tensorflow/core/graph/graph_debug_info_builder.h b/tensorflow/core/graph/graph_debug_info_builder.h index 086aa76521ddd4..b1c8fcef703c3a 100644 --- a/tensorflow/core/graph/graph_debug_info_builder.h +++ b/tensorflow/core/graph/graph_debug_info_builder.h @@ -51,6 +51,9 @@ class AbstractStackTrace { // The returned span is alive as long as the AbstractStackTrace is alive. virtual absl::Span ToFrames() const = 0; + // Returns the stack frames without caching any generated data. + virtual std::vector ToUncachedFrames() const = 0; + // Returns the last stack frame from user code, attempting to ignore the // framework code. Returns an empty frame if no such stack frame was found. virtual StackFrame LastUserFrame() const = 0; @@ -84,6 +87,8 @@ class FrozenStackTrace : public AbstractStackTrace { absl::Span ToFrames() const override; + std::vector ToUncachedFrames() const override; + StackFrame LastUserFrame() const override; std::vector GetUserFrames(int limit) const override; diff --git a/tensorflow/core/graph/graph_debug_info_builder_test.cc b/tensorflow/core/graph/graph_debug_info_builder_test.cc index f7e5a6c01f68ad..cbe4a8a8ae9287 100644 --- a/tensorflow/core/graph/graph_debug_info_builder_test.cc +++ b/tensorflow/core/graph/graph_debug_info_builder_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/graph_debug_info.pb.h" +#include "tensorflow/core/platform/stack_frame.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -38,6 +39,8 @@ class TestStackTrace : public AbstractStackTrace { absl::Span ToFrames() const override { return frames_; } + std::vector ToUncachedFrames() const override { return frames_; } + std::vector GetUserFrames(int limit) const override { return frames_; } @@ -219,5 +222,18 @@ TEST(StackTracesMapToGraphDebugInfoTest, RoundTripStackTraces) { } } +TEST(StackTracesTest, ToFrames) { + StackTracesMap map; + std::vector frames = { + StackFrame({"dummy_file_name", 10, "dummy_function_name"}), + StackFrame({"other_file_name", 20, "other_function_name"})}; + auto stack_trace = TestStackTrace(frames); + EXPECT_EQ(stack_trace.ToFrames().size(), 2); + auto uncached_frames = stack_trace.ToUncachedFrames(); + EXPECT_EQ(uncached_frames.size(), 2); + EXPECT_EQ(frames[0], uncached_frames[0]); + EXPECT_EQ(frames[1], uncached_frames[1]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 638a6a33f9395f..cf159922c51daa 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -386,8 +386,8 @@ void MutableGraphView::AddAndDedupFanouts(NodeDef* node) { fanouts()[output].emplace(node, Graph::kControlSlot); } else { max_input_port = pos; - max_regular_output_port()[output.node] = - std::max(max_regular_output_port()[output.node], output.port_id); + int& max_port = max_regular_output_port()[output.node]; + max_port = std::max(max_port, output.port_id); fanouts()[output].emplace(node, pos); } ++pos; diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index e0981fe90c8ae9..2bf4de1ba86033 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -444,7 +444,7 @@ bool IsQuantizedMatMul(const NodeDef& node) { } bool IsQueue(const NodeDef& node) { - return str_util::EndsWith(node.op(), "QueueV2"); + return absl::EndsWith(node.op(), "QueueV2"); } bool IsRandomShuffle(const NodeDef& node) { diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index bfed9693e0dcbe..6f867024bb9000 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -8,7 +8,6 @@ package( "//tensorflow/core/data:__pkg__", "//tensorflow/core/data/service:__pkg__", "//tensorflow/core/grappler/optimizers/data:__subpackages__", - "//tensorflow/core/kernels/data:__pkg__", "//tensorflow/core/kernels/data/experimental:__pkg__", ], licenses = ["notice"], @@ -1022,7 +1021,7 @@ tf_cc_test( "//tensorflow/core/platform:status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc index f9d4f063618811..a212e250510002 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -178,7 +178,8 @@ NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, StringPiece num_parallel_calls_node_name, StringPiece function_name, - StringPiece deterministic) { + StringPiece deterministic, + bool use_unbounded_threadpool) { return test::function::NDef( name, "ParallelMapDatasetV2", {string(input_node_name), string(num_parallel_calls_node_name)}, @@ -188,6 +189,7 @@ NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, {"output_shapes", absl::Span{}}, {"output_types", absl::Span{}}, {"deterministic", string(deterministic)}, + {"use_unbounded_threadpool", use_unbounded_threadpool}, }); } diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h index 7341329ac36030..c5823d1a38607c 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -89,7 +89,8 @@ NodeDef MakeParallelMapNode(StringPiece name, StringPiece input_node_name, NodeDef MakeParallelMapV2Node(StringPiece name, StringPiece input_node_name, StringPiece num_parallel_calls_node_name, StringPiece function_name, - StringPiece deterministic); + StringPiece deterministic, + bool use_unbounded_threadpool); // Creates a test NodeDef for ParseExampleDataset. NodeDef MakeParseExampleNode(StringPiece name, StringPiece input_node_name, diff --git a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc index 54bf5fe97a4732..5cb93faf3365fa 100644 --- a/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc +++ b/tensorflow/core/grappler/optimizers/data/inject_io_prefetch_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/function_testlib.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace grappler { @@ -101,7 +101,7 @@ GraphDef EligibleMapCase() { {{"value", 1}, {"dtype", DT_INT32}}), graph_tests_utils::MakeParallelMapV2Node( "map_1", "io_1", "num_parallel_calls_1", "noop_1", - /*deterministic=*/"default"), + /*deterministic=*/"default", /*use_unbounded_threadpool=*/false), NDef("files_2", "Const", {}, {{"value", "file1file2"}, {"dtype", DT_STRING}}), @@ -114,7 +114,7 @@ GraphDef EligibleMapCase() { {{"value", 1}, {"dtype", DT_INT32}}), graph_tests_utils::MakeParallelMapV2Node( "map_2", "io_2", "num_parallel_calls_2", "noop_2", - /*deterministic=*/"default"), + /*deterministic=*/"default", /*use_unbounded_threadpool=*/false), NDef("zip", "ZipDataset", {"map_1", "map_2"}, {}), NDef("Sink", "Identity", {"zip"}, {})}, diff --git a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc index 1b76fee5103640..1ff66f3dfec09c 100644 --- a/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc +++ b/tensorflow/core/grappler/optimizers/data/make_deterministic_test.cc @@ -85,7 +85,7 @@ TEST_P(SplitMapTest, SplitMapFunction) { } else { orig_map_node_def = graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", "MyFunction", - deterministic ? "true" : "false"); + deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false); } orig_map_node_def.add_input("^start"); AttrValue* attr_val = &(*orig_map_node_def.mutable_attr())["Targuments"]; @@ -321,7 +321,8 @@ TEST_P(MakeDeterministicTest, NoRewriteMap) { {{"value", 1}, {"dtype", DT_INT32}}), graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", func_name, - deterministic ? "true" : "false")}, + deterministic ? "true" : "false", + /*use_unbounded_threadpool=*/false)}, // FunctionLib {test::function::XTimesTwo(), OuterXTimesTwo()}); @@ -387,7 +388,8 @@ TEST_P(MakeDeterministicTest, NoRewritePrefetch) { {{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}), graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", func_name, - deterministic ? "true" : "false"), + deterministic ? "true" : "false", + /*use_unbounded_threadpool=*/false), graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")}, // FunctionLib {test::function::RandomUniform(), OuterRandomUniform()}); @@ -485,7 +487,7 @@ TEST_P(RewriteMapWithoutSplitTest, RewriteMapWithoutSplit) { NodeDef map_node_def = graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", func_name, - deterministic ? "true" : "false"); + deterministic ? "true" : "false", /*use_unbounded_threadpool=*/false); map_node_def.add_input("^start"); // Rewrite occurs due to parallelism in map function @@ -587,7 +589,8 @@ TEST_P(MakeDeterministicTest, RewritePrefetch) { {{"value", Tensor(int64_t{1})}, {"dtype", DT_INT64}}), graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", func_name, - deterministic ? "true" : "false"), + deterministic ? "true" : "false", + /*use_unbounded_threadpool=*/false), graph_tests_utils::MakePrefetchNode("prefetch", "map", "buffer_size")}, // FunctionLib {test::function::ReadResourceVariable(), OuterReadResourceVariable()}); diff --git a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc index 207a9cdb447598..bf1542022d9d74 100644 --- a/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc +++ b/tensorflow/core/grappler/optimizers/data/make_sloppy_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -147,7 +146,7 @@ TEST(ChangeDefault, ParallelMap) { {{"value", 1}, {"dtype", DT_INT32}}), graph_tests_utils::MakeParallelMapV2Node( "map", "range", "num_parallel_calls", "XTimesTwo", - /*deterministic=*/"default")}, + /*deterministic=*/"default", /*use_unbounded_threadpool=*/false)}, // FunctionLib { test::function::XTimesTwo(), diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 69943e81044728..091e94dc2305d0 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -124,6 +124,11 @@ Status MapAndBatchFusion::OptimizeAndCollectStats(Cluster* cluster, if (node2->op() != "MapDataset" && !IsParallelMap(*node2)) { continue; } + // Do not fuse ParallelMap node that uses the unbounded thread pool. + if (node2->attr().find("use_unbounded_threadpool") != node2->attr().end() && + node2->attr().at("use_unbounded_threadpool").b()) { + continue; + } // Use a more descriptive variable name now that we know the node type. NodeDef* map_node = node2; diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc index 74947cbb5e669b..077123ebf61184 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc @@ -402,6 +402,71 @@ TEST(MapAndBatchFusionTest, NoChange) { EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); } +TEST(MapAndBatchFusionTest, NoChange_UnboundedThreadpoolParallelMap) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + NodeDef *start_node = graph_utils::AddScalarConstNode(0, &graph); + NodeDef *stop_node = graph_utils::AddScalarConstNode(10, &graph); + NodeDef *step_node = graph_utils::AddScalarConstNode(1, &graph); + + std::vector range_inputs(3); + range_inputs[0] = start_node->name(); + range_inputs[1] = stop_node->name(); + range_inputs[2] = step_node->name(); + std::vector> range_attrs; + NodeDef *range_node = graph_utils::AddNode("", "RangeDataset", range_inputs, + range_attrs, &graph); + NodeDef *captured_input_node = + graph_utils::AddScalarConstNode("hello", &graph); + NodeDef *num_parallel_calls_node = + graph_utils::AddScalarConstNode(2, &graph); + + NodeDef *map_node; + { + std::vector map_inputs(3); + map_inputs[0] = range_node->name(); + map_inputs[1] = captured_input_node->name(); + map_inputs[2] = num_parallel_calls_node->name(); + std::vector> map_attrs(3); + AttrValue f_attr; + SetAttrValue("f", &f_attr); + map_attrs[0] = std::make_pair("f", f_attr); + AttrValue args_attr; + SetAttrValue("Targuments", &args_attr); + map_attrs[1] = std::make_pair("Targuments", args_attr); + AttrValue use_unbounded_threadpool_attr; + SetAttrValue(true, &use_unbounded_threadpool_attr); + map_attrs[2] = std::make_pair("use_unbounded_threadpool", + use_unbounded_threadpool_attr); + map_node = graph_utils::AddNode("", "ParallelMapDataset", map_inputs, + map_attrs, &graph); + } + + NodeDef *batch_size_node = + graph_utils::AddScalarConstNode(5, &graph); + NodeDef *batch_node; + { + std::vector batch_inputs(2); + batch_inputs[0] = map_node->name(); + batch_inputs[1] = batch_size_node->name(); + std::vector> batch_attrs(2); + AttrValue shapes_attr; + SetAttrValue("output_shapes", &shapes_attr); + batch_attrs[0] = std::make_pair("output_shapes", shapes_attr); + AttrValue types_attr; + SetAttrValue("output_types", &types_attr); + batch_attrs[1] = std::make_pair("output_types", types_attr); + batch_node = graph_utils::AddNode("", "BatchDataset", batch_inputs, + batch_attrs, &graph); + } + + MapAndBatchFusion optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_TRUE(graph_utils::Compare(*graph.graph(), output)); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index d2bf6a3ea27c0c..78e9eba0fdc07d 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -216,10 +216,22 @@ Status MapFusion::OptimizeAndCollectStats(Cluster* cluster, for (const NodeDef& node : sorted_old_graph.node()) { const NodeDef* map_node = get_map_node(node); if (!map_node) continue; + // Do not fuse ParallelMap node that uses the unbounded thread pool. + if (map_node->attr().find("use_unbounded_threadpool") != + map_node->attr().end() && + map_node->attr().at("use_unbounded_threadpool").b()) { + continue; + } const NodeDef* parent_map_node = get_map_node(*graph_utils::GetInputNode(*map_node, graph)); if (!parent_map_node) continue; + // Do not fuse ParallelMap node that uses the unbounded thread pool. + if (parent_map_node->attr().find("use_unbounded_threadpool") != + parent_map_node->attr().end() && + parent_map_node->attr().at("use_unbounded_threadpool").b()) { + continue; + } // TODO(b/148614504): Support fusing different types of map operations. if (parent_map_node->op() != map_node->op()) continue; diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index c81191ecd823df..a773d9bcf1a1ff 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -28,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" namespace tensorflow { @@ -88,9 +88,11 @@ TEST_P(AutotuneSetting, MapFusionTest) { NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), num_parallel_calls_node, MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(), - "XTimesTwo", "default"), + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/false), MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(), - "XTimesTwo", "default")}, + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/false)}, // FunctionLib { test::function::XTimesTwo(), @@ -171,9 +173,11 @@ TEST(MapFusionTest, FuseTwoParallelMapNodesIntoOne) { NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), num_parallel_calls_node, MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(), - "XTimesTwo", "default"), + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/false), MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(), - "XTimesTwo", "default")}, + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/false)}, // FunctionLib { test::function::XTimesTwo(), @@ -187,6 +191,36 @@ TEST(MapFusionTest, FuseTwoParallelMapNodesIntoOne) { EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); } +TEST(MapFusionTest, NoChange_UnboundedThreadpoolParallelMap) { + using test::function::NDef; + GrapplerItem item; + NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper( + "num_parallel_calls", DT_INT64, + [](TensorProto* proto) { proto->add_int64_val(-1); }); + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + num_parallel_calls_node, + MakeParallelMapV2Node("map1", "range", num_parallel_calls_node.name(), + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/true), + MakeParallelMapV2Node("map2", "map1", num_parallel_calls_node.name(), + "XTimesTwo", "default", + /*use_unbounded_threadpool=*/false)}, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapFusion optimizer; + GraphDef output; + TF_ASSERT_OK(OptimizeWithMapFusion(item, &output, true)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + TEST(MapFusionTest, FusedNodesAndFunctionsAreNamedAfterOldNodesAndFunctions) { using test::function::NDef; NodeDef num_parallel_calls_node = CreateScalarConstNodeHelper( @@ -209,10 +243,11 @@ TEST(MapFusionTest, FusedNodesAndFunctionsAreNamedAfterOldNodesAndFunctions) { num_parallel_calls_node, MakeParallelMapV2Node(parent_map_node_name, "range", num_parallel_calls_node.name(), - parent_function_name, "default"), + parent_function_name, "default", + /*use_unbounded_threadpool=*/false), MakeParallelMapV2Node(map_node_name, parent_map_node_name, num_parallel_calls_node.name(), function_name, - "default")}, + "default", /*use_unbounded_threadpool=*/false)}, // FunctionLib {parent_fn, fn}); }; diff --git a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc index 25d86a12ed8cdb..2060b0ed4c83e8 100644 --- a/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc +++ b/tensorflow/core/grappler/optimizers/data/remove_compression_map_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/graph.pb.h" @@ -25,7 +26,6 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/platform/status_matchers.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { @@ -74,7 +74,8 @@ TEST(RemoveCompressionMap, Success) { /*input_node_name=*/"RangeDataset/_3", /*num_parallel_calls_node_name=*/"Const/_4", /*function_name=*/"__inference_Dataset_map_lambda_10", - /*deterministic=*/"default"), + /*deterministic=*/"default", + /*use_unbounded_threadpool=*/false), NDef("dataset", // name "_Retval", // op diff --git a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc index 076357accef923..eba3fced1876c4 100644 --- a/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc +++ b/tensorflow/core/grappler/optimizers/data/seq_interleave_prefetch_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/platform/status.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 7a02b8283e752e..5ddff709e7435c 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -3085,7 +3085,7 @@ class XlaCpuJitDisableFusionTest : public RemapperTest { } Remapper optimizer(RewriterConfig::ON, RewriterConfig::NO_CONVERSION_ON_CPU, - /*xla_clustering_on=*/true); + /*xla_auto_clustering_on=*/true); GraphDef output; TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); diff --git a/tensorflow/core/grappler/utils/pattern_utils_test.cc b/tensorflow/core/grappler/utils/pattern_utils_test.cc index 6b8f6894d1e23b..22fe41b1b6feed 100644 --- a/tensorflow/core/grappler/utils/pattern_utils_test.cc +++ b/tensorflow/core/grappler/utils/pattern_utils_test.cc @@ -184,7 +184,7 @@ TEST_F(PatternMatcherTest, Tree) { bool all_indices_matched = true; for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin(); it++) { - auto label = str_util::StripPrefix(it->first, "my_"); + auto label = absl::StripPrefix(it->first, "my_"); int matched_node_idx = it->second; int expected_node_idx = graph_view.GetNode(label)->node_index(); if (matched_node_idx != expected_node_idx) { @@ -268,7 +268,7 @@ TEST_F(PatternMatcherTest, DAG) { bool all_indices_matched = true; for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin(); it++) { - auto label = str_util::StripPrefix(it->first, "my_"); + auto label = absl::StripPrefix(it->first, "my_"); int matched_node_idx = it->second; int expected_node_idx = graph_view.GetNode(label)->node_index(); if (matched_node_idx != expected_node_idx) { @@ -387,7 +387,7 @@ TEST_F(PatternMatcherTest, MatMulBiasAddGelu) { bool all_indices_matched = true; for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin(); it++) { - auto label = str_util::StripPrefix(it->first, "my_"); + auto label = absl::StripPrefix(it->first, "my_"); int matched_node_idx = it->second; int expected_node_idx = graph_view.GetNode(label)->node_index(); if (matched_node_idx != expected_node_idx) { @@ -561,7 +561,7 @@ TEST_F(PatternMatcherTest, CommutativeInputs) { bool all_indices_matched = true; for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin(); it++) { - auto label = str_util::StripPrefix(it->first, "my_"); + auto label = absl::StripPrefix(it->first, "my_"); int matched_node_idx = it->second; int expected_node_idx = graph_view.GetNode(label)->node_index(); if (matched_node_idx != expected_node_idx) { diff --git a/tensorflow/core/ir/importexport/convert_attributes.cc b/tensorflow/core/ir/importexport/convert_attributes.cc index 24ab1d2c12cba2..dee8e7eb4c21d5 100644 --- a/tensorflow/core/ir/importexport/convert_attributes.cc +++ b/tensorflow/core/ir/importexport/convert_attributes.cc @@ -417,9 +417,10 @@ absl::StatusOr ConvertAttribute( default: return InvalidArgument("Unsupported attr kind in FullType"); } - - return FullTypeAttr::get(builder.getContext(), full_type.type_id(), args, - attr); + IntegerAttr type_id_attr = + mlir::IntegerAttr::get(mlir::IntegerType::get(builder.getContext(), 32), + static_cast(full_type.type_id())); + return FullTypeAttr::get(builder.getContext(), type_id_attr, args, attr); } absl::StatusOr ConvertAttribute( @@ -447,7 +448,8 @@ absl::StatusOr ConvertAttribute( mlir::debugString(full_type.getAttr())); } - ret.set_type_id(static_cast(full_type.getTypeId())); + ret.set_type_id( + static_cast(full_type.getTypeId().getInt())); return ret; } diff --git a/tensorflow/core/ir/types/attributes.td b/tensorflow/core/ir/types/attributes.td index c0af7de6f12b8e..3215c52212a90d 100644 --- a/tensorflow/core/ir/types/attributes.td +++ b/tensorflow/core/ir/types/attributes.td @@ -299,7 +299,7 @@ def TFType_FullTypeId : I32EnumAttr<"FullTypeId", "", [ I32EnumAttrCase<"TFT_LEGACY_VARIANT", 10203, "legacy_variant"> ]> { let cppNamespace = "::mlir::tf_type"; - string cppType = "int32_t"; + string cppType = "::mlir::IntegerAttr"; let genSpecializedAttr = 0; } @@ -320,7 +320,7 @@ def TFType_FullTypeAttr : AttrDef { let parameters = (ins TFType_FullTypeId:$type_id, TFType_FullTypeArgsAttr:$args, - TFType_FullTypeAttrAttr:$attr + "Attribute":$attr ); let mnemonic = "full_type"; let hasCustomAssemblyFormat = 1; diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index 481c9f0f055204..db175cfa089936 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project @@ -259,8 +260,11 @@ FailureOr RawFullTypeAttrParser(AsmParser& parser) { // Parse variable 'attr' Attribute attr; parser.parseOptionalAttribute(attr); - return FullTypeAttr::get(parser.getContext(), static_cast(*type_id), - args, attr); + return FullTypeAttr::get( + parser.getContext(), + mlir::IntegerAttr::get(mlir::IntegerType::get(parser.getContext(), 32), + static_cast(*type_id)), + args, attr); } Attribute FullTypeAttr::parse(AsmParser& parser, Type odsType) { @@ -271,7 +275,8 @@ Attribute FullTypeAttr::parse(AsmParser& parser, Type odsType) { } static void RawFullTypeAttrPrint(FullTypeAttr tfattr, AsmPrinter& printer) { - printer << stringifyFullTypeId(tf_type::FullTypeId(tfattr.getTypeId())); + printer << stringifyFullTypeId( + tf_type::FullTypeId(tfattr.getTypeId().getInt())); if (!tfattr.getArgs().empty()) { printer << "<"; llvm::interleaveComma(tfattr.getArgs(), printer, [&](Attribute arg) { @@ -366,17 +371,10 @@ void ShapeAttr::print(AsmPrinter& os) const { os << "<"; if (hasRank()) { auto print_dim = [&](int64_t dim) { - if (dim != ShapedType::kDynamic) { - if (dim == 0) { - // In order to avoid the parseInteger below from confusing a dimension - // list with '0x' as hex integer, we use 00 for a 0 sized dimension. - os << "00"; - } else { - os << dim; - } - } else { + if (dim != ShapedType::kDynamic) + os << dim; + else os << "?"; - } }; llvm::interleave(getShape(), os, print_dim, "x"); } else { @@ -405,7 +403,7 @@ Attribute ShapeAttr::parse(AsmParser& parser, Type type) { llvm::SMLoc loc = parser.getCurrentLocation(); if (succeeded(parser.parseOptionalQuestion())) { shape.back() = ShapedType::kDynamic; - } else if (failed(parser.parseInteger(shape.back()))) { + } else if (failed(parser.parseDecimalInteger(shape.back()))) { parser.emitError(loc) << "expected an integer or `?` when parsing a tf.shape attribute"; return failure(); diff --git a/tensorflow/core/ir/types/dialect_test.cc b/tensorflow/core/ir/types/dialect_test.cc index 84a301a93ad9d5..4fb014dcb92403 100644 --- a/tensorflow/core/ir/types/dialect_test.cc +++ b/tensorflow/core/ir/types/dialect_test.cc @@ -62,7 +62,7 @@ TEST(TFTypesDialect, TestFuncAttrSubElement) { TEST(TFTypesDialect, ParsesDimensionListWithZero) { // Test that a dimension list with zero can be parsed. const char *const code = R"mlir( - "test.op"() {shape = #tf_type.shape<00x128>} : () -> () + "test.op"() {shape = #tf_type.shape<0x128>} : () -> () )mlir"; MLIRContext context; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 41cea6aa50402a..8995e1898a753a 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -713,6 +713,7 @@ cc_library( "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", "//tensorflow/core/kernels/batching_util:batch_resource_base", "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", + "//tensorflow/core/kernels/batching_util:batch_scheduler_utils", "//tensorflow/core/kernels/batching_util:bounded_executor", "//tensorflow/core/kernels/batching_util:concat_split_util", "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", @@ -1659,7 +1660,7 @@ cc_library( tf_cc_test( name = "batch_kernels_test", - size = "medium", + size = "small", srcs = ["batch_kernels_test.cc"], features = ["-layering_check"], deps = [ @@ -4694,7 +4695,7 @@ tf_kernel_library( "spacetobatch_functor.h", "spacetobatch_functor_gpu.cu.cc", ], - visibility = [":friends"], + visibility = ["//visibility:private"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -4737,7 +4738,7 @@ tf_kernel_library( "spacetodepth_op.h", "spacetodepth_op_gpu.cu.cc", ], - visibility = [":friends"], + visibility = ["//visibility:private"], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc index 4a6bfe0bf046ff..31a40b1d5b9662 100644 --- a/tensorflow/core/kernels/aggregate_ops.cc +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -186,7 +186,7 @@ class AddNOp : public OpKernel { TF_RETURN_IF_ERROR( BinaryOpVariants(ctx, ADD_VARIANT_BINARY_OP, a, b, c)); temp_filled->at(lhs_ix) = true; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/batch_kernels.cc b/tensorflow/core/kernels/batch_kernels.cc index 8e1e97dc2565f9..bd93c1ec3a02a3 100644 --- a/tensorflow/core/kernels/batch_kernels.cc +++ b/tensorflow/core/kernels/batch_kernels.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" #include "tensorflow/core/kernels/batching_util/bounded_executor.h" #include "tensorflow/core/kernels/batching_util/concat_split_util.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" @@ -177,7 +178,8 @@ class BatchResource : public serving::BatchResourceBase { /*mixed_priority_batching_policy=*/ serving::MixedPriorityBatchingPolicy:: kLowPriorityPaddingWithMaxBatchSize, - enable_large_batch_splitting, resource); + enable_large_batch_splitting, + /*batch_padding_policy=*/"PAD_UP", resource); } static Status Create( @@ -190,7 +192,7 @@ class BatchResource : public serving::BatchResourceBase { int32_t low_priority_max_enqueued_batches, const std::vector& low_priority_allowed_batch_sizes, serving::MixedPriorityBatchingPolicy mixed_priority_batching_policy, - bool enable_large_batch_splitting, + bool enable_large_batch_splitting, absl::string_view batch_padding_policy, std::unique_ptr* resource) { BatcherT::Options batcher_options; batcher_options.num_batch_threads = num_batch_threads; @@ -203,8 +205,8 @@ class BatchResource : public serving::BatchResourceBase { num_batch_threads, max_execution_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting, - /*disable_padding=*/false, low_priority_max_batch_size, - low_priority_batch_timeout_micros, + /*disable_padding=*/false, batch_padding_policy, + low_priority_max_batch_size, low_priority_batch_timeout_micros, low_priority_max_enqueued_batches, low_priority_allowed_batch_sizes, mixed_priority_batching_policy), allowed_batch_sizes)); @@ -439,7 +441,7 @@ void BatchFunctionKernel::ComputeAsync(OpKernelContext* c, DoneCallback done) { low_priority_batch_timeout_micros_, low_priority_max_enqueued_batches_, low_priority_allowed_batch_sizes_, mixed_priority_batching_policy, enable_large_batch_splitting_, - &new_resource)); + batch_padding_policy_, &new_resource)); if (session_metadata) { new_resource->set_session_metadata(*session_metadata); } diff --git a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc index e3601cff28b527..7e66f9b26726f4 100644 --- a/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc +++ b/tensorflow/core/kernels/batch_kernels_auto_warmup_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/batch_kernels.h" - #include #include #include @@ -22,6 +20,7 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/function.h" @@ -30,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/batch_kernel_test_util.h" +#include "tensorflow/core/kernels/batch_kernels.h" #include "tensorflow/core/kernels/batching_util/warmup.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/env.h" @@ -37,7 +37,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/version.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/errors.h" #include "tsl/platform/refcount.h" diff --git a/tensorflow/core/kernels/batch_kernels_env_test.cc b/tensorflow/core/kernels/batch_kernels_env_test.cc index 508c0e8699763c..5a8bfec9f90d46 100644 --- a/tensorflow/core/kernels/batch_kernels_env_test.cc +++ b/tensorflow/core/kernels/batch_kernels_env_test.cc @@ -14,11 +14,11 @@ limitations under the License. ==============================================================================*/ #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/kernels/batch_kernel_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/batch_kernels_test.cc b/tensorflow/core/kernels/batch_kernels_test.cc index 62666c099518fd..9aaeb5ad5207c8 100644 --- a/tensorflow/core/kernels/batch_kernels_test.cc +++ b/tensorflow/core/kernels/batch_kernels_test.cc @@ -17,12 +17,15 @@ limitations under the License. #include #include +#include #include #include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/device_factory.h" #include "tensorflow/core/framework/function.h" @@ -39,12 +42,13 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/version.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/criticality.h" #include "tsl/platform/errors.h" #include "tsl/platform/refcount.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" +#include "tsl/protobuf/error_codes.pb.h" namespace tensorflow { namespace { @@ -84,19 +88,13 @@ class SharedBatchFunctionTestState : public OpsTestBase { return absl::OkStatus(); }}); } -}; - -class BatchFunctionTestState : public SharedBatchFunctionTestState { - public: - // Init test fixture with a batch kernel instance. The caller guarantees that - // the device pointer is valid throughout the life of this class. - absl::Status Init(Device *device, bool enable_low_priority_queue, - absl::string_view mixed_priority_policy, - int64_t expected_batch_size) { - // Override the per-test/per-op device with a given device so that it can - // be shared between ops. - device_ = device; + protected: + // Create common batch function op for testing. + absl::StatusOr CreateBatchFunctionBuilder( + const std::vector &allowed_batch_sizes, int max_batch_size, + absl::string_view padding_policy, + const TensorShape &expected_output_shape) { NameAttrList f; f.set_name("ShapeEnforcingFunction"); FunctionDef func = FunctionDefHelper::Create( @@ -112,8 +110,7 @@ class BatchFunctionTestState : public SharedBatchFunctionTestState { {{{"o"}, "EnsureShape", {"x"}, - {{"T", DataType::DT_INT64}, - {"shape", TensorShape({expected_batch_size, 2})}}}}, + {{"T", DataType::DT_INT64}, {"shape", expected_output_shape}}}}, // ret_def {{"o", "o:output"}}); TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func)); @@ -121,13 +118,40 @@ class BatchFunctionTestState : public SharedBatchFunctionTestState { std::vector inputs( {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); - TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction") - .Attr("max_batch_size", 8) - .Attr("num_batch_threads", 8) - .Attr("allowed_batch_sizes", {4, 8}) - .Attr("batch_timeout_micros", 1000000) - .Attr("max_enqueued_batches", 10) - .Attr("enable_large_batch_splitting", true) + return NodeDefBuilder(absl::StrCat("BatchTPUInput", padding_policy), + "BatchFunction") + .Attr("max_batch_size", max_batch_size) + .Attr("num_batch_threads", 8) + .Attr("allowed_batch_sizes", allowed_batch_sizes) + .Attr("batch_timeout_micros", 1000000) + .Attr("max_enqueued_batches", 10) + .Attr("enable_large_batch_splitting", true) + .Attr("batch_padding_policy", padding_policy) + .Attr("Tin", {DataType::DT_INT64}) + .Input(inputs) + .Attr("Tcaptured", std::vector{}) + .Input(std::vector{}) + .Attr("Tout", std::vector{DT_INT64}) + .Attr("f", f); + } +}; + +class BatchFunctionTestState : public SharedBatchFunctionTestState { + public: + // Init test fixture with a batch kernel instance. The caller guarantees that + // the device pointer is valid throughout the life of this class. + absl::Status Init(Device *device, bool enable_low_priority_queue, + absl::string_view mixed_priority_policy, + int64_t expected_batch_size) { + // Override the per-test/per-op device with a given device so that it can + // be shared between ops. + device_ = device; + + const TensorShape expected_output_shape({expected_batch_size, 2}); + TF_ASSIGN_OR_RETURN( + NodeDefBuilder builder, + CreateBatchFunctionBuilder({4, 8}, 8, "PAD_UP", expected_output_shape)); + TF_RETURN_IF_ERROR(builder .Attr("low_priority_max_batch_size", enable_low_priority_queue ? 8 : 0) .Attr("low_priority_batch_timeout_micros", @@ -139,14 +163,8 @@ class BatchFunctionTestState : public SharedBatchFunctionTestState { .Attr("low_priority_max_enqueued_batches", enable_low_priority_queue ? 2 : 0) .Attr("mixed_priority_policy", mixed_priority_policy) - .Attr("batch_padding_policy", "PAD_UP") - .Attr("Tin", {DataType::DT_INT64}) - .Input(inputs) - .Attr("Tcaptured", std::vector{}) - .Input(std::vector{}) - .Attr("Tout", std::vector{DT_INT64}) - .Attr("f", f) .Finalize(node_def())); + return OpsTestBase::InitOp(); } @@ -576,48 +594,13 @@ class BatchFunctionKernelParallelWarmupTestState // be shared between ops. device_ = cpu_device; - NameAttrList f; - f.set_name("BatchFunctionKernelParallelWarmupTestStateFunc"); - FunctionDef func = FunctionDefHelper::Create( - // function_name - f.name(), - // in_def - {"x:int64"}, - // out_def - {"o:int64"}, - // attr_def - {}, - // node_def - {{{"o"}, - "EnsureShape", - {"x"}, - {{"T", DataType::DT_INT64}, {"shape", TensorShape({2})}}}}, - // ret_def - {{"o", "o:output"}}); - TF_RETURN_IF_ERROR(flib_def_->AddFunctionDef(func)); - SharedBatchFunctionTestState::CreateFunctionLibraryRuntime(); + const TensorShape expected_output_shape({2}); + TF_ASSIGN_OR_RETURN( + NodeDefBuilder builder, + CreateBatchFunctionBuilder({2, 4, 8}, enable_splitting ? 16 : 8, + "PAD_UP", expected_output_shape)); + TF_RETURN_IF_ERROR(builder.Finalize(node_def())); - std::vector inputs( - {NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})}); - TF_RETURN_IF_ERROR(NodeDefBuilder("BatchTPUInput", "BatchFunction") - .Attr("max_batch_size", enable_splitting ? 16 : 8) - .Attr("num_batch_threads", 8) - .Attr("allowed_batch_sizes", {2, 4, 8}) - .Attr("batch_timeout_micros", 1000000) - .Attr("max_enqueued_batches", 10) - .Attr("enable_large_batch_splitting", true) - .Attr("low_priority_max_batch_size", 64) - .Attr("low_priority_batch_timeout_micros", 8000) - .Attr("low_priority_allowed_batch_sizes", {32, 64}) - .Attr("low_priority_max_enqueued_batches", 1000) - .Attr("batch_padding_policy", "PAD_UP") - .Attr("Tin", {DataType::DT_INT64}) - .Input(inputs) - .Attr("Tcaptured", std::vector{}) - .Input(std::vector{}) - .Attr("Tout", std::vector{DT_INT64}) - .Attr("f", f) - .Finalize(node_def())); return OpsTestBase::InitOp(); } @@ -688,5 +671,80 @@ INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelParallelWarmupTestSuite, BatchFunctionKernelParallelWarmupTest, ::testing::Bool()); +class BatchFunctionKernelPaddingTestState + : public SharedBatchFunctionTestState { + public: + // Init test fixture with a batch kernel instance. + absl::Status Init(absl::string_view padding_policy, int expected_batch_size) { + static auto *const cpu_device = []() { + auto device = + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); + return device.release(); + }(); + + // Override the per-test/per-op device with a global device so that it can + // be shared between ops. + device_ = cpu_device; + + const TensorShape expected_output_shape({expected_batch_size, 2}); + TF_RETURN_IF_ERROR(CreateBatchFunctionBuilder({4, 8}, 8, padding_policy, + expected_output_shape) + ->Finalize(node_def())); + + return OpsTestBase::InitOp(); + } + + void TestBody() override {} +}; + +class BatchFunctionKernelPaddingTest + : public ::testing::TestWithParam {}; + +TEST_P(BatchFunctionKernelPaddingTest, PadUp) { + SessionMetadata session_metadata; + session_metadata.set_name("test_model"); + session_metadata.set_version(123); + + // Send 5 requests in parallel and check that the given batch padding + // policy behaves as expected. + int64_t num_requests = 5; + int64_t expected_batch_size = 0; + std::string padding_policy = GetParam(); + if (padding_policy == "PAD_UP") { + expected_batch_size = 8; + } else if (padding_policy == "BATCH_DOWN") { + expected_batch_size = 4; + } else if (padding_policy == "MINIMIZE_TPU_COST_PER_REQUEST") { + expected_batch_size = 8; + } else { + FAIL() << "Unsupported padding policy: " << padding_policy; + } + + { + tsl::BlockingCounter blocking_counter(num_requests); + for (int i = 0; i < num_requests; ++i) { + Env::Default()->SchedClosure([&]() { + BatchFunctionKernelPaddingTestState test_state; + test_state.set_session_metadata(session_metadata); + TF_CHECK_OK(test_state.Init(padding_policy, expected_batch_size)); + test_state.AddInputFromList(TensorShape({1, 2}), {123, 456}); + TF_EXPECT_OK(test_state.RunOpKernel()); + + test::ExpectTensorEqual( + *test_state.GetOutput(0), + test::AsTensor({123, 456}, TensorShape({1, 2}))); + blocking_counter.DecrementCount(); + }); + } + + blocking_counter.Wait(); + } +} + +INSTANTIATE_TEST_SUITE_P(BatchFunctionKernelPaddingTestSuite, + BatchFunctionKernelPaddingTest, + ::testing::Values("PAD_UP", "BATCH_DOWN", + "MINIMIZE_TPU_COST_PER_REQUEST")); + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index a136eeb1fab768..41736329123967 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -37,8 +37,8 @@ cc_library( hdrs = ["batch_stats.h"], deps = [ "//tensorflow/core:framework_lite", + "//tensorflow/core:portable_gif_internal", "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/time", ], ) @@ -121,10 +121,10 @@ cc_library( "//tensorflow/core/lib/core:notification", "//tensorflow/core/lib/core:status", "//tensorflow/core/platform:thread_annotations", - "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:criticality", + "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -134,12 +134,12 @@ cc_library( hdrs = ["batch_scheduler.h"], deps = [ "//tensorflow/core:lib", - "//tensorflow/core/profiler/lib:traceme", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:criticality", + "@local_tsl//tsl/profiler/lib:traceme", ], ) @@ -148,8 +148,12 @@ cc_library( srcs = ["batch_scheduler_utils.cc"], hdrs = ["batch_scheduler_utils.h"], deps = [ + ":batch_scheduler_hdrs", + ":batch_stats", "//tensorflow/core:portable_gif_internal", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", ], ) @@ -183,7 +187,10 @@ tf_cc_test( name = "batch_scheduler_utils_test", srcs = ["batch_scheduler_utils_test.cc"], deps = [ + ":batch_scheduler_hdrs", ":batch_scheduler_utils", + ":batch_stats", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", ], ) @@ -195,6 +202,7 @@ cc_library( ":batch_input_task", ":batch_scheduler_hdrs", ":batch_scheduler_utils", + ":batch_stats", ":periodic_function_dynamic", "//tensorflow/core:framework_lite", "//tensorflow/core:lib", @@ -209,7 +217,6 @@ cc_library( "//tensorflow/core/profiler/lib:context_types_hdrs", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -225,13 +232,13 @@ cc_library( ":batch_input_task", ":batch_scheduler", ":batch_scheduler_utils", + ":batch_stats", ":periodic_function_dynamic", "//tensorflow/core:lib", "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:context_types_hdrs", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -246,6 +253,7 @@ tf_cc_test( srcs = ["shared_batch_scheduler_test.cc"], deps = [ ":batch_scheduler", + ":batch_scheduler_utils", ":fake_clock_env", ":shared_batch_scheduler", "//tensorflow/core:lib", @@ -481,18 +489,30 @@ tf_cc_test( srcs = ["batch_resource_base_test.cc"], deps = [ ":batch_resource_base", + ":batch_scheduler_hdrs", + ":batch_scheduler_utils", ":batch_stats", + ":shared_batch_scheduler", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core/common_runtime:cost_constants", "//tensorflow/core/common_runtime:cost_measurement", "//tensorflow/core/common_runtime:cost_measurement_registry", "//tensorflow/core/common_runtime:no_op_cost_measurement", "//tensorflow/core/common_runtime:request_cost", "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/kernels:batch_kernels", + "//tensorflow/core/lib/monitoring:cell_reader", + "//tensorflow/core/platform:notification", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:criticality", + "@local_tsl//tsl/platform:status", ], ) diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index e74fbbbb308248..81cbe417123074 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -233,6 +233,16 @@ void RecordBatchParamMaxBatchSize(int64_t max_batch_size, cell->GetCell(model_name, op_name)->Set(max_batch_size); } +void RecordBatchParamPaddingPolicy(const string& batch_padding_policy, + const string& model_name, + const string& op_name) { + static auto* cell = monitoring::Gauge::New( + "/tensorflow/serving/batching/configured_batch_padding_policy", + "The value of BatchFunction.batch_padding_policy attribute.", + "model_name", "op_name"); + cell->GetCell(model_name, op_name)->Set(batch_padding_policy); +} + void RecordBatchParamMaxEnqueuedBatches(int64_t max_enqueued_batches, const string& model_name, const string& op_name) { @@ -406,6 +416,9 @@ Status BatchResourceBase::RegisterInput( RecordBatchParamMaxEnqueuedBatches( batcher_queue_options_.max_enqueued_batches, GetModelName(context), context->op_kernel().name()); + RecordBatchParamPaddingPolicy( + this->batcher_queue_options_.batch_padding_policy, + GetModelName(context), context->op_kernel().name()); } else if (adaptive_batcher_) { RecordBatchParamBatchTimeoutMicros( adaptive_batcher_queue_options_.batch_timeout_micros, @@ -472,8 +485,10 @@ Status BatchResourceBase::RegisterInput( } BatcherQueueT* batcher_queue; - TF_RETURN_IF_ERROR( - LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue)); + TF_RETURN_IF_ERROR(LookupOrCreateBatcherQueue( + /* queue_name= */ batcher_queue_name, + /* model_name= */ GetModelName(context), + /* op_name= */ context->op_kernel().name(), /* queue= */ &batcher_queue)); if (!session_metadata().name().empty()) { absl::MutexLock lock(&outstanding_batch_mu_); @@ -500,7 +515,9 @@ BatchResourceBase::GetBatcherQueueOptions( return GetBatcherQueueOptions( num_batch_threads, max_batch_size, batch_timeout_micros, max_enqueued_batches, allowed_batch_sizes, enable_large_batch_splitting, - disable_padding, /*low_priority_max_batch_size=*/0, + disable_padding, + /*batch_padding_policy=*/kPadUpPolicy, + /*low_priority_max_batch_size=*/0, /*low_priority_batch_timeout_micros=*/0, /*low_priority_max_enqueued_batches=*/0, /*low_priority_allowed_batch_sizes=*/{}, @@ -514,7 +531,7 @@ BatchResourceBase::GetBatcherQueueOptions( int32_t batch_timeout_micros, int32_t max_enqueued_batches, const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding, - int32_t low_priority_max_batch_size, + absl::string_view batch_padding_policy, int32_t low_priority_max_batch_size, int32_t low_priority_batch_timeout_micros, int32_t low_priority_max_enqueued_batches, const std::vector& low_priority_allowed_batch_sizes, @@ -523,6 +540,8 @@ BatchResourceBase::GetBatcherQueueOptions( batcher_queue_options.input_batch_size_limit = max_batch_size; batcher_queue_options.max_enqueued_batches = max_enqueued_batches; batcher_queue_options.batch_timeout_micros = batch_timeout_micros; + batcher_queue_options.batch_padding_policy = + std::string(batch_padding_policy); if (low_priority_max_batch_size > 0) { batcher_queue_options.enable_priority_queue = true; } @@ -1172,9 +1191,9 @@ void BatchResourceBase::ProcessBatchCallBack( } } -// Looks up the batcher queue for 'queue_name'. If it didn't previously exist, -// creates it. Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, + const string& model_name, + const string& op_name, BatcherQueueT** queue) { mutex_lock l(batcher_queues_mu_); @@ -1186,8 +1205,12 @@ Status BatchResourceBase::LookupOrCreateBatcherQueue(const string& queue_name, std::unique_ptr new_queue; if (batcher_) { + BatcherT::QueueOptions batcher_queue_options = batcher_queue_options_; + batcher_queue_options.model_batch_stats = &GlobalBatchStatsRegistry().model( + /* model_name= */ model_name, /* op_name= */ op_name); + TF_RETURN_IF_ERROR(batcher_->AddQueue( - batcher_queue_options_, + batcher_queue_options, absl::bind_front(&BatchResourceBase::ProcessBatchCallBack, this), &new_queue)); } else if (adaptive_batcher_) { @@ -1241,16 +1264,12 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( absl::StrCat(cost_type, kNoSmearSuffix), total_cost / processed_size * batch.size()); + // Register batch stats for in-process use. if (cost_type == kTpuCostName) { - // Get the model stats object for the current model name and op name. - ModelBatchStats& model_stats = GlobalBatchStats().model( + ModelBatchStats& model_stats = GlobalBatchStatsRegistry().model( /* model_name= */ model_name, /* op_name= */ op_name); - - // Register TPU cost for in-process use. model_stats.batch_size(processed_size).tpu_cost().Register(total_cost); - - // Register cumulative size of processed non-padding jobs for in-process - // use. + // batch.size() is the size of the original batch before padding. model_stats.RegisterProcessedSize(batch.size()); } diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.h b/tensorflow/core/kernels/batching_util/batch_resource_base.h index e8b3926cca92c4..c50b29f3d1b3ed 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.h +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" #include "tensorflow/core/common_runtime/cost_measurement_registry.h" #include "tensorflow/core/common_runtime/request_cost.h" @@ -34,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" #include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/threadsafe_status.h" #include "tensorflow/core/platform/context.h" @@ -52,6 +54,7 @@ struct BatchResourceOptions { int32_t batch_timeout_micros; int32_t max_enqueued_batches; std::vector allowed_batch_sizes; + std::string batch_padding_policy{kPadUpPolicy}; int32_t low_priority_max_batch_size; int32_t low_priority_batch_timeout_micros; int32_t low_priority_max_enqueued_batches; @@ -213,6 +216,7 @@ class BatchResourceBase : public ResourceBase { int32_t batch_timeout_micros, int32_t max_enqueued_batches, const std::vector& allowed_batch_sizes, bool enable_large_batch_splitting, bool disable_padding, + absl::string_view batch_padding_policy, int32_t low_priority_max_batch_size, int32_t low_priority_batch_timeout_micros, int32_t low_priority_max_enqueued_batches, @@ -332,9 +336,14 @@ class BatchResourceBase : public ResourceBase { static Status EmitIndexTensor(OpKernelContext* context, const BatchT& batch, int output_index); - // Looks up the batcher queue for 'queue_name'. If it did't previously exist, + // Looks up the batcher queue for 'queue_name'. If it didn't previously exist, // creates it. + // + // The model_name and op_name are the names of the current model and + // operation, respectively. Status LookupOrCreateBatcherQueue(const string& queue_name, + const string& model_name, + const string& op_name, BatcherQueueT** queue); SessionMetadata session_metadata_; diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index 7a635f1798098c..fa4fad932cfca0 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -16,22 +16,40 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include +#include #include +#include #include #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "tensorflow/core/common_runtime/cost_constants.h" #include "tensorflow/core/common_runtime/cost_measurement.h" #include "tensorflow/core/common_runtime/cost_measurement_registry.h" #include "tensorflow/core/common_runtime/request_cost.h" +#include "tensorflow/core/framework/device.h" +#include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" #include "tensorflow/core/kernels/batching_util/batch_stats.h" +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/public/version.h" #include "tsl/platform/criticality.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace serving { @@ -331,7 +349,7 @@ TEST(SplitBatchCostsAndRecordMetricsTest, UpdatesGlobalBatchStats) { /* model_name= */ kModelName, /* op_name= */ "op_name", batch_cost_measurements, /* processed_size= */ 17, batch); - EXPECT_EQ(GlobalBatchStats() + EXPECT_EQ(GlobalBatchStatsRegistry() .model(/* model_name= */ kModelName, /* op_name= */ "op_name") .batch_size(17) .tpu_cost() @@ -365,7 +383,7 @@ TEST(SplitBatchCostsAndRecordMetricsTest, GlobalBatchStatsProcessedSize) { // Get the original cumulative processed size. int original_cumulative_processed_size = - GlobalBatchStats() + GlobalBatchStatsRegistry() .model(/* model_name= */ kModelName, /* op_name= */ "op_name") .cumulative_processed_size(); @@ -377,7 +395,7 @@ TEST(SplitBatchCostsAndRecordMetricsTest, GlobalBatchStatsProcessedSize) { // that even though the batch size is 17, there is only one non-padding task, // so the cumulative processed size should be // original_cumulative_processed_size + 1. - EXPECT_EQ(GlobalBatchStats() + EXPECT_EQ(GlobalBatchStatsRegistry() .model(/* model_name= */ kModelName, /* op_name= */ "op_name") .cumulative_processed_size(), original_cumulative_processed_size + 1); @@ -394,12 +412,216 @@ TEST(SplitBatchCostsAndRecordMetricsTest, GlobalBatchStatsProcessedSize) { batch_cost_measurements, /* processed_size= */ 8, batch2); // Expect the cumulative processed size to be updated correctly. - EXPECT_EQ(GlobalBatchStats() + EXPECT_EQ(GlobalBatchStatsRegistry() .model(/* model_name= */ kModelName, /* op_name= */ "op_name") .cumulative_processed_size(), original_cumulative_processed_size + 4); } +class BatchResourceBaseTest : public ::testing::Test { + protected: + // Like BatchResourceBase but overrides abstract methods, one of which + // notifies the exposed process_func_batch_called() notification. + class MyBatchResource : public BatchResourceBase { + public: + using BatchResourceBase::BatchResourceBase; + + std::string DebugString() const override { return ""; } + + void ProcessFuncBatchImpl( + const BatchResourceBase::BatchTask& /* last_task */, + absl::Span /* inputs */, + std::vector* /* combined_outputs */, + std::function /* done */) const override { + process_func_batch_called_.Notify(); + } + + Notification& process_func_batch_called() { + return process_func_batch_called_; + } + + private: + mutable Notification process_func_batch_called_; + }; + + BatchResourceBaseTest() { + // The whole point of this test fixture is to create a usable batch function + // context, context_. + + // Create device_. + device_ = DeviceFactory::NewDevice("CPU", SessionOptions{}, + "/job:a/replica:0/task:0"); + + // Create batch_kernel_node_def. + NodeDefBuilder batch_function_builder("my_batch_node", "BatchFunction"); + batch_function_builder.Attr("max_batch_size", 128); + batch_function_builder.Attr("num_batch_threads", 8); + batch_function_builder.Attr("allowed_batch_sizes", {2, 4, 8}); + batch_function_builder.Attr("batch_timeout_micros", 100); + batch_function_builder.Attr("max_enqueued_batches", 100); + batch_function_builder.Attr("enable_large_batch_splitting", true); + std::vector input_dtypes = {DataType::DT_INT64, + DataType::DT_INT64}; + std::vector inputs; + inputs.push_back(NodeDefBuilder::NodeOut({"n1", 0, DataType::DT_INT64})); + inputs.push_back(NodeDefBuilder::NodeOut({"n2", 1, DataType::DT_INT64})); + batch_function_builder.Attr("Tin", input_dtypes); + batch_function_builder.Input(inputs); + batch_function_builder.Attr("Tcaptured", {DataType::DT_INT64}); + batch_function_builder.Input(std::vector{ + NodeDefBuilder::NodeOut({"n3", 1, DataType::DT_INT64})}); + batch_function_builder.Attr("Tout", {DataType::DT_INT64}); + NameAttrList f; + f.set_name("func_to_batch"); + batch_function_builder.Attr("f", f); + NodeDef batch_kernel_node_def; + TF_CHECK_OK(batch_function_builder.Finalize(&batch_kernel_node_def)); + + // Create batch_kernel_. + absl::Status op_kernel_creation_status; + batch_kernel_ = + CreateOpKernel(DEVICE_CPU, device_.get(), device_->GetAllocator({}), + batch_kernel_node_def, TF_GRAPH_DEF_VERSION, + &op_kernel_creation_status); + TF_CHECK_OK(op_kernel_creation_status); + CHECK(batch_kernel_ != nullptr); + + // Create input tensors. + input_tensor_ = Tensor(DataType::DT_INT64, TensorShape({5, 2, 1})); + input_tensor_values_ = { + TensorValue(&input_tensor_), + TensorValue(&input_tensor_), + TensorValue(&input_tensor_), + }; + + // Fill-in session_metadata_. + session_metadata_.set_name("my_model_name"); + + // Fill-in params_. + params_.device = device_.get(); + params_.op_kernel = batch_kernel_.get(); + params_.inputs = input_tensor_values_; + params_.session_metadata = &session_metadata_; + + // Create context_. + context_ = std::make_unique(¶ms_); + } + + std::unique_ptr device_; + + std::unique_ptr batch_kernel_; + + Tensor input_tensor_; + std::vector input_tensor_values_; + + SessionMetadata session_metadata_; + + OpKernelContext::Params params_; + + std::unique_ptr context_; +}; + +TEST_F(BatchResourceBaseTest, PassesCorrectModelBatchStatsToSbs) { + using BatchTask = BatchResourceBase::BatchTask; + using SharedBatchScheduler = SharedBatchScheduler; + + // Like SharedBatchScheduler but exposes the last QueueOptions passed to + // AddQueue as queue_options(). + class MySharedBatchScheduler : public SharedBatchScheduler { + public: + MySharedBatchScheduler() : SharedBatchScheduler::SharedBatchScheduler({}) {} + + absl::Status AddQueue( + const QueueOptions& options, + ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue) override { + queue_options_ = options; + return SharedBatchScheduler::AddQueue(options, process_batch_callback, + queue); + } + + const QueueOptions& queue_options() const { return queue_options_; } + + private: + QueueOptions queue_options_; + }; + + auto batcher = std::make_shared(); + + MyBatchResource* my_batch_resource = new MyBatchResource( + /* has_process_batch_function */ true, + /* batcher= */ batcher, + /* batcher_queue_options */ {}, + /* allowed_batch_sizes */ {}); + + TF_CHECK_OK(my_batch_resource->RegisterInput( + /* guid= */ + 0, + /* context= */ context_.get(), + /* batcher_queue_name= */ "batcher_queue_name", + /* create_batch_task_fn= */ + []() -> absl::StatusOr> { + return std::make_unique(); + }, + /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0)); + + EXPECT_EQ(batcher->queue_options().model_batch_stats, + &GlobalBatchStatsRegistry().model(/* model_name= */ "my_model_name", + /* op_name= */ "my_batch_node")); + + // Wait for the batch timeout to expire and the scheduler to dump the only + // scheduled task back to the batch resource. If we don't do this, the + // scheduler will do this itself on destruction, when the resource has already + // been destroyed. + my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout( + absl::Seconds(1)); + + // This is how we have to destroy the BatchResource. + my_batch_resource->Unref(); +} + +TEST_F(BatchResourceBaseTest, ConfiguredBatchPaddingPolicyMetric) { + tensorflow::monitoring::testing::CellReader metric( + "/tensorflow/serving/batching/configured_batch_padding_policy"); + + std::shared_ptr> batcher; + TF_CHECK_OK( + SharedBatchScheduler::Create({}, &batcher)); + + MyBatchResource* my_batch_resource = new MyBatchResource( + /* has_process_batch_function */ true, + /* batcher= */ batcher, + /* batcher_queue_options */ + MyBatchResource::BatcherT::QueueOptions{ + .batch_padding_policy{kMinimizeTpuCostPerRequestPolicy}, + }, + /* allowed_batch_sizes */ {}); + + TF_CHECK_OK(my_batch_resource->RegisterInput( + /* guid= */ + 0, /* context= */ context_.get(), + /* batcher_queue_name= */ "batcher_queue_name", + /* create_batch_task_fn= */ + []() -> absl::StatusOr> { + return std::make_unique(); + }, + /* done_callback= */ [] {}, /* forced_warmup_batch_size= */ 0)); + + EXPECT_EQ(metric.Read(/* model_name= */ "my_model_name", + /* op_name= */ "my_batch_node"), + kMinimizeTpuCostPerRequestPolicy); + + // Wait for the batch timeout to expire and the scheduler to dump the only + // scheduled task back to the batch resource. If we don't do this, the + // scheduler will do this itself on destruction, when the resource has already + // been destroyed. + my_batch_resource->process_func_batch_called().WaitForNotificationWithTimeout( + absl::Seconds(1)); + + // This is how we have to destroy the BatchResource. + my_batch_resource->Unref(); +} + } // namespace } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler.h b/tensorflow/core/kernels/batching_util/batch_scheduler.h index c70972cc2bf6b4..1c50c552d6fa66 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler.h @@ -32,7 +32,7 @@ limitations under the License. #include #include #include -#include +#include #include #include #include @@ -43,12 +43,11 @@ limitations under the License. #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/profiler/lib/traceme.h" #include "tsl/platform/criticality.h" +#include "tsl/profiler/lib/traceme.h" namespace tensorflow { namespace serving { @@ -252,7 +251,7 @@ int TaskQueue::size() const { // accept new tasks; a closed one cannot. A batch is monotonic: initially it is // open and tasks can be added to it; then it is closed and its set of tasks // remains fixed for the remainder of its life. A closed batch cannot be re- -// opened. Tasks can never be removed from a batch. +// opened. // // Type parameter TaskType must be a subclass of BatchTask. template @@ -304,6 +303,15 @@ class Batch { // Returns the TraceMe context id of this batch. uint64 traceme_context_id() const; + // Attempts to trim this batch to a new, smaller size (not to be confused with + // the number of tasks in the batch). On success, the trimmed tasks go into + // 'out_trimmed_tasks' in the same order the tasks were in this batch. + // + // The method might not succeed if it needs to split a large task to hit the + // correct size. + void TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks); + private: mutable mutex mu_; @@ -505,6 +513,45 @@ uint64 Batch::traceme_context_id() const { return traceme_context_id_; } +template +void Batch::TryTrimToNewSize( + int new_size, std::vector>& out_trimmed_tasks) { + mutex_lock l(mu_); + DCHECK_GT(new_size, 0); + DCHECK_LT(new_size, size_); + DCHECK(out_trimmed_tasks.empty()); + + // Index of the first task to trim away. It is possible that it is the index + // of a task of size larger than 1 that will have to be split in order to get + // to the target new_size. + int32 first_task_to_move = 0; + // The sum of sizes of tasks i, where i < first_task_to_move. + int32 size_of_previous_tasks = 0; + while (size_of_previous_tasks + tasks_[first_task_to_move]->size() <= + new_size) { + size_of_previous_tasks += tasks_[first_task_to_move]->size(); + first_task_to_move++; + // The loop must always stop before this check is tripped because new_size + // must never be larger than the size of the batch. + DCHECK_LT(first_task_to_move, tasks_.size()); + } + + // Check whether task 'first_task_to_move' will have to be split. + if (size_of_previous_tasks < new_size) { + // TODO: b/325954758 - Consider supporting splitting large tasks and then + // drop 'Try' from the method name. + return; + } + DCHECK_EQ(size_of_previous_tasks, new_size); + + // Actually trim. + out_trimmed_tasks.reserve(tasks_.size() - first_task_to_move); + std::move(tasks_.begin() + first_task_to_move, tasks_.end(), + std::back_inserter(out_trimmed_tasks)); + tasks_.resize(first_task_to_move); + size_ = new_size; +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc index e159c4373fdf90..2f9c9031776373 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_test.cc @@ -21,12 +21,13 @@ limitations under the License. #include #include #include +#include #include #include #include "absl/status/status.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tsl/platform/criticality.h" @@ -37,6 +38,7 @@ namespace { using ::testing::ElementsAre; using ::testing::Eq; +using ::testing::Pointer; using ::testing::Property; TEST(MixedPriorityBatchingPolicyTest, InvalidAttrValueError) { @@ -386,6 +388,53 @@ TEST(BatchTest, RemoveAllTasks) { EXPECT_THAT(batch.RemoveAllTasks(), ::testing::IsEmpty()); // third call } +TEST(BatchTest, TryTrimToNewSizeTrimsAndReturnsTrimmedElementsInOrder) { + Batch batch; + + auto task0 = new FakeTask(3); + batch.AddTask(std::unique_ptr(task0)); + + auto task1 = new FakeTask(5); + batch.AddTask(std::unique_ptr(task1)); + + auto task2 = new FakeTask(7); + batch.AddTask(std::unique_ptr(task2)); + + auto task3 = new FakeTask(9); + batch.AddTask(std::unique_ptr(task3)); + + std::vector> trimmed_tasks; + batch.TryTrimToNewSize(/* new_size= */ 8, + /* out_trimmed_tasks= */ trimmed_tasks); + + EXPECT_EQ(batch.size(), 8); + EXPECT_EQ(batch.num_tasks(), 2); + + EXPECT_THAT(trimmed_tasks, ElementsAre(Pointer(task2), Pointer(task3))); + + batch.Close(); // Batch::~Batch blocks until the batch is closed. +} + +TEST(BatchTest, TryTrimToNewSizeDoesNotTrimWhenItWouldNeedToSplitATask) { + Batch batch; + + auto task0 = new FakeTask(3); + batch.AddTask(std::unique_ptr(task0)); + + auto task1 = new FakeTask(5); + batch.AddTask(std::unique_ptr(task1)); + + std::vector> trimmed_tasks; + batch.TryTrimToNewSize(/* new_size= */ 4, + /* out_trimmed_tasks= */ trimmed_tasks); + + EXPECT_EQ(batch.size(), 8); + EXPECT_EQ(batch.num_tasks(), 2); + EXPECT_TRUE(trimmed_tasks.empty()); + + batch.Close(); // Batch::~Batch blocks until the batch is closed. +} + } // namespace } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h index 7e4382a9d862db..9a6deb1a530208 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils.h @@ -16,8 +16,15 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ #define TENSORFLOW_CORE_KERNELS_BATCHING_UTIL_BATCH_SCHEDULER_UTILS_H_ +#include +#include #include +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -36,6 +43,114 @@ int GetPrevAllowedBatchSize(int batch_size, const std::vector& allowed_batch_sizes, bool disable_padding); +// Constants containing possible values for the batch_padding_policy argument +// of MaybeBatchDown. This argument specifies the policy that a batch scheduler +// is using when deciding what to do when, say, 18 requests need to be batched, +// but only 16 and 32 batch sizes are allowed. The following options are +// available. +// +// - PAD_UP: pad to size 32. +// - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the +// batch buffer. +// - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses +// to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per +// real request. In this case, it would compare (batch_16_cost / 16) and +// (batch_32_cost / 18). +// +inline constexpr absl::string_view kBatchDownPolicy = "BATCH_DOWN"; +inline constexpr absl::string_view kPadUpPolicy = "PAD_UP"; +inline constexpr absl::string_view kMinimizeTpuCostPerRequestPolicy = + "MINIMIZE_TPU_COST_PER_REQUEST"; + +// Trims the batch to the next allowed batch size when possible and when +// configured by batch_padding_policy. +// +// When trimming, this function puts the trimmed tasks go into the +// out_trimmed_tasks vector in the same order as they were in the batch. +template +void MaybeBatchDown(Batch& batch, + const std::vector& allowed_batch_sizes, + bool disable_padding, + absl::string_view batch_padding_policy, + ModelBatchStats* model_batch_stats, + std::vector>& out_trimmed_tasks) { + if (batch_padding_policy == kPadUpPolicy) { + // This is the default behavior of batch resource when it is given a batch + // size that doesn't match any of the allowed batch sizes. + return; + } + bool minimize_tpu_cost_per_request; + if (batch_padding_policy == kBatchDownPolicy) { + minimize_tpu_cost_per_request = false; + } else if (batch_padding_policy == kMinimizeTpuCostPerRequestPolicy) { + if (model_batch_stats == nullptr) { + LOG_FIRST_N(ERROR, 1) + << kMinimizeTpuCostPerRequestPolicy + << " batch padding policy has been chosen " + "but no ModelBatchStats passed to the batch scheduler; will " + "fall back on the " + << kPadUpPolicy << " policy."; + return; + } + minimize_tpu_cost_per_request = true; + } else { + LOG_FIRST_N(ERROR, 1) << "Unsupported batch_padding_policy: " + << batch_padding_policy << ", falling back on the " + << kPadUpPolicy << " policy."; + return; + } + + int32 batch_size = batch.size(); + + int32 pad_up_size = + GetNextAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (pad_up_size == batch_size) { + return; // Good, no padding is necessary. + } + + int32 batch_down_size = + GetPrevAllowedBatchSize(batch_size, allowed_batch_sizes, disable_padding); + if (batch_down_size == batch_size) { + return; // Can't batch down (e.g. no smaller batch size available). + } + + if (minimize_tpu_cost_per_request) { + // TODO: b/325954758 - Consider logging a warning here or elsewhere if + // a larger batch doesn't cost meaningfully cheaper than a smaller batch. + // TODO: b/325954758 - Consider logging a warning here or elsewhere if a + // smaller batch costs unreasonably cheaper than a larger one (assuming + // a batch cost model = constant_cost + batch_size * per_element_cost). + // TODO: b/325954758 - Consider occasionally picking either batch size so + // that we learn fresh costs of each batch size. For this code, it is not a + // large priority though because if we are in between two allowed batch + // sizes (say, 16 and 32), chances are that will occasionally organically + // get batches of exact sizes 16 and 32 (and then we pick those + // unconditionally). But if we explicitly occasionally explored other batch + // sizes, we wouldn't have to rely on this "chances are". For other + // applications of batch costs, we might also want to occasionally explore + // all allowed batch sizes and not just 16 and 32 from this example. + std::optional down_batch_cost = + model_batch_stats->batch_size(batch_down_size).tpu_cost().mean(); + std::optional up_batch_cost = + model_batch_stats->batch_size(pad_up_size).tpu_cost().mean(); + if (!down_batch_cost.has_value() || !up_batch_cost.has_value()) { + // We have no data about batch costs, let's just do nothing. + return; + } + + auto batch_down_cost_per_request = *down_batch_cost / batch_down_size; + auto pad_up_cost_per_request = *up_batch_cost / batch_size; + + if (pad_up_cost_per_request < batch_down_cost_per_request) { + // Abort batching down because it's cheaper to pad up. + return; + } + } + + // Batch down. + batch.TryTrimToNewSize(batch_down_size, out_trimmed_tasks); +} + } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc index 2bff515a57aeb9..e45cb46e29646c 100644 --- a/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_scheduler_utils_test.cc @@ -15,7 +15,14 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include +#include +#include + #include +#include "absl/time/time.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" namespace tensorflow { namespace serving { @@ -66,6 +73,208 @@ TEST(GetPrevAllowedBatchSizeTest, GreaterThanMaxAllowedBatchSize) { EXPECT_EQ(GetPrevAllowedBatchSize(10, {2, 4, 8}, false), 8); } +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + size_t size() const override { return size_; } + + private: + const size_t size_; +}; + +TEST(MaybeBatchDownTest, PadUp) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kPadUpPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch must stay unchanged (for the batch resource to then pad it to the + // next allowed batch size, thus ending up in a pad-up behavior.) + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, BatchDown) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kBatchDownPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The scheduler should trim the batch to a smaller allowed size that requires + // no padding. + EXPECT_EQ(batch.size(), 2); + // The trimmed part. + EXPECT_EQ(out_trimmed_tasks.size(), 1); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNotSplitTasks) { + // Add tasks for size 3, but the second task is large and will have to be + // split if doing batch-down. + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(2)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kBatchDownPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch must stay unchanged due the fact that the current implementation + // doesn's support splitting large tasks. + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenTheBatchSizeIsAlreadyAllowed) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {1, 2, 4, 8}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kBatchDownPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // The batch should stay unchanged because it's already of an allowed size. + EXPECT_EQ(batch.size(), 4); +} + +TEST(MaybeBatchDownTest, BatchDownDoesNothingWhenNoSmallerAllowedSize) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {4, 8}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kBatchDownPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // Can't batch down because there is no smaller allowed size. + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksBatchDown) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + ModelBatchStats model_batch_stats; + model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2)); + model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(3.1)); + + std::vector> out_trimmed_tasks; + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy, + /* model_batch_stats= */ &model_batch_stats, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + EXPECT_EQ(batch.size(), 2); +} + +TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestPicksPadUp) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + ModelBatchStats model_batch_stats; + model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2)); + model_batch_stats.batch_size(4).tpu_cost().Register(absl::Seconds(2.9)); + + std::vector> out_trimmed_tasks; + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy, + /* model_batch_stats= */ &model_batch_stats, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + EXPECT_EQ(batch.size(), 3); +} + +TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestIsOkWithMissingCosts) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + ModelBatchStats model_batch_stats; + model_batch_stats.batch_size(2).tpu_cost().Register(absl::Seconds(2)); + // Not adding costs for batch 4. + + std::vector> out_trimmed_tasks; + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy, + /* model_batch_stats= */ &model_batch_stats, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + // No expectations as we do not expect a particular behavior. We just care + // that we don't crash. +} + +TEST(MaybeBatchDownTest, MinimizeTpuCostPerRequestDoesPadUpWhenNoModelStats) { + Batch batch; + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.AddTask(std::make_unique(1)); + batch.Close(); + + std::vector> out_trimmed_tasks; + MaybeBatchDown( + /* batch= */ batch, /* allowed_batch_sizes= */ {2, 4}, + /* disable_padding= */ false, + /* batch_padding_policy= */ kMinimizeTpuCostPerRequestPolicy, + /* model_batch_stats= */ nullptr, + /* out_trimmed_tasks= */ out_trimmed_tasks); + + EXPECT_EQ(batch.size(), 3); +} + } // namespace } // namespace serving diff --git a/tensorflow/core/kernels/batching_util/batch_stats.h b/tensorflow/core/kernels/batching_util/batch_stats.h index 9aefb743bac8f3..87c36fca0c02a1 100644 --- a/tensorflow/core/kernels/batching_util/batch_stats.h +++ b/tensorflow/core/kernels/batching_util/batch_stats.h @@ -23,11 +23,11 @@ limitations under the License. // The classes defined here are not supposed to be instantiated by the user. // Instead, this file provides a single entry point: // -// BatchStats& GlobalBatchStats(); +// BatchStatsRegistry& GlobalBatchStatsRegistry(); // // For example, to register batch cost, do: // -// GlobalBatchStats() +// GlobalBatchStatsRegistry() // .model(/* model_name= */ "m", /* op_name= */ "o") // .batch_size(4) // .tpu_cost @@ -36,7 +36,7 @@ limitations under the License. // To get the mean cost later, do: // // std::optional cost = -// .GlobalBatchStats() +// .GlobalBatchStatsRegistry() // .model(/* model_name= */ "m", /* op_name= */ "o") // .batch_size(4) // .tpu_cost @@ -58,14 +58,18 @@ limitations under the License. #include #include "absl/container/node_hash_map.h" -#include "absl/log/check.h" #include "absl/time/time.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tsl/platform/thread_annotations.h" namespace tensorflow::serving { +// Default values for when there is no recorded statistic in ModelBatchStats. +constexpr int64_t kNumBatchThreadsUnknown = -1; +constexpr int64_t kBatchTimeoutMicrosUnknown = -1; + // Tracks the average cost of registered samples. // // Thread-safe. @@ -167,6 +171,23 @@ class ModelBatchStats { return result; } + void SetNumBatchThreads(int64_t num_batch_threads) { + num_batch_threads_.store(num_batch_threads, std::memory_order_relaxed); + } + + int64_t num_batch_threads() const { + return num_batch_threads_.load(std::memory_order_relaxed); + } + + void SetBatchTimeoutMicros(int64_t batch_timeout_micros) { + batch_timeout_micros_.store(batch_timeout_micros, + std::memory_order_relaxed); + } + + int64_t batch_timeout_micros() const { + return batch_timeout_micros_.load(std::memory_order_relaxed); + } + private: mutable mutex mu_; @@ -184,12 +205,19 @@ class ModelBatchStats { // Can be used to generate an internal load metric per model. See // RegisterQuerySize for more details. std::atomic cumulative_processed_size_ = 0; + + // The number of batch threads assigned to this model. + std::atomic num_batch_threads_ = kNumBatchThreadsUnknown; + + // The timeout in microseconds for this model (after which the current batch + // is sent to be processed by the TPU). + std::atomic batch_timeout_micros_ = kBatchTimeoutMicrosUnknown; }; // Tracks batch statistics for all models. // // Thread-safe. -class BatchStats { +class BatchStatsRegistry { public: // Returns a reference to ModelBatchStats for the provided model_name and // op_name. @@ -236,8 +264,8 @@ class BatchStats { // Returns the global instance of BatchStats, to use used for all production // purposes (one should only instantiate individual classes from this file to // test them). -inline BatchStats& GlobalBatchStats() { - static BatchStats* instance = new BatchStats(); +inline BatchStatsRegistry& GlobalBatchStatsRegistry() { + static BatchStatsRegistry* instance = new BatchStatsRegistry(); return *instance; } diff --git a/tensorflow/core/kernels/batching_util/batch_stats_test.cc b/tensorflow/core/kernels/batching_util/batch_stats_test.cc index 96152777b30d37..5f5168cc24af1a 100644 --- a/tensorflow/core/kernels/batching_util/batch_stats_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_stats_test.cc @@ -27,12 +27,12 @@ namespace { using ::testing::UnorderedElementsAre; -TEST(BatchStatsTest, GlobalBatchStatsAlwaysReturnsTheSameInstance) { - ASSERT_EQ(&GlobalBatchStats(), &GlobalBatchStats()); +TEST(BatchStatsTest, GlobalBatchStatsRegistryAlwaysReturnsTheSameInstance) { + ASSERT_EQ(&GlobalBatchStatsRegistry(), &GlobalBatchStatsRegistry()); } TEST(BatchStatsTest, BasicOperation) { - BatchStats stats; + BatchStatsRegistry stats; stats.model(/* model_name= */ "m", /* op_name= */ "o") .batch_size(1) .tpu_cost() @@ -45,7 +45,7 @@ TEST(BatchStatsTest, BasicOperation) { } TEST(BatchStatsTest, ModelBatchStatsAreUniqueForEachModel) { - BatchStats stats; + BatchStatsRegistry stats; ASSERT_NE(&stats.model(/* model_name= */ "m", /* op_name= */ "o"), &stats.model(/* model_name= */ "m", /* op_name= */ "o2")); } @@ -79,7 +79,7 @@ TEST(BatchStatsTest, ProcessedSizeIsCorrect) { } TEST(BatchStatsTest, ModelOpNamesAreCorrect) { - BatchStats stats; + BatchStatsRegistry stats; // Register a cost for model "m" and op "o". stats.model(/* model_name= */ "m", /* op_name= */ "o") @@ -126,6 +126,28 @@ TEST(BatchStatsTest, BatchSizesAreCorrect) { ASSERT_THAT(stats.BatchSizes(), UnorderedElementsAre(1, 2, 4)); } +TEST(BatchStatsTest, BatchTimeoutIsCorrect) { + ModelBatchStats stats; + + // Originally the batch timeout is -1 if unassigned. + ASSERT_EQ(stats.batch_timeout_micros(), -1); + + // Assign a batch timeout of 100 microseconds. + stats.SetBatchTimeoutMicros(100); + ASSERT_EQ(stats.batch_timeout_micros(), 100); +} + +TEST(BatchStatsTest, NumBatchThreadsIsCorrect) { + ModelBatchStats stats; + + // Originally the number of batch threads is -1 if unassigned. + ASSERT_EQ(stats.num_batch_threads(), -1); + + // Assign a number of per-model batch threads. + stats.SetNumBatchThreads(16); + ASSERT_EQ(stats.num_batch_threads(), 16); +} + } // namespace } // namespace tensorflow::serving diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index 93ac0c922fb404..acea6496288ffd 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -29,13 +29,13 @@ limitations under the License. #include #include -#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/time/clock.h" #include "tensorflow/core/kernels/batching_util/batch_input_task.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" #include "tensorflow/core/kernels/batching_util/periodic_function.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" @@ -43,6 +43,7 @@ limitations under the License. #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -149,7 +150,7 @@ class SharedBatchScheduler const Options& options, std::shared_ptr>* scheduler); - ~SharedBatchScheduler(); + virtual ~SharedBatchScheduler(); // Adds a queue to which tasks may be submitted. The returned queue implements // the BatchScheduler API. Each queue has its own set of scheduling options, @@ -240,6 +241,18 @@ class SharedBatchScheduler // If true, the padding will not be appended. bool disable_padding = false; + // The padding policy to use. + // + // See the documentation for kPadUpPolicy for details. + string batch_padding_policy = string(kPadUpPolicy); + + // A pointer to a ModelBatchStats instance for this model. To be used for + // cost-based padding policy selection. + // + // If null, some other padding policy will be used if a cost-based one is + // requested. + ModelBatchStats* model_batch_stats = nullptr; + // If true, queue implementation would split high priority and low priority // inputs into two sub queues. bool enable_priority_queue = false; @@ -270,13 +283,15 @@ class SharedBatchScheduler MixedPriorityBatchingPolicy mixed_priority_batching_policy = MixedPriorityBatchingPolicy::kLowPriorityPaddingWithMaxBatchSize; }; - Status AddQueue(const QueueOptions& options, - ProcessBatchCallback process_batch_callback, - std::unique_ptr>* queue); + // This method is marked virtual for testing purposes only. + virtual Status AddQueue(const QueueOptions& options, + ProcessBatchCallback process_batch_callback, + std::unique_ptr>* queue); - private: + protected: explicit SharedBatchScheduler(const Options& options); + private: void GetNextWorkItem_Locked(internal::Queue** queue_for_batch_out, BatchUniquePtr* batch_to_process_out) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -588,6 +603,9 @@ class Queue { // The time at which the first task was added to the open (back-most) batch // in 'high_priority_batches_'. Valid iff that batch contains at least one // task. + // + // Note that when using a batch padding policy other than PAD_UP, this field + // might contain an approximate value (see ScheduleBatchWithEagerSplit). uint64 open_batch_start_time_micros_ TF_GUARDED_BY(mu_); // Whether this queue contains a batch that is eligible to be scheduled. @@ -920,7 +938,7 @@ Status Queue::Schedule(std::unique_ptr* task) { template Status Queue::ScheduleWithLazySplit(std::unique_ptr* task) { - profiler::TraceMe trace_me([task] { + tsl::profiler::TraceMe trace_me([task] { return profiler::TraceMeEncode( "ScheduleWithLazySplit", {{"batching_input_task_size", (*task)->size()}}); @@ -1055,7 +1073,7 @@ template Status Queue::ScheduleWithoutOrEagerSplit( std::unique_ptr* task) { const bool large_batch_splitting = options_.enable_large_batch_splitting; - profiler::TraceMe trace_me([task, large_batch_splitting] { + tsl::profiler::TraceMe trace_me([task, large_batch_splitting] { return profiler::TraceMeEncode( large_batch_splitting ? "ScheduleWithEagerSplit" : "ScheduleWithoutSplit", @@ -1223,7 +1241,37 @@ Queue::ScheduleBatchWithEagerSplit() { std::deque>>& batches = GetBatches(); // Consider closing the open batch at this time, to schedule it. if (batches.size() == 1 && IsOpenBatchSchedulable()) { + // Support BatchPaddingPolicy::kBatchDown and + // BatchPaddingPolicy::kMinimizeTpuCostPerRequest. We do this before + // starting a new batch because starting a new batch will close the old + // batch, making it read-only. + std::vector> trimmed_tasks; + MaybeBatchDown( + /* batch= */ *batches[0], + /* allowed_batch_sizes= */ options_.allowed_batch_sizes, + /* disable_padding= */ options_.disable_padding, + /* batch_padding_policy= */ options_.batch_padding_policy, + /* model_batch_stats= */ options_.model_batch_stats, + /* out_trimmed_tasks= */ trimmed_tasks); + StartNewBatch(); + + // Move the trimmed tasks, if any, into the new batch. + Batch& new_batch = *batches[1]; + for (std::unique_ptr& task : trimmed_tasks) { + new_batch.AddTask(std::move(task)); + } + if (!new_batch.empty()) { + // TODO - b/325954758: Reconsider the starting time of a trimmed batch. + // + // Ideally, we'd set open_batch_start_time_micros_ to time we received + // the first task, but we don't have this information here, so we're + // using NOW as the timestamp. An alternative solution that doesn't + // require adding time to each task would be to assume that requests + // arrived at a steady rate and therefore use a point between the old + // value of open_batch_start_time_micros_ and NOW. + open_batch_start_time_micros_ = env_->NowMicros(); + } } if (batches.size() >= 2) { diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc index 1c4073a065eef2..2a5afae82a2728 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler_test.cc @@ -27,11 +27,14 @@ limitations under the License. #include "absl/container/fixed_array.h" #include "absl/status/status.h" #include "absl/time/time.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_scheduler_utils.h" #include "tensorflow/core/kernels/batching_util/fake_clock_env.h" #include "tensorflow/core/lib/core/notification.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/status.h" @@ -39,7 +42,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/criticality.h" namespace tensorflow { @@ -1052,6 +1054,79 @@ TEST_P(SharedBatchSchedulerTest, ZeroQueueRewrittenToOneQueue) { } } +TEST_P(SharedBatchSchedulerTest, BatchPaddingPolicyBatchDown) { + if (enable_lazy_split()) { + GTEST_SKIP() + << "BatchPaddingPolicy::kBatchDown is not supported for lazy split."; + } + + // Set up a fake clock, which only advances when we explicitly tell it to. + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + + { + Notification first_batch_processed; + Notification second_batch_processed; + auto callback = [&](std::unique_ptr> batch) { + if (!first_batch_processed.HasBeenNotified()) { + // This is the main expectation of the test. + // + // The scheduler should have trimmed the batch to a smaller allowed + // size which requires no padding. + EXPECT_EQ(batch->size(), 2); + + first_batch_processed.Notify(); + return; + } + + if (!second_batch_processed.HasBeenNotified()) { + // Leftovers after the first batch. + EXPECT_EQ(batch->size(), 1); + + second_batch_processed.Notify(); + return; + } + + ADD_FAILURE() << "Batch callback must not be invoked more than expected"; + }; + + auto scheduler = CreateSharedBatchScheduler(1, &env); + + QueueOptions options = + CreateQueueOptions(/* max_execution_batch_size= */ 10, + /* input_batch_size_limit= */ 10, + /* batch_timeout_micros= */ 10, + /* max_enqueued_batches= */ 10); + + // The most interesting option for this test. + options.allowed_batch_sizes = {1, 2, 4, 8}; + options.batch_padding_policy = kBatchDownPolicy; + + auto queue = CreateQueue(scheduler, options, callback); + + // Schedule some tasks and ensure the scheduler calls the callback after a + // batch timeout has expired. + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + env.AdvanceByMicroseconds(options.batch_timeout_micros); + first_batch_processed.WaitForNotification(); + + // Ensure the scheduler correctly updates the starting time of the new + // batch. + env.AdvanceByMicroseconds(options.batch_timeout_micros - 1); + EXPECT_FALSE(second_batch_processed.WaitForNotificationWithTimeout( + absl::Milliseconds(10))); + env.AdvanceByMicroseconds(1); + second_batch_processed.WaitForNotification(); + + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + // TODO(b/161857471): // Add test coverage when input-split and no-split returns differently. INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/core/kernels/batchtospace_op.cc b/tensorflow/core/kernels/batchtospace_op.cc index 624b136d30a574..50ad9472a39198 100644 --- a/tensorflow/core/kernels/batchtospace_op.cc +++ b/tensorflow/core/kernels/batchtospace_op.cc @@ -64,8 +64,8 @@ static void BatchToSpaceOpCompute(OpKernelContext* context, orig_crops.shape().DebugString())); // To avoid out-of-bounds access in the case that the block_shape and/or // crops tensors are concurrently modified, we must copy the values. - gtl::InlinedVector block_shape; - gtl::InlinedVector crops; + absl::InlinedVector block_shape; + absl::InlinedVector crops; internal::spacetobatch::SubtleMustCopyFlat(orig_block_shape, &block_shape); internal::spacetobatch::SubtleMustCopyFlat(orig_crops, &crops); diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc index b60c5dd763923b..b4959d43d9c5e5 100644 --- a/tensorflow/core/kernels/bcast_ops.cc +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -31,7 +31,7 @@ class BCastArgsOp : public OpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const Tensor& in = ctx->input(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), @@ -81,7 +81,7 @@ class BCastGradArgsOp : public OpKernel { OP_REQUIRES( ctx, ctx->num_inputs() == 2, errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); - gtl::InlinedVector shapes; + absl::InlinedVector shapes; for (int i = 0; i < ctx->num_inputs(); ++i) { const Tensor& in = ctx->input(i); OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in.shape()), diff --git a/tensorflow/core/kernels/bincount_op.cc b/tensorflow/core/kernels/bincount_op.cc index 1a1e55ed067fd3..d6f8d3dbad9ed0 100644 --- a/tensorflow/core/kernels/bincount_op.cc +++ b/tensorflow/core/kernels/bincount_op.cc @@ -81,7 +81,7 @@ struct BincountFunctor { Eigen::array reduce_dim({0}); output.device(context->eigen_cpu_device()) = partial_bins.any(reduce_dim).cast(); - return OkStatus(); + return absl::OkStatus(); } }; @@ -164,7 +164,7 @@ struct BincountFunctor { Eigen::array reduce_dim({0}); output.device(context->eigen_cpu_device()) = partial_bins.sum(reduce_dim); } - return OkStatus(); + return absl::OkStatus(); } }; @@ -209,7 +209,7 @@ struct BincountReduceFunctor { static_cast(err_neg_val))); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/bucketize_op.cc b/tensorflow/core/kernels/bucketize_op.cc index 03dc11ffe62ad4..179a930da5790c 100644 --- a/tensorflow/core/kernels/bucketize_op.cc +++ b/tensorflow/core/kernels/bucketize_op.cc @@ -44,7 +44,7 @@ struct BucketizeFunctor { output(i) = first_bigger_it - boundaries_vector.begin(); } - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc index 250f8436e0ef40..cb39718908d360 100644 --- a/tensorflow/core/kernels/checkpoint_callback_manager_test.cc +++ b/tensorflow/core/kernels/checkpoint_callback_manager_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" @@ -26,7 +27,6 @@ limitations under the License. #include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace checkpoint { diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.cc b/tensorflow/core/kernels/conv_grad_shape_utils.cc index 00aceb02e31f5b..0be69d2689e7be 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.cc +++ b/tensorflow/core/kernels/conv_grad_shape_utils.cc @@ -95,7 +95,7 @@ Status ConvBackpropExtractAndVerifyDimension( Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const absl::Span& dilations, const std::vector& strides, + const absl::Span dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims) { // The + 2 in the following line is for the batch and feature dimensions. diff --git a/tensorflow/core/kernels/conv_grad_shape_utils.h b/tensorflow/core/kernels/conv_grad_shape_utils.h index 8d105a9df92e0a..f61f53ee13cc38 100644 --- a/tensorflow/core/kernels/conv_grad_shape_utils.h +++ b/tensorflow/core/kernels/conv_grad_shape_utils.h @@ -44,7 +44,7 @@ struct ConvBackpropSpatialDimension { // Computed dimensions for a backwards convolution. struct ConvBackpropDimensions { // Information about each spatial dimension. - gtl::InlinedVector spatial_dims; + absl::InlinedVector spatial_dims; // Batch size. int64_t batch_size; @@ -80,7 +80,7 @@ Status ConvBackpropComputeDimensions(StringPiece label, int num_spatial_dims, Status ConvBackpropComputeDimensionsV2( StringPiece label, int num_spatial_dims, const TensorShape& input_shape, const TensorShape& filter_shape, const TensorShape& out_backprop_shape, - const absl::Span& dilations, const std::vector& strides, + absl::Span dilations, const std::vector& strides, Padding padding, absl::Span explicit_paddings, TensorFormat data_format, ConvBackpropDimensions* dims); diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 5517ef406fb5ad..c447763924a878 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -121,6 +121,10 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:split_utils", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:mutex", ], ) @@ -851,6 +855,7 @@ tf_kernel_library( "//tensorflow/core/data:dataset_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:stats_utils", + "//tensorflow/core/data:unbounded_thread_pool", "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/profiler/lib:traceme_encode", "@com_google_absl//absl/base", @@ -1467,10 +1472,20 @@ tf_cc_test( ":tf_record_dataset_op", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/data:dataset_test_base", + "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -1484,6 +1499,8 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", + "@com_google_absl//absl/status", ], ) @@ -1495,9 +1512,14 @@ tf_kernel_library( ":window_dataset", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", ], ) @@ -1513,12 +1535,16 @@ tf_cc_test( "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/data:dataset_test_base", - "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/data:name_utils", "//tensorflow/core/data:serialization_utils", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", ], ) @@ -1529,12 +1555,14 @@ tf_kernel_library( deps = [ "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core/data:dataset_utils", - "//tensorflow/core/data:global_shuffle_utils", "//tensorflow/core/data:name_utils", "//tensorflow/core/data:split_utils", + "//tensorflow/core/framework:dataset_options_proto_cc", + "//tensorflow/core/framework:types_proto_cc", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", ], ) @@ -1551,11 +1579,16 @@ tf_cc_test( "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", + "//tensorflow/core:portable_gif_internal", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", "//tensorflow/core/data:dataset_test_base", - "//tensorflow/core/data:dataset_utils", + "//tensorflow/core/data:name_utils", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index 238706787c6b13..cd8acd1042c2da 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -288,17 +288,24 @@ class BatchDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); + int64_t input_empty; + TF_RETURN_IF_ERROR( + reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty)); + if (ctx->restored_element_count().has_value()) { IteratorContext::Params params(ctx); params.restored_element_count = *ctx->restored_element_count() * dataset()->batch_size_; IteratorContext ctx_copy(params); - return RestoreInput(&ctx_copy, reader, input_impl_); + if (!static_cast(input_empty)) { + TF_RETURN_IF_ERROR(RestoreInput(&ctx_copy, reader, input_impl_)); + ctx->MergeCheckpoint(ctx_copy.checkpoint()); + } else { + input_impl_.reset(); + } + return absl::OkStatus(); } - int64_t input_empty; - TF_RETURN_IF_ERROR( - reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty)); if (!static_cast(input_empty)) { TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); } else { diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index de8f50db1b9903..b77af19b8a0ea5 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -880,13 +880,14 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase { TF_RETURN_IF_ERROR( WriteElementsToCheckpoint(writer, prefix(), cache_->data())); } + TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); return SaveInput(ctx, writer, iterator_); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { - return global_shuffle_iterator_.Restore(ctx); + return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } mutex_lock l(mu_); iterator_.reset(); diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index 8f380a29b87193..77ed1ce1ca6a69 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -14,13 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/concatenate_dataset_op.h" -#include +#include +#include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/split_utils.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/tensor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/mutex.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/thread_annotations.h" namespace tensorflow { namespace data { @@ -36,6 +44,30 @@ namespace data { constexpr char kIndex[] = "i"; constexpr char kInputImplUninitialized[] = "input_impl_uninitialized"; +constexpr char kElementCount[] = "element_count"; + +namespace { + +// Gets the next shuffled index by iterating through the `index_mapper` until +// 1. It is not a `NotFoundError` or +// 2. It is an `OutOfRangeError` or +// 3. It is an error other than `NotFoundError` or `OutOfRangeError` +absl::StatusOr GetNextShuffledIndex(const IndexMapperFn& index_mapper, + size_t& element_count) { + absl::StatusOr shuffled_index = absl::NotFoundError("default"); + + while (absl::IsNotFound(shuffled_index.status())) { + shuffled_index = index_mapper(element_count++); + if (absl::IsOutOfRange(shuffled_index.status())) { + return shuffled_index.status(); + } + if (!absl::IsNotFound(shuffled_index.status()) && !shuffled_index.ok()) { + return shuffled_index.status(); + } + } + return shuffled_index; +} +} // namespace class ConcatenateDatasetOp::Dataset : public DatasetBase { public: @@ -58,6 +90,12 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { &output_tensorshape)); output_shapes_.push_back(output_tensorshape); } + if (input_ != nullptr && !input_->RandomIndexingCompatible().ok()) { + random_indexing_compatible_ = input->RandomIndexingCompatible(); + } else if (to_concatenate_ != nullptr && + !to_concatenate_->RandomIndexingCompatible().ok()) { + random_indexing_compatible_ = to_concatenate_->RandomIndexingCompatible(); + } } ~Dataset() override { input_->Unref(); @@ -126,6 +164,10 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } + absl::Status RandomIndexingCompatible() const override { + return random_indexing_compatible_; + } + protected: Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, @@ -149,11 +191,15 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { bool SymbolicCheckpointCompatible() const override { return true; } Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + input_impls_.resize(2); + TF_ASSIGN_OR_RETURN(input_contexts_, CreateInputIteratorContexts(ctx, dataset())); TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( &input_contexts_[0], this, strings::StrCat(prefix(), "[0]"), - &input_impl_)); + &input_impls_[0])); + ctx->MergeCheckpoint(input_contexts_[0].checkpoint()); return absl::OkStatus(); } @@ -162,25 +208,115 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { std::vector* out_tensors, bool* end_of_sequence) override { mutex_lock l(mu_); - if (!input_impl_) { + if (!input_impls_[0] && !input_impls_[1]) { *end_of_sequence = true; return absl::OkStatus(); } - while (i_ < 2) { - TF_RETURN_IF_ERROR(input_impl_->GetNext(&input_contexts_[i_], - out_tensors, end_of_sequence)); + // Global shuffling + if (ctx->index_mapper()) { + if (input_impls_[1] == nullptr) { + // Creates the second iterator immediately in the case of + // global random shuffling. + TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( + &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"), + &input_impls_[1])); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); + } + + if (input_contexts_[0].index_mapper() == nullptr) { + IndexMapperFn left_index_mapper = + [index_mapper = ctx->index_mapper(), + left_cardinality = dataset()->input_cardinality_, + right_cardinality = dataset()->to_concatenate_cardinality_]( + size_t to_idx) -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx)); + + if (from_idx >= left_cardinality + right_cardinality) { + return absl::OutOfRangeError("Running out of elements."); + } + if (from_idx >= left_cardinality) { + // This has to return a status so that upstream global shuffle + // iterator will not treat it as an end of sequence. + return absl::NotFoundError("Skipping this element."); + } + return from_idx; + }; + + IndexMapperFn right_index_mapper = + [index_mapper = ctx->index_mapper(), + left_cardinality = dataset()->input_cardinality_, + right_cardinality = dataset()->to_concatenate_cardinality_]( + size_t to_idx) -> absl::StatusOr { + TF_ASSIGN_OR_RETURN(size_t from_idx, index_mapper(to_idx)); + + if (from_idx >= left_cardinality + right_cardinality) { + return absl::OutOfRangeError("Running out of elements."); + } + if (from_idx < left_cardinality) { + // This has to return a status so that upstream global shuffle + // iterator will not treat it as an end of sequence. + return absl::NotFoundError("Skipping this element."); + } + return from_idx - left_cardinality; + }; + + input_contexts_[0].SetIndexMapper(left_index_mapper); + input_contexts_[1].SetIndexMapper(right_index_mapper); + } + + // Materializes the shuffled index because we need this information + // to determine which iterator we need to call later. + + absl::StatusOr shuffled_index = + GetNextShuffledIndex(ctx->index_mapper(), element_count_); + + if (absl::IsOutOfRange(shuffled_index.status())) { + *end_of_sequence = true; + return absl::OkStatus(); + } + + TF_RETURN_IF_ERROR(shuffled_index.status()); + + // Routes the shuffled index to the correct input iterator. + bool temp_end_of_sequence = false; + absl::Status status = absl::OkStatus(); + if (shuffled_index.value() < dataset()->input_cardinality_) { + status = input_impls_[0]->GetNext(&input_contexts_[0], out_tensors, + &temp_end_of_sequence); + ctx->MergeCheckpoint(input_contexts_[0].checkpoint()); + } else { + status = input_impls_[1]->GetNext(&input_contexts_[1], out_tensors, + &temp_end_of_sequence); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); + } + TF_RETURN_IF_ERROR(status); + + if (temp_end_of_sequence) { + *end_of_sequence = temp_end_of_sequence; + return absl::OkStatus(); + } + return absl::OkStatus(); + } + + for (; i_ < 2; ++i_) { + TF_RETURN_IF_ERROR(input_impls_[i_]->GetNext( + &input_contexts_[i_], out_tensors, end_of_sequence)); ctx->MergeCheckpoint(input_contexts_[i_].checkpoint()); if (!*end_of_sequence) { return absl::OkStatus(); } - if (++i_ < 2) { + if (i_ == 0) { + // Creates the second iterator only when the first iterator + // is exhausted to save memory usage. TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( - &input_contexts_[i_], this, strings::StrCat(prefix(), "[1]"), - &input_impl_)); + &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"), + &input_impls_[1])); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); } } *end_of_sequence = true; - input_impl_.reset(); + input_impls_[0].reset(); + input_impls_[1].reset(); return absl::OkStatus(); } @@ -196,10 +332,18 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kIndex, i_)); TF_RETURN_IF_ERROR( - writer->WriteScalar(prefix(), kInputImplUninitialized, - static_cast(!input_impl_))); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); + writer->WriteScalar(prefix(), kElementCount, element_count_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0), + static_cast(!input_impls_[0]))); + if (input_impls_[0]) { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[0])); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1), + static_cast(!input_impls_[1]))); + if (input_impls_[1]) { + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[1])); } return absl::OkStatus(); } @@ -207,33 +351,96 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); - TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_)); - int64_t input_uninitialized; - TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kInputImplUninitialized, - &input_uninitialized)); - if (static_cast(input_uninitialized)) { - input_impl_.reset(); + + int64_t input_uninitialized[2]; + TF_RETURN_IF_ERROR(reader->ReadScalar( + prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 0), + &input_uninitialized[0])); + if (static_cast(input_uninitialized[0])) { + input_impls_[0].reset(); + } + TF_RETURN_IF_ERROR(reader->ReadScalar( + prefix(), absl::StrFormat("%s[%d]", kInputImplUninitialized, 1), + &input_uninitialized[1])); + if (static_cast(input_uninitialized[1])) { + input_impls_[1].reset(); + } + + if (ctx->restored_element_count()) { + if (input_impls_.size() != 2) { + return absl::FailedPreconditionError( + "`Initialize` should be called before restoring from the " + "checkpoint."); + } + { + int64_t tmp_element_count; + TF_RETURN_IF_ERROR( + reader->ReadScalar(prefix(), kElementCount, &tmp_element_count)); + if (tmp_element_count < 0) { + return absl::FailedPreconditionError(absl::StrFormat( + "element_count should be >= 0. Got %d", tmp_element_count)); + } + element_count_ = static_cast(tmp_element_count); + } + + if (!static_cast(input_uninitialized[0])) { + if (!input_impls_[0]) { + return absl::FailedPreconditionError( + "Something went wrong internally. The first iterator should " + "exist because of `Initialize`."); + } + input_contexts_[0].set_restored_element_count( + *ctx->restored_element_count()); + TF_RETURN_IF_ERROR( + RestoreInput(&input_contexts_[0], reader, input_impls_[0])); + ctx->MergeCheckpoint(input_contexts_[0].checkpoint()); + } + + if (!static_cast(input_uninitialized[1])) { + TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( + &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"), + &input_impls_[1])); + + input_contexts_[1].set_restored_element_count( + *ctx->restored_element_count()); + + TF_RETURN_IF_ERROR( + RestoreInput(&input_contexts_[1], reader, input_impls_[1])); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); + } return absl::OkStatus(); } + + TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kIndex, &i_)); + if (!TF_PREDICT_TRUE(i_ >= 0 && i_ <= 2)) return errors::InvalidArgument("i_ must be in range [0, 2]."); - if (i_ == 1) { - TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( - ctx, this, strings::StrCat(prefix(), "[1]"), &input_impl_)); - } else if (i_ == 2) { - input_impl_.reset(); + + if (!static_cast(input_uninitialized[0])) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impls_[0])); } - if (input_impl_) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + if (!static_cast(input_uninitialized[1])) { + TF_RETURN_IF_ERROR(dataset()->to_concatenate_->MakeIterator( + &input_contexts_[1], this, strings::StrCat(prefix(), "[1]"), + &input_impls_[1])); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); + + TF_RETURN_IF_ERROR( + RestoreInput(&input_contexts_[1], reader, input_impls_[1])); + ctx->MergeCheckpoint(input_contexts_[1].checkpoint()); } + return absl::OkStatus(); } private: mutex mu_; int64_t i_ TF_GUARDED_BY(mu_); - std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); - std::vector input_contexts_; + std::vector> input_impls_ TF_GUARDED_BY(mu_); + std::vector input_contexts_ TF_GUARDED_BY(mu_); + // Indicates `ctx->index_mapper()(element_count_)` is the next + // shuffled index. + size_t element_count_ TF_GUARDED_BY(mu_) = 0; }; Status MostSpecificCompatibleShape(const PartialTensorShape& ts1, @@ -257,6 +464,7 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase { const int64_t input_cardinality_; const int64_t to_concatenate_cardinality_; std::vector output_shapes_; + absl::Status random_indexing_compatible_ = absl::OkStatus(); }; ConcatenateDatasetOp::ConcatenateDatasetOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index f15bd3994c9f1f..936ad1f9d4c357 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -173,17 +173,8 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - if (ctx->restored_element_count().has_value()) { - num_elements_ = *(ctx->restored_element_count()); - // If the dataset has reached the end of sequence, the restored element - // count could be cardinality + 1. - if (num_elements_ > dataset()->Cardinality()) { - num_elements_ = dataset()->Cardinality(); - } - } else { - TF_RETURN_IF_ERROR( - reader->ReadScalar(full_name("num_elements"), &num_elements_)); - } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_elements"), &num_elements_)); return RestoreInput(ctx, reader, input_impl_); } diff --git a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc index 8bfc9ade778f1f..91b596c6273b66 100644 --- a/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/dense_to_sparse_batch_dataset_op.cc @@ -251,7 +251,7 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { const auto& t_flat = t.flat(); // TODO(mrry): Replace with a memcpy or something more // efficient. (Maybe an Eigen assign op?) - gtl::InlinedVector strides(row_ndims); + absl::InlinedVector strides(row_ndims); if (!strides.empty()) { strides[row_ndims - 1] = 1; for (int64_t row_dim = strides.size() - 2; row_dim >= 0; diff --git a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc index c937f61b2d9100..20d65e870f0edf 100644 --- a/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/global_shuffle_dataset_op.cc @@ -235,15 +235,8 @@ class GlobalShuffleDatasetOp::Dataset::Iterator TF_ASSIGN_OR_RETURN(element_position, parent_index_mapper(element_position)); } - // This could happen if the source dataset generates more elements than - // needed by the intermediate transformations. For example, when shuffling - // `range(10).batch(3, drop_remainder=True)`, the last element of `range` - // has index 9, which maps to the 4th batched element. However, since - // `batch` drops remainders, the cardinality is 3. In this case, the - // element position exceeds the max index. The caller is responsible to - // handle this case properly. if (element_position > max_index) { - return element_position; + return absl::OutOfRangeError("Out of range"); } if (max_index == 0) { return 0; @@ -265,6 +258,7 @@ class GlobalShuffleDatasetOp::Dataset::Iterator TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed, seed_)); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed2, seed2_)); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kSeed3, seed3_)); + TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); return absl::OkStatus(); } diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc index c93ccce81b4474..2852b443205205 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op.cc @@ -186,14 +186,16 @@ class ListDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return split_provider_->Save( - [this](const std::string& key) { return full_name(key); }, writer); + TF_RETURN_IF_ERROR(split_provider_->Save( + [this](const std::string& key) { return full_name(key); }, writer)); + TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { - return global_shuffle_iterator_.Restore(ctx); + return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } return split_provider_->Restore( [this](const std::string& key) { return full_name(key); }, reader); diff --git a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc index 86f4b00386bab8..44e25cdb334a71 100644 --- a/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/experimental/list_dataset_op_test.cc @@ -96,7 +96,7 @@ class ListDatasetParams : public DatasetParams { const std::vector>& input_elements) { std::vector output_shapes; for (const auto& tensor : input_elements.front()) { - gtl::InlinedVector partial_dim_sizes; + absl::InlinedVector partial_dim_sizes; partial_dim_sizes.reserve(tensor.dims()); for (int i = 0; i < tensor.dims(); ++i) { partial_dim_sizes.push_back(tensor.dim_size(i)); diff --git a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc index 7021a23f02f211..cd9bb2aa6d4147 100644 --- a/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/parallel_interleave_dataset_op.cc @@ -164,7 +164,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { DatasetGraphDefBuilder* b, Node** output) const override { std::vector> inputs; - std::vector>> list_inputs; + std::vector>> list_inputs; int input_index = 0; Node* input_node; diff --git a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc index 3cc332024328e8..5682d1966eba4a 100644 --- a/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/unbatch_dataset_op.cc @@ -72,7 +72,7 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { if (batch_size_ < 0 && shape.dim_size(0) >= 0) { batch_size_ = shape.dim_size(0); } - gtl::InlinedVector partial_dim_sizes; + absl::InlinedVector partial_dim_sizes; for (int i = 1; i < shape.dims(); ++i) { partial_dim_sizes.push_back(shape.dim_size(i)); } diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index bdd95f0c8737ec..cd03a090febdad 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -141,7 +141,12 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { } absl::Status RandomIndexingCompatible() const override { - return random_indexing_compatible_; + return absl::UnimplementedError( + "Please consider applying maps on each dataset, concatenating them " + "into " + "one dataset and apply global shuffle dataset op onto the " + "dataset to achieve the same result as flat map with global " + "shuffling."); } protected: @@ -358,7 +363,10 @@ class FlatMapDatasetOp::Dataset : public DatasetBase { return absl::OkStatus(); } - // TODO(b/325112575): Refactor and reuse this code from weighted flat map. + // TODO: b/355241367 - This implementation is incorrect because IndexMapper + // should be stateless otherwise it would not be compatible with batch + // dataset op. + // See go/tf-data-random-access-iterator-for-concatenate for more info. IndexMapperFn GetFlatMapIndexMapper(IndexMapperFn parent_index_mapper, size_t input_dataset_index) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { diff --git a/tensorflow/core/kernels/data/map_defun_op_test.cc b/tensorflow/core/kernels/data/map_defun_op_test.cc index d48650cada3f82..aaf292ef365020 100644 --- a/tensorflow/core/kernels/data/map_defun_op_test.cc +++ b/tensorflow/core/kernels/data/map_defun_op_test.cc @@ -104,9 +104,10 @@ class MapDefunOpTest : public DatasetOpsTestBase { } // Creates a new `MapDefun` op kernel context. - Status CreateMapDefunContext(OpKernel* const op_kernel, - gtl::InlinedVector* const inputs, - std::unique_ptr* context) { + Status CreateMapDefunContext( + OpKernel* const op_kernel, + absl::InlinedVector* const inputs, + std::unique_ptr* context) { TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs)); TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context)); return absl::OkStatus(); @@ -243,7 +244,7 @@ TEST_P(ParameterizedMapDefunOpTest, NormalTests) { TestCase test_case = GetParam(); TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params)); auto input_tensors = test_case.map_defun_op_params.GetInputTensors(); - gtl::InlinedVector input_values; + absl::InlinedVector input_values; for (auto& input : input_tensors) { input_values.push_back(TensorValue(&input)); } @@ -272,7 +273,7 @@ TEST_F(MapDefunOpTest, InvalidArguments) { for (auto& test_case : test_cases) { TF_ASSERT_OK(InitializeRuntime(test_case.map_defun_op_params)); auto input_tensors = test_case.map_defun_op_params.GetInputTensors(); - gtl::InlinedVector input_values; + absl::InlinedVector input_values; for (auto& input : input_tensors) { input_values.push_back(TensorValue(&input)); } diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index fcb0c117eedb4d..2774d25747340b 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/stats_utils.h" +#include "tensorflow/core/data/unbounded_thread_pool.h" #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/model.h" @@ -83,6 +85,13 @@ constexpr char kErrorMessage[] = "error_message"; // Period between reporting dataset statistics. constexpr int kStatsReportingPeriodMillis = 1000; +// Factor used to determine the autotune parallelism limit when using an +// unbounded threadpool. The limit is determined by multiplying this factor +// by the default threadpool size, which is typically based on the number of +// CPU cores. Without this limit, we see autotune sometimes choose unreasonably +// large values for the parallelism, e.g. creating 300k threads. +constexpr int kUnboundedThreadpoolAutotuningFactor = 10; + } // namespace class ParallelMapDatasetOp::Dataset : public DatasetBase { @@ -92,17 +101,19 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { const std::vector& output_shapes, DeterminismPolicy deterministic, std::unique_ptr captured_func, - bool preserve_cardinality, int op_version) + bool preserve_cardinality, bool use_unbounded_threadpool, + int op_version) : Dataset(DatasetContext(ctx), input, num_parallel_calls, output_types, output_shapes, deterministic, std::move(captured_func), - preserve_cardinality, op_version) {} + preserve_cardinality, use_unbounded_threadpool, op_version) {} Dataset(DatasetContext dataset_context, const DatasetBase* input, int64_t num_parallel_calls, const DataTypeVector& output_types, const std::vector& output_shapes, DeterminismPolicy deterministic, std::unique_ptr captured_func, - bool preserve_cardinality, int op_version) + bool preserve_cardinality, bool use_unbounded_threadpool, + int op_version) : DatasetBase(std::move(dataset_context)), input_(input), num_parallel_calls_(num_parallel_calls), @@ -110,6 +121,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { output_shapes_(output_shapes), deterministic_(deterministic), preserve_cardinality_(preserve_cardinality), + use_unbounded_threadpool_(use_unbounded_threadpool), captured_func_(std::move(captured_func)), op_version_(op_version) { input_->Ref(); @@ -235,6 +247,12 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { b->BuildAttrValue(preserve_cardinality_, &preserve_cardinality_attr); attrs.emplace_back(kPreserveCardinality, preserve_cardinality_attr); + // Attr: use_unbounded_threadpool + AttrValue use_unbounded_threadpool_attr; + b->BuildAttrValue(use_unbounded_threadpool_, + &use_unbounded_threadpool_attr); + attrs.emplace_back(kUseUnboundedThreadpool, use_unbounded_threadpool_attr); + TF_RETURN_IF_ERROR(b->AddDataset( this, {std::make_pair(0, input_graph_node), @@ -256,6 +274,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { deterministic_(params.dataset->deterministic_.IsDeterministic() || params.dataset->deterministic_.IsDefault()), preserve_cardinality_(params.dataset->preserve_cardinality_), + use_unbounded_threadpool_(params.dataset->use_unbounded_threadpool_), autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {} ~Iterator() override { @@ -271,7 +290,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { Status Initialize(IteratorContext* ctx) override { mutex_lock l(*mu_); interleave_depth_ = ctx->interleave_depth(); - + if (use_unbounded_threadpool_) { + unbounded_thread_pool_ = std::make_unique( + ctx->env(), "tf_data_map_unbounded_thread_pool"); + } if (num_parallel_calls_->value == model::kAutotune) { num_parallel_calls_->value = GetAutotuneDefaultParallelism(ctx); } @@ -323,11 +345,15 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { std::shared_ptr CreateNode( IteratorContext* ctx, model::Node::Args args) const override { std::shared_ptr parameter; + double max_parallelism_value = ctx->runner_threadpool_size(); + if (use_unbounded_threadpool_) { + max_parallelism_value *= kUnboundedThreadpoolAutotuningFactor; + } if (num_parallel_calls_ && dataset()->num_parallel_calls_ == model::kAutotune) { parameter = model::MakeParameter( "parallelism", num_parallel_calls_, /*min=*/1, - /*max=*/ctx->runner_threadpool_size(), + /*max=*/max_parallelism_value, // This is to ensure before this op has seen its first element, // `MaximumBufferedBytes()` can use the correct `parameter->value` // to estimate the maximum buffer bytes. @@ -335,7 +361,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { } else { parameter = model::MakeParameter("parallelism", num_parallel_calls_, /*min=*/1, - /*max=*/ctx->runner_threadpool_size()); + /*max=*/max_parallelism_value); } std::optional estimated_element_size = dataset()->GetEstimatedElementSize(); @@ -394,10 +420,6 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - if (ctx->restored_element_count().has_value()) { - return RestoreInput(ctx, reader, input_impl_); - } - mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); DCHECK(invocation_results_.empty()); @@ -456,6 +478,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { std::make_pair("autotune", autotune_ ? "true" : "false")); result.push_back( std::make_pair("deterministic", deterministic_ ? "true" : "false")); + result.push_back( + std::make_pair("use_unbounded_threadpool", + use_unbounded_threadpool_ ? "true" : "false")); result.push_back(std::make_pair( "parallelism", parallelism == -1 @@ -543,7 +568,15 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { // Apply the map function on `input_element`, storing the result in // `result->return_values`, and invoking `done` when finished. - if (dataset()->captured_func_->use_inter_op_parallelism()) { + if (use_unbounded_threadpool_) { + auto runner_fn = [this](std::function fn) { + this->unbounded_thread_pool_->Schedule(fn); + }; + instantiated_captured_func_->RunAsync( + runner_fn, ctx->cancellation_manager(), ctx->collective_executor(), + std::move(input_element), &result->return_values, done, + model_node()); + } else if (dataset()->captured_func_->use_inter_op_parallelism()) { instantiated_captured_func_->RunAsync( ctx.get(), std::move(input_element), &result->return_values, std::move(done), model_node()); @@ -751,6 +784,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { const std::shared_ptr num_parallel_calls_; const bool deterministic_; const bool preserve_cardinality_; + const bool use_unbounded_threadpool_; const bool autotune_; // Counts the number of outstanding calls. int64_t num_calls_ TF_GUARDED_BY(*mu_) = 0; @@ -767,6 +801,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { bool cancelled_ TF_GUARDED_BY(*mu_) = false; std::unique_ptr runner_thread_ TF_GUARDED_BY(*mu_); std::unique_ptr stats_thread_ TF_GUARDED_BY(*mu_); + std::unique_ptr unbounded_thread_pool_; // Method for deregistering the cancellation callback. std::function deregister_fn_; @@ -784,6 +819,7 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase { const std::vector output_shapes_; const DeterminismPolicy deterministic_; const bool preserve_cardinality_; + const bool use_unbounded_threadpool_; const std::unique_ptr captured_func_; const int op_version_; // This is used for random access provided by Get(). @@ -812,12 +848,15 @@ ParallelMapDatasetOp::ParallelMapDatasetOp(OpKernelConstruction* ctx) } else { deterministic_ = DeterminismPolicy(DeterminismPolicy::Type::kDefault); } + use_unbounded_threadpool_ = false; } if (op_version_ == 2) { std::string deterministic; OP_REQUIRES_OK(ctx, ctx->GetAttr(kDeterministic, &deterministic)); OP_REQUIRES_OK( ctx, DeterminismPolicy::FromString(deterministic, &deterministic_)); + OP_REQUIRES_OK( + ctx, ctx->GetAttr(kUseUnboundedThreadpool, &use_unbounded_threadpool_)); } OP_REQUIRES_OK(ctx, ctx->GetAttr(kPreserveCardinality, &preserve_cardinality_)); @@ -849,10 +888,10 @@ void ParallelMapDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, metrics::RecordTFDataAutotune(kDatasetType); } - *output = - new Dataset(ctx, input, num_parallel_calls, output_types_, output_shapes_, - deterministic_, std::move(captured_func), - preserve_cardinality_, op_version_); + *output = new Dataset(ctx, input, num_parallel_calls, output_types_, + output_shapes_, deterministic_, + std::move(captured_func), preserve_cardinality_, + use_unbounded_threadpool_, op_version_); } std::unique_ptr MakeDataServiceUncompressDataset( @@ -867,7 +906,8 @@ std::unique_ptr MakeDataServiceUncompressDataset( /*num_parallel_calls=*/model::kAutotune, output_types, output_shapes, DeterminismPolicy(DeterminismPolicy::Type::kDefault), std::move(captured_function), - /*preserve_cardinality=*/true, /*op_version=*/2); + /*preserve_cardinality=*/true, + /*use_unbounded_threadpool=*/false, /*op_version=*/2); } namespace { diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.h b/tensorflow/core/kernels/data/parallel_map_dataset_op.h index 4e1e564346f874..efdf6339a20007 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.h +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.h @@ -38,6 +38,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { static constexpr const char* const kSloppy = "sloppy"; static constexpr const char* const kPreserveCardinality = "preserve_cardinality"; + static constexpr const char* const kUseUnboundedThreadpool = + "use_unbounded_threadpool"; explicit ParallelMapDatasetOp(OpKernelConstruction* ctx); @@ -54,6 +56,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { bool sloppy_; bool preserve_cardinality_; DeterminismPolicy deterministic_; + bool use_unbounded_threadpool_; friend std::unique_ptr MakeDataServiceUncompressDataset( DatasetBase* input, std::unique_ptr captured_function, diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc index 357e279d7396a8..cedc0e8adad743 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op_test.cc @@ -12,10 +12,10 @@ limitations under the License. #include "tensorflow/core/kernels/data/parallel_map_dataset_op.h" #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_test_base.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 0a7779c519a835..c6238a0f987a1d 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -297,11 +297,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - if (ctx->restored_element_count().has_value()) { - tsl::mutex_lock l(input_mu_); - return RestoreInput(ctx, reader, input_impl_); - } - mutex_lock input_l(input_mu_); mutex_lock l(*mu_); DCHECK(!prefetch_thread_); diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index bf71d672c86bef..5834494e5a043f 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -332,13 +332,14 @@ class RangeDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR( writer->WriteScalar(prefix(), kNext, counter_->Peek())); } + TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { - return global_shuffle_iterator_.Restore(ctx); + return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } if (reader->Contains(prefix(), kHasSplitProvider)) { TF_RETURN_IF_ERROR(split_provider_->Restore( diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 820fc9ba6b0007..555bdbaad31322 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -319,7 +319,7 @@ class RepeatDatasetOp::Dataset : public DatasetBase { if (element_position >= input_cardinality) { // The input element position is out-of-range. The caller is // responsible for handle this case (e.g.: returning end_of_sequence). - return element_position; + return absl::OutOfRangeError("Finite repeat is out of range"); } // First, maps the input indices from @@ -356,28 +356,37 @@ class RepeatDatasetOp::Dataset : public DatasetBase { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); + int64_t input_empty; + TF_RETURN_IF_ERROR( + reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty)); + TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_)); + if (ctx->restored_element_count().has_value()) { CardinalityOptions options; options.set_compute_level( CardinalityOptions::CARDINALITY_COMPUTE_MODERATE); const int64_t input_cardinality = dataset()->input_->Cardinality(std::move(options)); - i_ = *ctx->restored_element_count() / input_cardinality; // For upstream iterators, the restored element count should be the // element count within the current repetition. IteratorContext::Params params(ctx); params.restored_element_count = - *ctx->restored_element_count() % input_cardinality; + *ctx->restored_element_count() % (input_cardinality); params.index_mapper = GetIndexMapper(ctx->index_mapper()); IteratorContext ctx_with_restored_element_count(params); - return RestoreInput(&ctx_with_restored_element_count, reader, - input_impl_); + if (!input_empty) { + // Needs to re-`MakeIterator` because `i_` might have changed. + TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( + ctx, this, nested_prefix(prefix(), i_), &input_impl_)); + TF_RETURN_IF_ERROR(RestoreInput(&ctx_with_restored_element_count, + reader, input_impl_)); + ctx->MergeCheckpoint(ctx_with_restored_element_count.checkpoint()); + } else { + input_impl_.reset(); + } + return absl::OkStatus(); } - TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIteration, &i_)); - int64_t input_empty; - TF_RETURN_IF_ERROR( - reader->ReadScalar(prefix(), kInputImplEmpty, &input_empty)); if (static_cast(!input_empty)) { TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator( ctx, this, nested_prefix(prefix(), i_), &input_impl_)); diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index 695263eee45c38..950be9b631e480 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -280,7 +280,7 @@ class SparseTensorSliceDatasetOp : public DatasetOpKernel { "is not currently supported.")); previous_batch_index = next_batch_index; } - gtl::InlinedVector std_order(dense_shape->NumElements(), 0); + absl::InlinedVector std_order(dense_shape->NumElements(), 0); TensorShape shape; OP_REQUIRES_OK(ctx, TensorShape::BuildTensorShape( dense_shape->vec(), &shape)); diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index a08890ec626175..d91027115ae383 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -187,12 +187,6 @@ class TakeDataset::FiniteIterator : public DatasetIterator { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - if (ctx->restored_element_count().has_value()) { - mutex_lock l(mu_); - i_ = *ctx->restored_element_count(); - return RestoreInput(ctx, reader, input_impl_); - } - mutex_lock l(mu_); TF_RETURN_IF_ERROR(reader->ReadScalar(prefix(), kCurIndex, &i_)); int64_t input_empty; diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 02736e5a0ddd1a..fe2d564bee2c6c 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -184,13 +184,14 @@ class TensorDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kProduced, static_cast(produced_))); + TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { - return global_shuffle_iterator_.Restore(ctx); + return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } mutex_lock l(mu_); diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index dad1e8ce9e9950..3e2374b453acf8 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -55,7 +55,7 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { replicate_on_split_(replicate_on_split) { for (const Tensor& t : tensors_) { dtypes_.push_back(t.dtype()); - gtl::InlinedVector element_dim_sizes; + absl::InlinedVector element_dim_sizes; // Handle scalar here. Check that everyone matches here? Or fail // at runtime? for (int i = 1; i < t.dims(); ++i) { @@ -206,14 +206,16 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase { Status SaveInternal(SerializationContext* ctx, IteratorStateWriter* writer) override { - return split_provider_->Save( - [this](const std::string& key) { return full_name(key); }, writer); + TF_RETURN_IF_ERROR(split_provider_->Save( + [this](const std::string& key) { return full_name(key); }, writer)); + TF_RETURN_IF_ERROR(global_shuffle_iterator_.Save(prefix(), ctx, writer)); + return absl::OkStatus(); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { if (ctx->restored_element_count().has_value()) { - return global_shuffle_iterator_.Restore(ctx); + return global_shuffle_iterator_.Restore(prefix(), ctx, reader); } return split_provider_->Restore( [this](const std::string& key) { return full_name(key); }, reader); diff --git a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc index 6a0148eb25b244..99a29ebebd17cb 100644 --- a/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/tf_record_dataset_op_test.cc @@ -14,10 +14,29 @@ limitations under the License. #include #include +#include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_test_base.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/io/record_reader.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 4ca4c2a8ac1f93..c9d2a5a9f42a75 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -17,10 +17,22 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/thread_annotations.h" namespace tensorflow { namespace data { @@ -165,6 +177,7 @@ class WindowOp : public DatasetOpKernel { std::vector> elements; for (size_t i = 0; i < num_elements; ++i) { std::vector element; + element.reserve(element_size); for (size_t j = 0; j < element_size; ++j) { element.push_back(std::move(inputs[i * element_size + j])); } diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h index 6d698083c5a568..877d78369a397d 100644 --- a/tensorflow/core/kernels/data/window_dataset.h +++ b/tensorflow/core/kernels/data/window_dataset.h @@ -20,7 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 2aa058078e1c39..1a2c96bb1a6c5e 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -14,10 +14,28 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/window_dataset_op.h" +#include "absl/log/check.h" +#include "absl/status/status.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/data/window_dataset.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringprintf.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow/core/platform/types.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/thread_annotations.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/window_dataset_op.h b/tensorflow/core/kernels/data/window_dataset_op.h index 186bce67638191..241e0f510883f9 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.h +++ b/tensorflow/core/kernels/data/window_dataset_op.h @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" diff --git a/tensorflow/core/kernels/data/window_dataset_op_test.cc b/tensorflow/core/kernels/data/window_dataset_op_test.cc index 7d8d5b6bc0192b..6dfc803f96bfb3 100644 --- a/tensorflow/core/kernels/data/window_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/window_dataset_op_test.cc @@ -14,10 +14,20 @@ limitations under the License. #include #include +#include +#include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_test_base.h" -#include "tensorflow/core/data/dataset_utils.h" +#include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/serialization_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant_tensor_data.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index 1bda8e47e94191..f2891dd7eff70d 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -18,16 +18,26 @@ limitations under the License. #include #include +#include "absl/log/check.h" #include "absl/status/status.h" -#include "tensorflow/core/data/dataset_utils.h" #include "tensorflow/core/data/name_utils.h" #include "tensorflow/core/data/split_utils.h" #include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/dataset_options.pb.h" +#include "tensorflow/core/framework/model.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/types.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/thread_annotations.h" namespace tensorflow { @@ -262,6 +272,9 @@ class ZipDatasetOp::Dataset : public DatasetBase { mutex_lock l(mu_); // Note: When restoring, `SaveInternal` would not be called // if there is a global_shuffle_dataset_op.cc above this op. + int64_t inputs_empty; + TF_RETURN_IF_ERROR( + reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty)); if (ctx->restored_element_count()) { if (input_impls_.size() != dataset()->inputs_.size()) { return absl::FailedPreconditionError( @@ -273,14 +286,19 @@ class ZipDatasetOp::Dataset : public DatasetBase { "ctx->index_mapper() should be provided along with " "ctx->restored_element_count() when restoring."); } - for (const auto& input_impl : input_impls_) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl)); + if (static_cast(inputs_empty)) { + input_impls_.clear(); + } else { + for (int i = 0; i < input_impls_.size(); ++i) { + input_contexts_[i].set_restored_element_count( + ctx->restored_element_count().value()); + TF_RETURN_IF_ERROR( + RestoreInput(&input_contexts_[i], reader, input_impls_[i])); + ctx->MergeCheckpoint(input_contexts_[i].checkpoint()); + } } return absl::OkStatus(); } - int64_t inputs_empty; - TF_RETURN_IF_ERROR( - reader->ReadScalar(prefix(), kInputImplsEmpty, &inputs_empty)); if (static_cast(inputs_empty)) { input_impls_.clear(); } else { diff --git a/tensorflow/core/kernels/data/zip_dataset_op.h b/tensorflow/core/kernels/data/zip_dataset_op.h index 5aa92035cc7ae0..1e6b294be47736 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.h +++ b/tensorflow/core/kernels/data/zip_dataset_op.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_DATA_ZIP_DATASET_OP_H_ #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/data/zip_dataset_op_test.cc b/tensorflow/core/kernels/data/zip_dataset_op_test.cc index fcbd97f725491b..9d8eed6d496aef 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op_test.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op_test.cc @@ -14,7 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/zip_dataset_op.h" +#include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/data/dataset_test_base.h" +#include "tensorflow/core/data/name_utils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace data { diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index d7c0c762fa7648..15ff88c5c229ff 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -190,7 +190,7 @@ class BaseDebugOp : public OpKernel { LOG(ERROR) << "Debug node of watch key " << debug_watch_key_->debug_node_name << " failed to publish debug tensor data to all URLs " - << str_util::Join(debug_urls_, ", ") + << absl::StrJoin(debug_urls_, ", ") << ", due to: " << status.message(); } return status; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 7c865b62d0452b..b2930d4b45a670 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -255,7 +255,7 @@ class SymbolicGradientOp : public AsyncOpKernel { args.push_back(ctx->input(i)); } std::vector* rets = new std::vector; - profiler::TraceMe trace_me("SymbolicGradientOp"); + tsl::profiler::TraceMe trace_me("SymbolicGradientOp"); lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) { if (!status.ok()) { ctx->SetStatus(status); @@ -319,12 +319,12 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { handle = cached_entry->second; } else { VLOG(1) << "Instantiating " << func_name << " on " << target_device; - profiler::TraceMe activity( + tsl::profiler::TraceMe activity( [&] { return strings::StrCat("RemoteCall: Instantiate: ", func_name, " on ", target_device); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); FunctionLibraryRuntime::InstantiateOptions instantiate_opts; const auto* config = (ctx->function_library()) ? ctx->function_library()->config_proto() @@ -398,24 +398,24 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { auto* rets = new std::vector; VLOG(1) << "Running " << func_name << " on " << target_device << " with handle: " << handle; - profiler::TraceMe trace_me( + tsl::profiler::TraceMe trace_me( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "RemoteCallOp", {{"func_name", func_name}, {"device", target_device}}); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); lib->Run( opts, handle, args, rets, [rets, done = std::move(done), func_name, ctx, cancel_mgr, target_device = std::move(function_target.first)](const Status& status) { - profiler::TraceMe activity( + tsl::profiler::TraceMe activity( [&] { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( "RemoteCallOpDone", {{"func_name", func_name}, {"device", target_device}}); }, - profiler::TraceMeLevel::kInfo); + tsl::profiler::TraceMeLevel::kInfo); if (!status.ok()) { ctx->SetStatus(status); } else { @@ -431,13 +431,13 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { string RemoteCallOp::TraceString(const OpKernelContext& ctx, bool verbose) const { - string trace_string = profiler::TraceMeOp( + string trace_string = tsl::profiler::TraceMeOp( strings::StrCat(name_view(), "__", func_.name()), type_string_view()); if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - trace_string = - profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}}); + trace_string = tsl::profiler::TraceMeEncode(std::move(trace_string), + {{"shape", shape}}); } } return trace_string; diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc index 79c393facbb6f1..7cf465bbaa0ffe 100644 --- a/tensorflow/core/kernels/functional_ops.cc +++ b/tensorflow/core/kernels/functional_ops.cc @@ -199,7 +199,7 @@ class IfOp : public AsyncOpKernel { void Start() { FHandle handle = cond_ ? then_handle_ : else_handle_; rets_.clear(); - profiler::TraceMe trace_me("IfOp"); + tsl::profiler::TraceMe trace_me("IfOp"); lib_->Run( // Evaluate one of the branch. opts_, handle, args_, &rets_, @@ -378,7 +378,7 @@ class CaseOp : public AsyncOpKernel { branch = branch_handles_.size() - 1; } rets_.clear(); - profiler::TraceMe trace_me("CaseOp"); + tsl::profiler::TraceMe trace_me("CaseOp"); lib_->Run( // Evaluate one of the branch. opts_, branch_handles_[branch], args_, &rets_, @@ -633,7 +633,7 @@ class WhileOp : public AsyncOpKernel { std::unique_ptr body_frame_; void EvalCond() { - profiler::TraceMe trace_me("WhileOp-EvalCond"); + tsl::profiler::TraceMe trace_me("WhileOp-EvalCond"); lib_->Run( // Evaluate the condition. opts_, cond_handle_, args_, &rets_, @@ -669,7 +669,7 @@ class WhileOp : public AsyncOpKernel { } rets_.clear(); rets_.resize(args_.size()); - profiler::TraceMe trace_me("WhileOp-StartBody"); + tsl::profiler::TraceMe trace_me("WhileOp-StartBody"); lib_->Run( // Evaluate the body. opts_, body_handle_, body_frame_.get(), @@ -724,7 +724,7 @@ class WhileOp : public AsyncOpKernel { do { // Evaluate the cond function on the current loop variables. { - profiler::TraceMe trace_me("WhileOp-EvalCond"); + tsl::profiler::TraceMe trace_me("WhileOp-EvalCond"); TF_RETURN_IF_ERROR(lib->RunSync(opts, cond_handle, args, &cond_rets)); } if (cond_rets.size() != 1) { @@ -745,7 +745,7 @@ class WhileOp : public AsyncOpKernel { // Evaluate the body function on the current loop variables, to get an // updated vector of loop variables. { - profiler::TraceMe trace_me("WhileOp-StartBody"); + tsl::profiler::TraceMe trace_me("WhileOp-StartBody"); body_rets.resize(num_loop_vars); BodyFuncCallFrame call_frame(&args, &body_rets, loop_var_types); TF_RETURN_IF_ERROR(lib->RunSync(opts, body_handle, &call_frame)); @@ -982,7 +982,7 @@ class ForOp : public AsyncOpKernel { args_[1 + i] = std::move(rets_[i]); } rets_.clear(); - profiler::TraceMe trace_me("ForOp"); + tsl::profiler::TraceMe trace_me("ForOp"); lib_->Run(opts_, body_handle_, args_, &rets_, [this](const Status& s) { if (s.ok()) { *iter_ += delta_; diff --git a/tensorflow/core/kernels/gather_nd_op_test.cc b/tensorflow/core/kernels/gather_nd_op_test.cc index 3212068f389a95..2758fbb3a57fe1 100644 --- a/tensorflow/core/kernels/gather_nd_op_test.cc +++ b/tensorflow/core/kernels/gather_nd_op_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "absl/strings/match.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" @@ -37,7 +38,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc index 77f17257281923..209dbbdd60761d 100644 --- a/tensorflow/core/kernels/image/resize_bicubic_op_test.cc +++ b/tensorflow/core/kernels/image/resize_bicubic_op_test.cc @@ -92,7 +92,13 @@ class ResizeBicubicOpTest : public OpsTestBase { std::array* weights, std::array* indices) { const int64_t in_loc = scale * out_loc; - const float delta = scale * out_loc - in_loc; + // Ensure that the following calculation is kept in a float to match the + // rounding done in the optimised case. Merging it with the following line + // keeps an intermediate value at higher precision and that leads to a + // divergence in the result. So keep the following two lines separate to + // ensure that the calculation is rounded as expected. + const float in_loc_float = scale * out_loc; + const float delta = in_loc_float - in_loc; const int64_t offset = lrintf(delta * kTableSize); const float* coeffs_tab = GetCoeffsTable(); *weights = {{coeffs_tab[offset * 2 + 1], coeffs_tab[offset * 2], diff --git a/tensorflow/core/kernels/linalg/einsum_op_impl.h b/tensorflow/core/kernels/linalg/einsum_op_impl.h index de78ca3444c291..2dfd62fc943dbd 100644 --- a/tensorflow/core/kernels/linalg/einsum_op_impl.h +++ b/tensorflow/core/kernels/linalg/einsum_op_impl.h @@ -612,11 +612,12 @@ class EinsumOp : public OpKernel { if (verbose) { string shape = ShapeTraceString(ctx); if (!shape.empty()) { - return profiler::TraceMeEncode( + return tsl::profiler::TraceMeEncode( std::move(op), {{"equation", equation}, {"shape", shape}}); } } - return profiler::TraceMeEncode(std::move(op), {{"equation", equation}}); + return tsl::profiler::TraceMeEncode(std::move(op), + {{"equation", equation}}); } private: diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 761920189c3933..d07b4b92dd2db5 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -257,7 +257,7 @@ void PartitionedCallOp::RunFunction(FunctionLibraryRuntime::Handle handle, std::vector* rets = new std::vector; const string& func_name = func_->name(); - profiler::TraceMe trace_me("PartitionedCallOp"); + tsl::profiler::TraceMe trace_me("PartitionedCallOp"); lib->Run(run_opts, handle, inputs, rets, [rets, done = std::move(done), ctx, func_name, step_container](const Status& status) { diff --git a/tensorflow/core/kernels/scatter_nd_op_test.cc b/tensorflow/core/kernels/scatter_nd_op_test.cc index 2e88088fa84ccd..02fa44f193b28f 100644 --- a/tensorflow/core/kernels/scatter_nd_op_test.cc +++ b/tensorflow/core/kernels/scatter_nd_op_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/match.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -38,7 +39,6 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/kernels/sendrecv_ops.cc b/tensorflow/core/kernels/sendrecv_ops.cc index 7f76cb475bd3f7..d15cc7feda3bd0 100644 --- a/tensorflow/core/kernels/sendrecv_ops.cc +++ b/tensorflow/core/kernels/sendrecv_ops.cc @@ -120,8 +120,8 @@ string SendOp::TraceString(const OpKernelContext& ctx, bool verbose) const { auto dst_it = attr.find("_dst"); const string& src = src_it != attr.end() ? src_it->second.s() : ""; const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; - string op = profiler::TraceMeOp(name_view(), type_string_view()); - return profiler::TraceMeEncode( + string op = tsl::profiler::TraceMeOp(name_view(), type_string_view()); + return tsl::profiler::TraceMeEncode( std::move(op), {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}}); } @@ -166,8 +166,8 @@ string RecvOp::TraceString(const OpKernelContext& ctx, bool verbose) const { auto dst_it = attr.find("_dst"); const string& src = src_it != attr.end() ? src_it->second.s() : ""; const string& dst = dst_it != attr.end() ? dst_it->second.s() : ""; - string op = profiler::TraceMeOp(name_view(), type_string_view()); - return profiler::TraceMeEncode( + string op = tsl::profiler::TraceMeOp(name_view(), type_string_view()); + return tsl::profiler::TraceMeEncode( std::move(op), {{"from", src}, {"to", dst}, {"key", parsed_key_.FullKey()}}); } diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index de5205c23201da..5256db35a1f228 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -127,6 +127,8 @@ class RangeOp : public OpKernel { #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, CPUDevice, T) #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, GPUDevice, T) +TF_CALL_half(REGISTER_CPU_KERNEL); +TF_CALL_bfloat16(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_int32(REGISTER_CPU_KERNEL); @@ -134,6 +136,8 @@ TF_CALL_int64(REGISTER_CPU_KERNEL); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_bfloat16(REGISTER_GPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); TF_CALL_int64(REGISTER_GPU_KERNEL); diff --git a/tensorflow/core/kernels/sequence_ops_gpu.cu.cc b/tensorflow/core/kernels/sequence_ops_gpu.cu.cc index 205978fc1a4ecc..f33b8cc982d2d6 100644 --- a/tensorflow/core/kernels/sequence_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/sequence_ops_gpu.cu.cc @@ -58,6 +58,8 @@ struct RangeFunctor { } // namespace functor #define DEFINE_FUNCTOR(T) template struct functor::RangeFunctor; +TF_CALL_half(DEFINE_FUNCTOR); +TF_CALL_bfloat16(DEFINE_FUNCTOR); TF_CALL_float(DEFINE_FUNCTOR); TF_CALL_double(DEFINE_FUNCTOR); TF_CALL_int32(DEFINE_FUNCTOR); diff --git a/tensorflow/core/kernels/sequence_ops_test.cc b/tensorflow/core/kernels/sequence_ops_test.cc index 1985d631d23739..d0a079f1827428 100644 --- a/tensorflow/core/kernels/sequence_ops_test.cc +++ b/tensorflow/core/kernels/sequence_ops_test.cc @@ -68,6 +68,21 @@ TEST_F(RangeOpTest, Simple_D32) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +TEST_F(RangeOpTest, Simple_Half) { + MakeOp(DT_HALF); + + // Feed and run + AddInputFromList(TensorShape({}), {0.5}); + AddInputFromList(TensorShape({}), {2}); + AddInputFromList(TensorShape({}), {0.3}); + TF_ASSERT_OK(RunOpKernel()); + + // Check the output + Tensor expected(allocator(), DT_HALF, TensorShape({5})); + test::FillValues(&expected, {0.5, 0.8, 1.1, 1.4, 1.7}); + test::ExpectTensorEqual(expected, *GetOutput(0)); +} + TEST_F(RangeOpTest, Simple_Float) { MakeOp(DT_FLOAT); diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index f295922b30a494..f29e73d9576d8d 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -52,6 +52,7 @@ REGISTER_KERNEL_BUILDER(Name("Shape") TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL); TF_CALL_bool(REGISTER_GPU_KERNEL); TF_CALL_variant(REGISTER_GPU_KERNEL); +TF_CALL_tstring(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/special_math/special_math_op_bessel.cc b/tensorflow/core/kernels/special_math/special_math_op_bessel.cc index 8efa183655e3c3..e29042cea1cd04 100644 --- a/tensorflow/core/kernels/special_math/special_math_op_bessel.cc +++ b/tensorflow/core/kernels/special_math/special_math_op_bessel.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "Eigen/Core" // from @eigen_archive #include "tensorflow/core/kernels/cwise_ops_common.h" #include "tensorflow/core/kernels/special_math/special_math_op_misc_impl.h" diff --git a/tensorflow/core/kernels/spectrogram_op_test.cc b/tensorflow/core/kernels/spectrogram_op_test.cc index 791fddae49ad87..650024bb888563 100644 --- a/tensorflow/core/kernels/spectrogram_op_test.cc +++ b/tensorflow/core/kernels/spectrogram_op_test.cc @@ -25,13 +25,13 @@ limitations under the License. #include "tensorflow/cc/ops/audio_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/ops/math_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/kernels/stochastic_cast_op_test.cc b/tensorflow/core/kernels/stochastic_cast_op_test.cc index b0a58356e85d0e..10d9eae13249ff 100644 --- a/tensorflow/core/kernels/stochastic_cast_op_test.cc +++ b/tensorflow/core/kernels/stochastic_cast_op_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "Eigen/Core" // from @eigen_archive +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -30,7 +31,6 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/lib/random/philox_random.h" namespace Eigen { diff --git a/tensorflow/core/kernels/uniform_quant_ops/BUILD b/tensorflow/core/kernels/uniform_quant_ops/BUILD index 5c158feb1c6c50..507a85779d3ab0 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/BUILD +++ b/tensorflow/core/kernels/uniform_quant_ops/BUILD @@ -191,7 +191,7 @@ tf_cc_test( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:test", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -205,6 +205,6 @@ tf_cc_test( "//tensorflow/core/platform:test", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc index 4d2e8699b71f9d..a331b8b1fb6db1 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/math_utils_test.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc index f3d2f2cb811913..c4f0ea50c4c663 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/tensor_utils_test.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc index 7ffc629bb43aaf..73121debd9c220 100644 --- a/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc +++ b/tensorflow/core/kernels/uniform_quant_ops/uniform_quantized_convolution_ops_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/numeric_types.h" @@ -27,7 +28,6 @@ limitations under the License. #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/core/lib/core/BUILD b/tensorflow/core/lib/core/BUILD index ff45b30174d280..3b31ec16f99dbc 100644 --- a/tensorflow/core/lib/core/BUILD +++ b/tensorflow/core/lib/core/BUILD @@ -226,6 +226,7 @@ filegroup( srcs = [ "status_test_util.h", "@local_tsl//tsl/lib/core:legacy_lib_core_status_test_util_header", + "@local_xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h index ef333ef2d1d7c9..3c604ee854e9d0 100644 --- a/tensorflow/core/lib/core/status_test_util.h +++ b/tensorflow/core/lib/core/status_test_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ #define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #endif // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/tensorflow/core/lib/db/sqlite_test.cc b/tensorflow/core/lib/db/sqlite_test.cc index 39ea38e4fd7e01..0a0042fda80dcf 100644 --- a/tensorflow/core/lib/db/sqlite_test.cc +++ b/tensorflow/core/lib/db/sqlite_test.cc @@ -17,11 +17,11 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/lib/gtl/edit_distance.h b/tensorflow/core/lib/gtl/edit_distance.h index 9481ba75f96bb3..818ec69fd96fd7 100644 --- a/tensorflow/core/lib/gtl/edit_distance.h +++ b/tensorflow/core/lib/gtl/edit_distance.h @@ -59,7 +59,7 @@ inline int64_t LevenshteinDistance(const gtl::ArraySlice& s, if (s == t) return 0; // Create work vector - gtl::InlinedVector scratch_holder(t_size); + absl::InlinedVector scratch_holder(t_size); int64_t* scratch = scratch_holder.data(); diff --git a/tensorflow/core/lib/histogram/BUILD b/tensorflow/core/lib/histogram/BUILD index 04a698ff39dd2a..8701b2f5c49f6e 100644 --- a/tensorflow/core/lib/histogram/BUILD +++ b/tensorflow/core/lib/histogram/BUILD @@ -25,7 +25,7 @@ cc_library( "//tensorflow/core/platform:mutex", "//tensorflow/core/platform:thread_annotations", "//tensorflow/core/platform:types", - "@local_tsl//tsl/lib/histogram", + "@local_xla//xla/tsl/lib/histogram", ], alwayslink = True, ) @@ -35,7 +35,7 @@ filegroup( name = "mobile_srcs_only_runtime", srcs = [ "histogram.h", - "@local_tsl//tsl/lib/histogram:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/histogram:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -44,7 +44,7 @@ filegroup( name = "legacy_lib_histogram_all_headers", srcs = [ "histogram.h", - "@local_tsl//tsl/lib/histogram:legacy_lib_histogram_all_headers", + "@local_xla//xla/tsl/lib/histogram:legacy_lib_histogram_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h index 551477cf483961..281e190f0bb615 100644 --- a/tensorflow/core/lib/histogram/histogram.h +++ b/tensorflow/core/lib/histogram/histogram.h @@ -19,12 +19,12 @@ limitations under the License. #include #include +#include "xla/tsl/lib/histogram/histogram.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/histogram/histogram.h" namespace tensorflow { diff --git a/tensorflow/core/lib/strings/BUILD b/tensorflow/core/lib/strings/BUILD index 72eb0a6dac308c..d8f4e6df21d573 100644 --- a/tensorflow/core/lib/strings/BUILD +++ b/tensorflow/core/lib/strings/BUILD @@ -51,7 +51,7 @@ cc_library( name = "proto_serialization", hdrs = ["proto_serialization.h"], deps = [ - "@local_tsl//tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) @@ -116,7 +116,7 @@ filegroup( "ordered_code.cc", "ordered_code.h", "proto_serialization.h", - "@local_tsl//tsl/lib/strings:mobile_srcs_only_runtime", + "@local_xla//xla/tsl/lib/strings:mobile_srcs_only_runtime", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -133,7 +133,7 @@ filegroup( "str_util.h", "strcat.h", "stringprintf.h", - "@local_tsl//tsl/lib/strings:legacy_lib_strings_all_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_strings_all_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -165,7 +165,7 @@ filegroup( "str_util.h", "strcat.h", "stringprintf.h", - "@local_tsl//tsl/lib/strings:legacy_lib_string_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_string_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) @@ -178,7 +178,7 @@ filegroup( "proto_serialization.h", "proto_text_util.h", "scanner.h", - "@local_tsl//tsl/lib/strings:legacy_lib_internal_public_string_headers", + "@local_xla//xla/tsl/lib/strings:legacy_lib_internal_public_string_headers", ], visibility = ["//tensorflow/core:__pkg__"], ) diff --git a/tensorflow/core/lib/strings/proto_serialization.h b/tensorflow/core/lib/strings/proto_serialization.h index 0c01708dadf4b2..e0c253f52dbe45 100644 --- a/tensorflow/core/lib/strings/proto_serialization.h +++ b/tensorflow/core/lib/strings/proto_serialization.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ #define TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_ -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" namespace tensorflow { // NOLINTBEGIN(misc-unused-using-decls) diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index ebaade2c926c8f..b05c4125eaa9bd 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -157,7 +157,6 @@ Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g, offset_i.push_back(strings::StrCat("offset:offset:", i)); dx_i.push_back(strings::StrCat("dx_", i, ":output:0")); } - DataTypeVector dtype_list(N, T); // ConcatGrad(dim, x, dy): // for i in range(N): diff --git a/tensorflow/core/ops/batch_ops.cc b/tensorflow/core/ops/batch_ops.cc index 99d45512374584..6d21ee483a1948 100644 --- a/tensorflow/core/ops/batch_ops.cc +++ b/tensorflow/core/ops/batch_ops.cc @@ -76,9 +76,17 @@ REGISTER_OP("BatchFunction") // allowed. The following options are available. // // - PAD_UP: pad to size 32. + // - BATCH_DOWN: schedule a batch of size 16 and leave 2 requests in the + // batch buffer. + // - MINIMIZE_TPU_COST_PER_REQUEST: a smarter greedy policy that chooses + // to either PAD_UP or BATCH_DOWN so as to minimize the TPU costs per + // real request. In this case, it would compare (batch_16_cost / 16) and + // (batch_32_cost / 18). + // + // WARNING: Not all batch schedulers might support this attribute. .Attr( "batch_padding_policy: " - "{'PAD_UP'} = 'PAD_UP'") + "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt index 8fecdf6b1490e7..d743b8e513a1b2 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/BatchFunction.pbtxt @@ -802,3 +802,152 @@ op { } is_distributed_communication: true } +op { + name: "BatchFunction" + input_arg { + name: "in_tensors" + type_list_attr: "Tin" + } + input_arg { + name: "captured_tensors" + type_list_attr: "Tcaptured" + } + output_arg { + name: "out_tensors" + type_list_attr: "Tout" + } + attr { + name: "f" + type: "func" + } + attr { + name: "num_batch_threads" + type: "int" + } + attr { + name: "max_batch_size" + type: "int" + } + attr { + name: "batch_timeout_micros" + type: "int" + } + attr { + name: "max_enqueued_batches" + type: "int" + default_value { + i: 10 + } + } + attr { + name: "allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "batching_queue" + type: "string" + default_value { + s: "" + } + } + attr { + name: "low_priority_max_batch_size" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_batch_timeout_micros" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "low_priority_allowed_batch_sizes" + type: "list(int)" + default_value { + list { + } + } + } + attr { + name: "low_priority_max_enqueued_batches" + type: "int" + default_value { + i: 0 + } + } + attr { + name: "mixed_priority_policy" + type: "string" + default_value { + s: "low_priority_padding_with_max_batch_size" + } + allowed_values { + list { + s: "low_priority_padding_with_max_batch_size" + s: "low_priority_padding_with_next_allowed_batch_size" + s: "priority_isolation" + } + } + } + attr { + name: "batch_padding_policy" + type: "string" + default_value { + s: "PAD_UP" + } + allowed_values { + list { + s: "PAD_UP" + s: "BATCH_DOWN" + s: "MINIMIZE_TPU_COST_PER_REQUEST" + } + } + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "enable_large_batch_splitting" + type: "bool" + default_value { + b: false + } + } + is_distributed_communication: true +} diff --git a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt index 55e73b740adefd..1f3016b69303f5 100644 --- a/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt +++ b/tensorflow/core/ops/compat/ops_history_v2/ParallelMapDatasetV2.pbtxt @@ -290,3 +290,98 @@ op { } } } +op { + name: "ParallelMapDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + experimental_full_type { + type_id: TFT_DATASET + args { + type_id: TFT_FOR_EACH + args { + type_id: TFT_PRODUCT + } + args { + type_id: TFT_TENSOR + args { + type_id: TFT_VAR + s: "output_types" + } + } + args { + type_id: TFT_VAR + s: "output_types" + } + } + } + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } + attr { + name: "deterministic" + type: "string" + default_value { + s: "default" + } + } + attr { + name: "preserve_cardinality" + type: "bool" + default_value { + b: false + } + } + attr { + name: "use_unbounded_threadpool" + type: "bool" + default_value { + b: false + } + } + attr { + name: "metadata" + type: "string" + default_value { + s: "" + } + } +} diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index d347d0cc508abe..7e8121206e8fd9 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -206,6 +206,7 @@ REGISTER_OP("ParallelMapDatasetV2") // "true", "false", or "default". .Attr("deterministic: string = 'default'") .Attr("preserve_cardinality: bool = false") + .Attr("use_unbounded_threadpool: bool = false") .Attr("metadata: string = ''") .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, "output_types")) diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index d05df091b7da28..dcf9e2f0e666e5 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4112,6 +4112,8 @@ op { allowed_values { list { s: "PAD_UP" + s: "BATCH_DOWN" + s: "MINIMIZE_TPU_COST_PER_REQUEST" } } } @@ -33010,6 +33012,13 @@ op { b: false } } + attr { + name: "use_unbounded_threadpool" + type: "bool" + default_value { + b: false + } + } attr { name: "metadata" type: "string" diff --git a/tensorflow/core/platform/BUILD b/tensorflow/core/platform/BUILD index 2f179f74073009..1dc0e202a6e84a 100644 --- a/tensorflow/core/platform/BUILD +++ b/tensorflow/core/platform/BUILD @@ -317,7 +317,7 @@ tf_cc_test( ":test_main", "//tensorflow/core:protos_all_cc", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 178a68ccfcf65a..3080e97f03feb2 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/platform/cord.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/platform/strcat.h" #include "tensorflow/core/platform/stringpiece.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" namespace tsl { @@ -53,7 +53,7 @@ tensorflow::GraphDef CreateTestProto() { return g; } -static void ExpectHasSubstr(StringPiece s, StringPiece expected) { +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { EXPECT_TRUE(absl::StrContains(s, expected)) << "'" << s << "' does not contain '" << expected << "'"; } @@ -82,7 +82,7 @@ TEST_F(DefaultEnvTest, IncompleteReadOutOfRange) { TF_EXPECT_OK(env_->NewRandomAccessFile(filename, &f)); // Reading past EOF should give an OUT_OF_RANGE error - StringPiece result; + absl::string_view result; char scratch[3]; EXPECT_EQ(error::OUT_OF_RANGE, f->Read(0, 3, &result, scratch).code()); EXPECT_EQ(input, result); @@ -300,7 +300,7 @@ class TmpDirFileSystem : public NullFileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; absl::Status FileExists(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); if (path.empty()) return errors::NotFound(dir, " not found"); // The special "flushed" file exists only if the filesystem's caches have @@ -316,7 +316,7 @@ class TmpDirFileSystem : public NullFileSystem { } absl::Status CreateDir(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); if (scheme != "tmpdirfs") { return errors::FailedPrecondition("scheme must be tmpdirfs"); @@ -335,7 +335,7 @@ class TmpDirFileSystem : public NullFileSystem { absl::Status IsDirectory(const string& dir, TransactionToken* token) override { - StringPiece scheme, host, path; + absl::string_view scheme, host, path; io::ParseURI(dir, &scheme, &host, &path); for (const auto& existing_dir : created_directories_) if (existing_dir == path) return absl::OkStatus(); @@ -405,7 +405,7 @@ TEST_F(DefaultEnvTest, LocalTempFilename) { // Read from the temporary file and check content. std::unique_ptr file_to_read; TF_CHECK_OK(env->NewRandomAccessFile(filename, &file_to_read)); - StringPiece content; + absl::string_view content; char scratch[1024]; CHECK_EQ( error::OUT_OF_RANGE, @@ -427,7 +427,7 @@ TEST_F(DefaultEnvTest, CreateUniqueFileName) { EXPECT_TRUE(env->CreateUniqueFileName(&filename, suffix)); EXPECT_TRUE(absl::StartsWith(filename, prefix)); - EXPECT_TRUE(str_util::EndsWith(filename, suffix)); + EXPECT_TRUE(absl::EndsWith(filename, suffix)); } TEST_F(DefaultEnvTest, GetProcessId) { diff --git a/tensorflow/core/platform/stringpiece.h b/tensorflow/core/platform/stringpiece.h index 17760cd7fee327..66040fc997173c 100644 --- a/tensorflow/core/platform/stringpiece.h +++ b/tensorflow/core/platform/stringpiece.h @@ -30,7 +30,7 @@ limitations under the License. namespace tensorflow { -using StringPiece = tsl::StringPiece; +using StringPiece = absl::string_view; } // namespace tensorflow diff --git a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h index 916b5305c77226..66de83fe1991b4 100644 --- a/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h +++ b/tensorflow/core/profiler/convert/trace_viewer/trace_events_to_json.h @@ -541,7 +541,7 @@ void TraceEventsToJson(const JsonTraceOptions& options, separator.Add(); output->Append(R"({"args":{"name":)", JsonEscape(device.name()), R"(},"name":"process_name","ph":"M","pid":)", device_id, - "}"); + R"(,"thread_count":)", device.resources_size(), "}"); } separator.Add(); output->Append(R"({"args":{"sort_index":)", device_id, diff --git a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc index 3c06bd37fa34b9..6c80adda05d5a0 100644 --- a/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_dcn_collective_stats_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system.h" #include "tensorflow/core/platform/statusor.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/core/profiler/protobuf/dcn_slack_analysis.pb.h" #include "tensorflow/core/profiler/utils/xplane_builder.h" #include "tensorflow/core/profiler/utils/xplane_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/profiler/protobuf/xplane.pb.h" diff --git a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc index 5609fd7658d86d..8b228479872bcc 100644 --- a/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc +++ b/tensorflow/core/profiler/convert/xplane_to_op_metrics_db.cc @@ -227,8 +227,6 @@ OpMetricsDb ConvertTpuDeviceTraceXPlaneToOpMetricsDb( XPlaneVisitor plane = tsl::profiler::CreateTfXPlaneVisitor(&device_trace); using OpMetricBySymbol = absl::flat_hash_map; - absl::flat_hash_map flat_op_metric; - XEventsOpMetricsDbBuilder builder; plane.ForEachLine([&](const XLineVisitor& line) { diff --git a/tensorflow/core/profiler/utils/derived_timeline.cc b/tensorflow/core/profiler/utils/derived_timeline.cc index 9aef8808ff49c6..383aad17ec1bdc 100644 --- a/tensorflow/core/profiler/utils/derived_timeline.cc +++ b/tensorflow/core/profiler/utils/derived_timeline.cc @@ -69,7 +69,6 @@ inline std::string HloOpEventPrefix(const GpuEventStats& stats) { std::vector GetOrCreateHloOpEventsMetadata( XPlaneBuilder& xplane, const GpuEventStats& stats, const Symbol symbol) { DCHECK(stats.IsXlaOp()); - DCHECK(!stats.hlo_module_name.empty()); std::vector hlo_op_events_metadata; hlo_op_events_metadata.reserve(stats.hlo_op_names.size()); // Prepend an HLO module identifier so HLO operators with the same name but in diff --git a/tensorflow/core/profiler/utils/derived_timeline_test.cc b/tensorflow/core/profiler/utils/derived_timeline_test.cc index 15de9ff05e3e19..ae9decdc19d259 100644 --- a/tensorflow/core/profiler/utils/derived_timeline_test.cc +++ b/tensorflow/core/profiler/utils/derived_timeline_test.cc @@ -71,6 +71,30 @@ TEST(DerivedTimelineTest, HloModuleNameTest) { }); } +// Checks that HLO module events are expanded. +TEST(DerivedTimelineTest, NoHloModuleNameTest) { + const absl::string_view kKernelDetails = "kernel_details"; + XSpace space; + tsl::profiler::GroupMetadataMap group_metadata_map; + XPlane& plane = *GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); + XPlaneBuilder plane_builder(&plane); + auto line_builder = plane_builder.GetOrCreateLine(0); + CreateXEvent(&plane_builder, &line_builder, "op1", 0, 100, + {{StatType::kKernelDetails, kKernelDetails}}); + CreateXEvent(&plane_builder, &line_builder, "op2", 200, 300, + {{StatType::kKernelDetails, kKernelDetails}}); + GenerateDerivedTimeLines(group_metadata_map, &space); + XPlaneVisitor plane_visitor = tsl::profiler::CreateTfXPlaneVisitor(&plane); + // Only the hlo module line is added and other empty lines are removed at the + // end. + EXPECT_EQ(plane_visitor.NumLines(), 1); + plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) { + if (line_visitor.Id() == 0) return; + EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule); + EXPECT_EQ(line_visitor.NumEvents(), 0); + }); +} + // Checks that the TF op events are expanded. TEST(DerivedTimelineTest, TfOpLineTest) { const absl::string_view kTfOpName = "mul:Mul"; diff --git a/tensorflow/core/profiler/utils/gpu_event_stats.cc b/tensorflow/core/profiler/utils/gpu_event_stats.cc index be4a9246ba4d7f..80de74edec0968 100644 --- a/tensorflow/core/profiler/utils/gpu_event_stats.cc +++ b/tensorflow/core/profiler/utils/gpu_event_stats.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/profiler/utils/gpu_event_stats.h" +#include + #include "absl/strings/str_split.h" #include "tensorflow/core/profiler/utils/xplane_schema.h" #include "tensorflow/core/profiler/utils/xplane_visitor.h" @@ -57,7 +59,7 @@ GpuEventStats::GpuEventStats(const XEventVisitor* event) { memcpy_details = stat.StrOrRefValue(); break; case StatType::kCorrelationId: - correlation_id = stat.IntValue(); + correlation_id = static_cast(stat.IntOrUintValue()); break; case StatType::kGroupId: group_id = stat.IntValue(); @@ -79,7 +81,7 @@ LaunchEventStats::LaunchEventStats(const XEventVisitor* event) { device_id = stat.IntOrUintValue(); break; case StatType::kCorrelationId: - correlation_id = stat.IntValue(); + correlation_id = static_cast(stat.IntOrUintValue()); break; case StatType::kGroupId: group_id = stat.IntValue(); diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index 70d5efc7c11a09..d6efbd1cd7a1b1 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -54,6 +54,7 @@ using tsl::profiler::kMetadataPlaneName; // NOLINT using tsl::profiler::kPythonTracerPlaneName; // NOLINT using tsl::profiler::kRoctracerApiPlaneName; // NOLINT using tsl::profiler::kSourceLineName; // NOLINT +using tsl::profiler::kSparseCorePlaneRegex; // NOLINT using tsl::profiler::kStepLineName; // NOLINT using tsl::profiler::kTensorFlowNameScopeLineName; // NOLINT using tsl::profiler::kTensorFlowOpLineName; // NOLINT diff --git a/tensorflow/core/protobuf/BUILD b/tensorflow/core/protobuf/BUILD index 86bf0017f3cfda..c5bfac7a5bd974 100644 --- a/tensorflow/core/protobuf/BUILD +++ b/tensorflow/core/protobuf/BUILD @@ -210,7 +210,7 @@ tf_proto_library( protodeps = [ ":error_codes_proto_impl", "//tensorflow/core/framework:protos_all", - "@local_tsl//tsl/protobuf:bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", "@local_tsl//tsl/protobuf:coordination_config_proto", "@local_tsl//tsl/protobuf:rpc_options_proto", "@local_tsl//tsl/protobuf:status_proto", @@ -218,9 +218,9 @@ tf_proto_library( tags = ["alt_dep=//third_party/tensorflow/core:protos_all"], visibility = ["//visibility:public"], exports = [ - "@local_tsl//tsl/protobuf:bfc_memory_map_proto", "@local_tsl//tsl/protobuf:rpc_options_proto", "@local_tsl//tsl/protobuf:status_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", ], ) diff --git a/tensorflow/core/protobuf/bfc_memory_map.proto b/tensorflow/core/protobuf/bfc_memory_map.proto index 2dbcbf00bc6102..fcde598787250f 100644 --- a/tensorflow/core/protobuf/bfc_memory_map.proto +++ b/tensorflow/core/protobuf/bfc_memory_map.proto @@ -2,6 +2,6 @@ syntax = "proto3"; package tensorflow.dummy; -import public "tsl/protobuf/bfc_memory_map.proto"; +import public "xla/tsl/protobuf/bfc_memory_map.proto"; option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index d383334d57620c..a04655fb2ce770 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1938 // Updated: 2024/7/29 +#define TF_GRAPH_DEF_VERSION 1960 // Updated: 2024/8/20 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // diff --git a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc index bdb6f9ee1e1cd2..5b881676decffa 100644 --- a/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc +++ b/tensorflow/core/runtime_fallback/kernel/attr_util_test.cc @@ -16,10 +16,10 @@ limitations under the License. #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tfrt/core_runtime/op_attr_type.h" // from @tf_runtime #include "tfrt/core_runtime/op_attrs.h" // from @tf_runtime #include "tfrt/support/forward_decls.h" // from @tf_runtime diff --git a/tensorflow/core/runtime_fallback/runtime/BUILD b/tensorflow/core/runtime_fallback/runtime/BUILD index ece9e5dbb6c437..45f433d2d732a9 100644 --- a/tensorflow/core/runtime_fallback/runtime/BUILD +++ b/tensorflow/core/runtime_fallback/runtime/BUILD @@ -195,6 +195,7 @@ cc_library( "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", "//tensorflow/core/kernels/batching_util:batch_resource_base", "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", + "//tensorflow/core/kernels/batching_util:batch_stats", "//tensorflow/core/kernels/batching_util:bounded_executor", "//tensorflow/core/kernels/batching_util:warmup", "//tensorflow/core/lib/core:refcount", @@ -205,7 +206,6 @@ cc_library( "//tensorflow/core/tfrt/fallback:op_kernel_runner", "//tensorflow/core/tfrt/utils:error_util", "//tensorflow/core/tfrt/utils:fallback_tensor", - "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", diff --git a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h index dae6eb35a43f59..86772a2a38d437 100644 --- a/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h +++ b/tensorflow/core/runtime_fallback/runtime/fallback_batch_kernel.h @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "tensorflow/core/framework/op_kernel.h" @@ -30,10 +29,12 @@ limitations under the License. #include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" #include "tensorflow/core/kernels/batching_util/batch_resource_base.h" #include "tensorflow/core/kernels/batching_util/batch_scheduler.h" +#include "tensorflow/core/kernels/batching_util/batch_stats.h" #include "tensorflow/core/kernels/batching_util/warmup.h" #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/random.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -215,6 +216,7 @@ void BatchFunctionFallbackKernel::ComputeAsync( batch_resource_options.batch_timeout_micros = batch_timeout_micros_; batch_resource_options.max_enqueued_batches = max_enqueued_batches_; batch_resource_options.allowed_batch_sizes = allowed_batch_sizes_; + batch_resource_options.batch_padding_policy = batch_padding_policy_; batch_resource_options.low_priority_max_batch_size = low_priority_max_batch_size_; batch_resource_options.low_priority_batch_timeout_micros = @@ -224,6 +226,13 @@ void BatchFunctionFallbackKernel::ComputeAsync( batch_resource_options.low_priority_allowed_batch_sizes = low_priority_allowed_batch_sizes_; + serving::ModelBatchStats& model_batch_stats = + serving::GlobalBatchStatsRegistry().model( + /* model_name= */ std::string(GetModelName(c)), + /* op_name= */ c->op_kernel().name()); + model_batch_stats.SetBatchTimeoutMicros(batch_timeout_micros_); + model_batch_stats.SetNumBatchThreads(num_batch_threads_); + std::unique_ptr new_resource; auto status = BatchResourceType::Create( c, batch_resource_options, batch_function_, diff --git a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc index e953c7088d5583..38b8fc37f3a432 100644 --- a/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc +++ b/tensorflow/core/runtime_fallback/runtime/runtime_fallback_batch_tf_opkernels.cc @@ -137,7 +137,8 @@ class FallbackBatchResource : public tensorflow::serving::BatchResourceBase { options.num_batch_threads, options.max_batch_size, options.batch_timeout_micros, options.max_enqueued_batches, options.allowed_batch_sizes, enable_large_batch_splitting, - disable_padding, options.low_priority_max_batch_size, + disable_padding, options.batch_padding_policy, + options.low_priority_max_batch_size, options.low_priority_batch_timeout_micros, options.low_priority_max_enqueued_batches, options.low_priority_allowed_batch_sizes, @@ -437,7 +438,7 @@ REGISTER_OP("_BatchFunctionFallback") // BatchFunction in core/ops/batch_ops.cc. .Attr( "batch_padding_policy: " - "{'PAD_UP'} = 'PAD_UP'") + "{'PAD_UP', 'BATCH_DOWN', 'MINIMIZE_TPU_COST_PER_REQUEST'} = 'PAD_UP'") .Attr("Tin: list(type)") .Attr("Tcaptured: list(type) >= 0") .Attr("Tout: list(type)") diff --git a/tensorflow/core/tfrt/common/BUILD b/tensorflow/core/tfrt/common/BUILD index ba79007691359a..eb9724f707ec5f 100644 --- a/tensorflow/core/tfrt/common/BUILD +++ b/tensorflow/core/tfrt/common/BUILD @@ -188,13 +188,13 @@ tf_cc_test( ":pjrt_state", ":pjrt_util", "//tensorflow/core:framework", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla/pjrt/cpu:cpu_client", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -214,13 +214,13 @@ tf_cuda_cc_test( "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_xla//xla/pjrt:pjrt_client", "@local_xla//xla/pjrt:tfrt_cpu_pjrt_client", "@local_xla//xla/service:gpu_plugin", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/tfrt/common/pjrt_state_test.cc b/tensorflow/core/tfrt/common/pjrt_state_test.cc index fddd72ea050509..03dcdb7c8b9c23 100644 --- a/tensorflow/core/tfrt/common/pjrt_state_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_state_test.cc @@ -21,10 +21,10 @@ limitations under the License. #include #include "xla/pjrt/cpu/cpu_client.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/common/pjrt_util_test.cc b/tensorflow/core/tfrt/common/pjrt_util_test.cc index 1361b72c2da686..48f774388d355d 100644 --- a/tensorflow/core/tfrt/common/pjrt_util_test.cc +++ b/tensorflow/core/tfrt/common/pjrt_util_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include #include "xla/pjrt/cpu/cpu_client.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/tfrt/common/pjrt_state.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/tfrt/fallback/fallback_state_test.cc b/tensorflow/core/tfrt/fallback/fallback_state_test.cc index d7d55311e7ffd4..2111171df01355 100644 --- a/tensorflow/core/tfrt/fallback/fallback_state_test.cc +++ b/tensorflow/core/tfrt/fallback/fallback_state_test.cc @@ -19,10 +19,10 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/const_op.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/tfrt/gpu/kernel/BUILD b/tensorflow/core/tfrt/gpu/kernel/BUILD index bd4f86131e3117..fef0e58310a334 100644 --- a/tensorflow/core/tfrt/gpu/kernel/BUILD +++ b/tensorflow/core/tfrt/gpu/kernel/BUILD @@ -13,20 +13,16 @@ cc_library( deps = [ ":gpu_runner", "//tensorflow/core:framework", - "//tensorflow/core/common_runtime:copy_tensor", "//tensorflow/core/framework:tensor", "//tensorflow/core/platform:status", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", - "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_utils", "//tensorflow/core/runtime_fallback/kernel:tensor_util", "//tensorflow/core/tfrt/utils:fallback_tensor", "//tensorflow/core/tfrt/utils:gpu_variables_table", - "//tensorflow/core/tfrt/utils:tensor_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@tf_runtime//:core_runtime", "@tf_runtime//:hostcontext", "@tf_runtime//:support", "@tf_runtime//:tensor_alwayslink", @@ -47,9 +43,11 @@ cc_library( "//tensorflow/compiler/jit:xla_launch_util", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:function_proto_cc", + "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/platform:notification", "//tensorflow/core/platform:status", - "//tensorflow/core/platform:statusor", "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", "//tensorflow/core/tfrt/common:global_state", "//tensorflow/core/tfrt/utils:fallback_tensor", @@ -59,6 +57,7 @@ cc_library( "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", @@ -122,6 +121,7 @@ cc_library( "//tensorflow/core/common_runtime/gpu:gpu_serving_device_selector", "//tensorflow/core/platform:status", "//tensorflow/core/tfrt/runtime", + "@com_google_absl//absl/status", "@local_xla//xla/tsl/framework:serving_device_selector_policies", "@tf_runtime//:hostcontext", ], diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc index 3143b8bd7821ae..d4047d4d206043 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.cc @@ -41,15 +41,18 @@ limitations under the License. #include "xla/tsl/framework/device_id_manager.h" #include "xla/tsl/framework/serving_device_selector.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/common/global_state.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" @@ -231,7 +234,7 @@ TransferVariablesAndInputs( int device_idx, const llvm::SmallVector& args, tfrt::ArrayRef resource_indices, Device* cpu_device, absl::flat_hash_map gpu_devices, - tfrt::gpu::GpuVariablesTable& vars_table, + tfrt::gpu::GpuVariablesTable& vars_table, bool variables_are_shared, const tfrt::ExecutionContext& exec_ctx) { llvm::SmallVector> results; @@ -244,35 +247,51 @@ TransferVariablesAndInputs( TF_ASSIGN_OR_RETURN(const std::vector devices_on_platform, tsl::DeviceIdManager::GetTfDevicesOnPlatform( device_type, platform_device_id)); - const int platform_idx = platform_device_id.value(); absl::flat_hash_set resource_indices_set(resource_indices.begin(), resource_indices.end()); + // If variables are shared, there is only one copy of variables for all + // logical devices on the same physical GPU device; otherwise, each logical + // device has its own copy of variables. + const int cache_copy_idx = + variables_are_shared ? platform_device_id.value() : device_idx; + for (int i = 0, resource_idx = 0; i < args.size(); ++i) { if (resource_indices_set.contains(i)) { // Transfer resources. + VLOG(2) << "Transfer resource arg[" << i << "]."; tfrt::AsyncValueRef device_tensor; auto cached_device_variable = - vars_table.GetDeviceVariable(args[i], platform_idx); + vars_table.GetDeviceVariable(args[i], cache_copy_idx); if (cached_device_variable) { - VLOG(2) << "Cache hit for resource arg[" << i << "]"; + VLOG(2) << "Cache hit for resource arg[" << i << "]."; device_tensor = cached_device_variable.CopyRef(); } else { - VLOG(2) << "Cache miss for resource arg[" << i << "]"; - // Distribute variables on virtual devices on the same GPU. - const int idx = resource_idx % devices_on_platform.size(); - const int gpu_device_idx = devices_on_platform[idx].value(); + VLOG(2) << "Cache miss for resource arg[" << i << "]."; + + int gpu_device_idx; + if (variables_are_shared) { + // Distribute variables on logical devices on the same GPU. + const int idx = resource_idx % devices_on_platform.size(); + gpu_device_idx = devices_on_platform[idx].value(); + } else { + gpu_device_idx = device_idx; + } + + VLOG(2) << "Transfer the resource arg[" << i << "] to device " + << gpu_device_idx << "."; device_tensor = TransferTensorToDevice(exec_ctx, args[i], gpu_devices.at(gpu_device_idx)); - vars_table.AddOrUpdateDeviceVariable(args[i], platform_idx, + vars_table.AddOrUpdateDeviceVariable(args[i], cache_copy_idx, std::move(device_tensor)); device_tensor = - vars_table.GetDeviceVariable(args[i], platform_idx).CopyRef(); + vars_table.GetDeviceVariable(args[i], cache_copy_idx).CopyRef(); } results.push_back(device_tensor); ++resource_idx; } else { // Transfer inputs. + VLOG(2) << "Transfer input arg[" << i << "]."; tfrt::AsyncValueRef device_tensor = TransferTensorToDevice(exec_ctx, args[i], gpu_devices.at(device_idx)); results.push_back(device_tensor); @@ -356,6 +375,7 @@ GpuRunner::Run(const GpuRunInputs& run_inputs) { tsl::DeviceReservation device_reservation = serving_device_selector_->ReserveDevice(absl::StrCat(fingerprint)); const int device_idx = device_reservation.device_index(); + VLOG(1) << "GpuRunner selected device " << device_idx << "."; // Compile the program. const XlaCompiler::CompilationResult* compilation_result; @@ -368,10 +388,10 @@ GpuRunner::Run(const GpuRunInputs& run_inputs) { TF_ASSIGN_OR_RETURN( llvm::SmallVector> transferred_args, - TransferVariablesAndInputs(device_idx, *run_inputs.args, - run_inputs.resource_indices, - run_inputs.cpu_device, *run_inputs.gpu_devices, - vars_table_, *run_inputs.exec_ctx)); + TransferVariablesAndInputs( + device_idx, *run_inputs.args, run_inputs.resource_indices, + run_inputs.cpu_device, *run_inputs.gpu_devices, vars_table_, + /*variables_are_shared=*/false, *run_inputs.exec_ctx)); llvm::SmallVector, 4> transferred_args_to_wait; diff --git a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h index fc61eff2d28139..d292fedfbc4bfc 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h +++ b/tensorflow/core/tfrt/gpu/kernel/gpu_runner.h @@ -18,7 +18,10 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" #include "xla/tsl/framework/serving_device_selector.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/status.h" @@ -27,6 +30,7 @@ limitations under the License. #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime +#include "tfrt/support/forward_decls.h" // from @tf_runtime namespace tensorflow { namespace gpu { diff --git a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc index 8cc6a6286abf75..43cb013da7b926 100644 --- a/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc +++ b/tensorflow/core/tfrt/gpu/kernel/gpurt_kernels.cc @@ -21,19 +21,15 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "tensorflow/core/common_runtime/copy_tensor.h" -#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" -#include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_utils.h" #include "tensorflow/core/runtime_fallback/kernel/tensor_util.h" #include "tensorflow/core/tfrt/gpu/kernel/gpu_runner.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tensorflow/core/tfrt/utils/gpu_variables_table.h" -#include "tensorflow/core/tfrt/utils/tensor_util.h" -#include "tfrt/host_context/async_dispatch.h" // from @tf_runtime #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime #include "tfrt/host_context/attribute_utils.h" // from @tf_runtime #include "tfrt/host_context/execution_context.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc index 94e52ad23ed51a..48f3160f8138da 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "xla/tsl/framework/serving_device_selector_policies.h" #include "tensorflow/core/common_runtime/gpu/gpu_serving_device_selector.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h index bb990224ea0fc9..452ccdd9b1804d 100644 --- a/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h +++ b/tensorflow/core/tfrt/gpu/kernel/tfrt_gpu_init.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #define TENSORFLOW_CORE_TFRT_GPU_KERNEL_TFRT_GPU_INIT_H_ #include "xla/tsl/framework/serving_device_selector_policies.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tfrt/runtime/runtime.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/graph_executor/BUILD b/tensorflow/core/tfrt/graph_executor/BUILD index 3a5d94536f7462..61d869fe2a767b 100644 --- a/tensorflow/core/tfrt/graph_executor/BUILD +++ b/tensorflow/core/tfrt/graph_executor/BUILD @@ -246,9 +246,9 @@ tf_cc_test( ":test_config_proto_cc", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/tfrt/graph_executor/config_test.cc b/tensorflow/core/tfrt/graph_executor/config_test.cc index bc3d18665b304d..fc1b54f4952fa6 100644 --- a/tensorflow/core/tfrt/graph_executor/config_test.cc +++ b/tensorflow/core/tfrt/graph_executor/config_test.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/tfrt/graph_executor/config.pb.h" #include "tensorflow/core/tfrt/graph_executor/test_config.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc index 0a0f073fe589fa..c0e07b385f763c 100644 --- a/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc +++ b/tensorflow/core/tfrt/graph_executor/graph_executor_test.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" @@ -50,7 +51,6 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/interpreter/value.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tfrt/cpp_tests/test_util.h" // from @tf_runtime #include "tfrt/host_context/resource_context.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/ifrt/BUILD b/tensorflow/core/tfrt/ifrt/BUILD index bd362354b9f420..919e0df3be45a0 100644 --- a/tensorflow/core/tfrt/ifrt/BUILD +++ b/tensorflow/core/tfrt/ifrt/BUILD @@ -30,6 +30,7 @@ cc_library( srcs = ["ifrt_serving_core_selector.cc"], hdrs = ["ifrt_serving_core_selector.h"], deps = [ + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", @@ -50,12 +51,71 @@ tf_cc_test( ], ) +cc_library( + name = "grid", + srcs = ["grid.cc"], + hdrs = ["grid.h"], + deps = [ + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + ], +) + +cc_library( + name = "ifrt_device_utils", + srcs = ["ifrt_device_utils.cc"], + hdrs = ["ifrt_device_utils.h"], + deps = [ + ":grid", + ":ifrt_config_proto_cc", + "//tensorflow/compiler/tf2xla:host_compute_metadata_proto_cc", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:attribute_map", + "@local_xla//xla/service:computation_placer_hdr", + ], +) + +tf_cc_test( + name = "ifrt_device_utils_test", + srcs = [ + "ifrt_device_utils_test.cc", + ], + tags = ["no_oss"], + deps = [ + ":ifrt_device_utils", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/python/ifrt", + "@local_xla//xla/python/ifrt:attribute_map", + "@local_xla//xla/python/ifrt:mock", + "@local_xla//xla/service:computation_placer_hdr", + ], +) + cc_library( name = "ifrt_serving_executable", srcs = ["ifrt_serving_executable.cc"], hdrs = ["ifrt_serving_executable.h"], deps = [ ":ifrt_config_proto_cc", + ":ifrt_device_utils", ":ifrt_loaded_variable_registry", ":ifrt_loaded_variable_utils", ":ifrt_restore_tensor_registry", @@ -85,7 +145,10 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -158,11 +221,21 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/python/ifrt", "@local_xla//xla/tsl/concurrency:ref_count", ], ) +cc_library( + name = "ifrt_model_restore_context", + hdrs = ["ifrt_model_restore_context.h"], + deps = [ + ":checkpoint_loader", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "ifrt_model_context", srcs = ["ifrt_model_context.cc"], @@ -253,7 +326,6 @@ cc_library( ":sharding_utils", "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", "//tensorflow/core:framework", - "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -284,6 +356,7 @@ cc_library( "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core_no_xla", "//tensorflow/core/common_runtime/eager:tensor_handle", + "//tensorflow/core/framework:function_proto_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:fixed_array", @@ -348,10 +421,10 @@ tf_cc_test( "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/python/ifrt", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -371,16 +444,17 @@ tf_cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/python/ifrt", "@local_xla//xla/python/ifrt:test_util", "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", "@local_xla//xla/tsl/concurrency:ref_count", + "@local_xla//xla/tsl/lib/core:status_test_util", "@tf_runtime//:hostcontext", ], ) @@ -427,6 +501,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core/framework:tensor_matcher", "//tensorflow/core/framework:tensor_testutil", + "//tensorflow/core/framework:types_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -535,3 +610,38 @@ tf_cc_test( "@tf_runtime//backends/cpu:tf_ops_alwayslink", ], ) + +cc_library( + name = "checkpoint_loader", + srcs = ["checkpoint_loader.cc"], + hdrs = ["checkpoint_loader.h"], + deps = [ + ":ifrt_loaded_variable_utils", + ":ifrt_restore_tensor_registry", + "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", + "//tensorflow/core:framework", + "//tensorflow/core/common_runtime:function", + "//tensorflow/core/framework:attr_value_proto_cc", + "//tensorflow/core/framework:node_def_util", + "//tensorflow/core/framework:tensor", + "//tensorflow/core/framework:types_proto_cc", + "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/mlrt/bytecode", + "//tensorflow/core/tfrt/mlrt/kernel:context", + "//tensorflow/core/tfrt/mlrt/kernel:kernel_runner_utils", + "//tensorflow/core/tfrt/mlrt/kernel:shard_restore_util", + "//tensorflow/core/tfrt/utils:fallback_tensor", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tstring", + "@local_xla//xla/python/ifrt", + "@tf_runtime//:hostcontext", + ], +) diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc new file mode 100644 index 00000000000000..a970b027e48b40 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.cc @@ -0,0 +1,359 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" + +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "xla/python/ifrt/future.h" +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_handle.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" +#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h" +#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/tstring.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +namespace { + +static constexpr int kNumRestoreClusters = 4; + +// A shard of variables to be restored. +struct RestoreVariableShard { + tensorflow::Tensor prefix; + tensorflow::Tensor tensor_names; + tensorflow::Tensor shape_and_slices; + std::vector var_handles; + tensorflow::AttrValue dtypes_attr_value; + std::vector restored_dtypes; + std::vector truncate_in_cast; +}; + +struct AsyncState { + explicit AsyncState( + const std::vector& input_tf_tensor_values, + const OpKernelContext::Params& params, int num_outputs, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime) + : run_state(input_tf_tensor_values, params), + context(&run_state.params, num_outputs), + device_manager(device_manager), + process_function_library_runtime(process_function_library_runtime) {} + + tfrt_stub::OpKernelRunState run_state; + OpKernelContext context; + const tensorflow::DeviceMgr& device_manager; + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime; + + std::vector> results; +}; + +// Returns a casted tensor if successful. +absl::StatusOr Cast( + tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype, + tensorflow::DataType cast_dtype, bool truncate_in_cast, + const tensorflow::DeviceMgr& device_manager, + const tensorflow::ProcessFunctionLibraryRuntime& + process_function_library_runtime, + OpKernelContext::Params& params) { + auto runner = + tfrt_stub::OpKernelRunner::Create( + /*op_name=*/ + "Cast", /*node_name=*/"Cast", params.device->name(), + /*num_args=*/1, + [&](tensorflow::AttrValueMap* attr_value_map) { + tensorflow::AttrValue restored_dtype_attr_value; + restored_dtype_attr_value.set_type(restored_dtype); + attr_value_map->insert({"SrcT", restored_dtype_attr_value}); + + tensorflow::AttrValue cast_dtype_attr_value; + cast_dtype_attr_value.set_type(cast_dtype); + attr_value_map->insert({"DstT", cast_dtype_attr_value}); + + tensorflow::AttrValue truncate_attr_value; + truncate_attr_value.set_b(truncate_in_cast); + attr_value_map->insert({"Truncate", truncate_attr_value}); + return absl::OkStatus(); + }, + device_manager, process_function_library_runtime) + .value(); + + std::vector input_tf_tensor_values; + input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor)); + + tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params); + // Use persistent device instead of the per request device. + + OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1); + + runner.Run(&op_kernel_context); + + if (!op_kernel_context.status().ok()) { + return op_kernel_context.status(); + } + DCHECK_EQ(op_kernel_context.num_outputs(), 1); + return *(op_kernel_context.mutable_output(0)); +} + +absl::Status RunShard(RestoreVariableShard shard, + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue, + tf_mlrt::Context& context) { + if (!ifrt_restore_tensor_registry) { + return absl::InternalError("ifrt_restore_tensor_registry must not be null"); + } + if (!checkpoint_loader_work_queue) { + return absl::InternalError("checkpoint_loader_work_queue must not be null"); + } + const int num_outputs = shard.var_handles.size(); + DCHECK_EQ(num_outputs, shard.tensor_names.NumElements()); + auto& fallback_request_state = context.fallback_request_state(); + + // Use `tf.RestoreV2` to restore tensor. This will also populate + // tensorflow::ResourceManager. + // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the + // variable is only used by device/IFRT. + // TODO(b/319045348): consider directly calling restore function such as that + // in /tensorflow/core/kernels/save_restore_v2_ops.cc + auto runner = + tfrt_stub::OpKernelRunner::Create( + /*op_name=*/ + "RestoreV2", /*node_name=*/"RestoreV2", + context.params().device->name(), + /*num_args=*/3, + [&](tensorflow::AttrValueMap* attr_value_map) { + attr_value_map->insert({"dtypes", shard.dtypes_attr_value}); + return absl::OkStatus(); + }, + fallback_request_state.device_manager(), + fallback_request_state.process_function_library_runtime()) + .value(); + + // Prepare the input tensors. + std::vector input_tf_tensor_values; + static constexpr int kNumInputArgs = 3; + input_tf_tensor_values.resize(kNumInputArgs); + // We need to keep these tensor alive + input_tf_tensor_values[0].tensor = &shard.prefix; + input_tf_tensor_values[1].tensor = &shard.tensor_names; + input_tf_tensor_values[2].tensor = &shard.shape_and_slices; + + auto& params = context.params(); + tf_mlrt::SetUpParams(runner, input_tf_tensor_values, params); + // Use persistent device instead of the per request device. + params.device = context.fallback_request_state().device_manager().HostCPU(); + + auto async_state = std::make_unique( + input_tf_tensor_values, params, num_outputs, + fallback_request_state.device_manager(), + fallback_request_state.process_function_library_runtime()); + + for (int i = 0; i < num_outputs; ++i) { + auto promise = xla::ifrt::Future::CreatePromise(); + auto future = xla::ifrt::Future(promise); + const ResourceHandle& var_handle = + shard.var_handles[i].tensor().scalar()(); + + TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape, + ifrt_serving::GetDtypeAndShape(var_handle)); + + std::string runtime_name = + ifrt_serving::GetRuntimeNameFromVarHandle(var_handle); + + ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo + restored_tensor_info = {false, std::move(dtype_and_shape), + std::move(future)}; + if (auto status = ifrt_restore_tensor_registry->TryRegister( + runtime_name, restored_tensor_info); + !status.ok()) { + // Propagate errors so that if already-registered futures are being waited + // on, they can be unblocked. + for (auto& result : async_state->results) { + std::move(result).Set(status); + }; + return status; + } + async_state->results.push_back(std::move(promise)); + } + + // Use dedicated work queue for restore operation. + checkpoint_loader_work_queue->AddTask([runner = std::move(runner), + async_state = std::move(async_state), + shard = std::move(shard)]() { + // Keep input tensor alive in `shard`. + auto* op_kernel_context_ptr = &async_state->context; + runner.Run(op_kernel_context_ptr); + + auto& op_kernel_context = async_state->context; + if (!op_kernel_context.status().ok()) { + for (auto& result : async_state->results) { + std::move(result).Set(op_kernel_context.status()); + } + return; + } + DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); + DCHECK_EQ(shard.truncate_in_cast.size(), op_kernel_context.num_outputs()); + + // TODO(b/343964091): consider to run multiple casts in parallel. + for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { + DCHECK(op_kernel_context.mutable_output(i)); + + if (op_kernel_context.mutable_output(i)->dtype() != + shard.restored_dtypes[i]) { + std::move(async_state->results[i]) + .Set(absl::InvalidArgumentError(absl::StrCat( + "The restored tensor has a different dtype than the " + "variable handle: ", + op_kernel_context.mutable_output(i)->dtype(), " vs. ", + shard.restored_dtypes[i]))); + return; + } + const ResourceHandle& var_handle = + shard.var_handles[i].tensor().scalar()(); + + if (shard.restored_dtypes[i] == var_handle.dtypes_and_shapes()[0].dtype) { + std::move(async_state->results[i]) + .Set(*std::move(op_kernel_context.mutable_output(i))); + } else { + absl::StatusOr cast_output = + Cast(*op_kernel_context.mutable_output(i), shard.restored_dtypes[i], + var_handle.dtypes_and_shapes()[0].dtype, + shard.truncate_in_cast[i], async_state->device_manager, + async_state->process_function_library_runtime, + async_state->run_state.params); + if (!cast_output.ok()) { + std::move(async_state->results[i]).Set(cast_output.status()); + } else { + std::move(async_state->results[i]).Set(*std::move(cast_output)); + } + } + } + }); + return absl::OkStatus(); +} + +int64_t GetSizeFromVarHandle(const ResourceHandle& handle) { + int size = 0; + for (auto& dtype_and_shape : handle.dtypes_and_shapes()) { + size += DataTypeSize(dtype_and_shape.dtype) * + dtype_and_shape.shape.num_elements(); + } + return size; +} + +} // namespace + +absl::Status CheckpointLoader::PrepareRestore( + mlir::OwningOpRef module) { + VLOG(1) << "Skip CheckpointLoader::PrepareRestore"; + return absl::OkStatus(); +} + +absl::Status CheckpointLoader::Load( + const tensorflow::tfrt_stub::FallbackTensor& prefix, + const std::vector& var_handles, + const tensorflow::tfrt_stub::FallbackTensor& tensor_names, + const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices, + const mlrt::bc::Vector& restored_dtypes, + const mlrt::bc::Vector& truncate_in_cast, tf_mlrt::Context& context) { + std::vector variable_sizes; + variable_sizes.reserve(var_handles.size()); + for (auto& handle : var_handles) { + variable_sizes.push_back(GetSizeFromVarHandle( + handle.tensor().scalar()())); + } + + std::vector> sharded_indices = tf_mlrt::ShardVariables( + kNumRestoreClusters, absl::MakeSpan(variable_sizes)); + + // Converts the names and slices back to the tensor. + auto vector_to_tensor = [](const std::vector& vec) { + tensorflow::Tensor tensor(tensorflow::DT_STRING, + TensorShape({static_cast(vec.size())})); + for (int i = 0; i < vec.size(); ++i) { + tensor.flat()(i) = vec[i]; + } + return tensor; + }; + + const auto& tensor_names_flat = tensor_names.tensor().flat(); + const auto& shape_and_slices_flat = + shape_and_slices.tensor().flat(); + + std::vector shards; + shards.reserve(sharded_indices.size()); + for (auto& sharded_index : sharded_indices) { + RestoreVariableShard shard; + shard.var_handles.reserve(sharded_index.size()); + shard.truncate_in_cast.reserve(sharded_index.size()); + shard.restored_dtypes.reserve(sharded_index.size()); + std::vector tensor_names; + std::vector shape_and_slices; + shape_and_slices.reserve(sharded_index.size()); + tensor_names.reserve(sharded_index.size()); + for (int index : sharded_index) { + tensor_names.push_back(tensor_names_flat(index)); + shape_and_slices.push_back(shape_and_slices_flat(index)); + shard.dtypes_attr_value.mutable_list()->add_type(restored_dtypes[index]); + shard.var_handles.push_back(var_handles[index]); + shard.restored_dtypes.push_back(restored_dtypes[index]); + shard.truncate_in_cast.push_back(truncate_in_cast[index]); + } + shard.prefix = prefix.tensor(); + shard.tensor_names = vector_to_tensor(tensor_names); + shard.shape_and_slices = vector_to_tensor(shape_and_slices); + shards.push_back(std::move(shard)); + } + for (const auto& shard : shards) { + TF_RETURN_IF_ERROR(RunShard(shard, ifrt_restore_tensor_registry_, + checkpoint_loader_work_queue_, context)); + } + return absl::OkStatus(); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/checkpoint_loader.h b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h new file mode 100644 index 00000000000000..ab4a2ab48e12aa --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/checkpoint_loader.h @@ -0,0 +1,67 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" +#include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" +#include "tensorflow/core/tfrt/mlrt/kernel/context.h" +#include "tensorflow/core/tfrt/utils/fallback_tensor.h" +#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime + +namespace tensorflow { +namespace ifrt_serving { + +// TODO(b/352551302) Move the unit test in ifrt_ops_kernel for restore to test +// this class's APIs. +// Implement the `CheckpointLoaderInterface` by using RestoreV2. +class CheckpointLoader { + public: + explicit CheckpointLoader( + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue) + : ifrt_restore_tensor_registry_(ifrt_restore_tensor_registry), + checkpoint_loader_work_queue_(checkpoint_loader_work_queue) {} + virtual ~CheckpointLoader() = default; + + // Called before `Load` to do some preparation work. + virtual absl::Status PrepareRestore(mlir::OwningOpRef module); + + // Load the checkpoint. This API is designed to be compatible with the + // `tf_mlrt.ifrt_restore_variable` kernel. + virtual absl::Status Load( + const tensorflow::tfrt_stub::FallbackTensor& prefix, + const std::vector& var_handles, + const tensorflow::tfrt_stub::FallbackTensor& tensor_names, + const tensorflow::tfrt_stub::FallbackTensor& shape_and_slices, + const mlrt::bc::Vector& restored_dtypes, + const mlrt::bc::Vector& truncate_in_cast, + tf_mlrt::Context& context); + + IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry_; + tfrt::ConcurrentWorkQueue* checkpoint_loader_work_queue_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_CHECKPOINT_LOADER_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h b/tensorflow/core/tfrt/ifrt/grid.cc similarity index 62% rename from third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h rename to tensorflow/core/tfrt/ifrt/grid.cc index 73486915d3ae7d..672503ff9d53b3 100644 --- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.h +++ b/tensorflow/core/tfrt/ifrt/grid.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,16 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_ -#define XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_ +#include "tensorflow/core/tfrt/ifrt/grid.h" #include -#include "xla/xla_data.pb.h" +#include "absl/strings/str_cat.h" -namespace xla { -std::string FrontendAttributesToString( - const FrontendAttributes& frontend_attributes); -} // namespace xla +namespace tensorflow { +namespace ifrt_serving { -#endif // XLA_HLO_IR_HLO_FRONTEND_ATTRIBUTES_H_ +std::string GridCoords::ToString() const { + return absl::StrCat("[", dim[0], ",", dim[1], ",", dim[2], ",", dim[3], "]"); +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/grid.h b/tensorflow/core/tfrt/ifrt/grid.h new file mode 100644 index 00000000000000..28e52809083166 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/grid.h @@ -0,0 +1,77 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/strings/str_format.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Coordinates that identify a particular point in a 4-d grid (usually a TPU +// topology). +struct GridCoords { + int dim[4]; + + constexpr GridCoords(int d0, int d1, int d2, int d3) : dim{d0, d1, d2, d3} {} + GridCoords() : GridCoords(0, 0, 0, 0) {} + + static GridCoords Zeroes() { return GridCoords(0, 0, 0, 0); } + static GridCoords Ones() { return GridCoords(1, 1, 1, 1); } + + int operator[](int i) const { + DCHECK_GE(i, 0); + DCHECK_LT(i, 4); + return dim[i]; + } + + int& operator[](int i) { + DCHECK_GE(i, 0); + DCHECK_LT(i, 4); + return dim[i]; + } + + int Product() const { return dim[0] * dim[1] * dim[2] * dim[3]; } + + std::string ToString() const; + + template + friend void AbslStringify(Sink& sink, const GridCoords& value) { + absl::Format(&sink, "%s", value.ToString()); + } + + friend bool operator==(const GridCoords& a, const GridCoords& b) { + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3]; + } + + friend std::ostream& operator<<(std::ostream& os, const GridCoords& c) { + return os << c.ToString(); + } + + template + friend H AbslHashValue(H h, const GridCoords& c) { + return H::combine(std::move(h), c[0], c[1], c[2], c[3]); + } +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_GRID_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_device_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_device_utils.cc new file mode 100644 index 00000000000000..2a65007552829c --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_device_utils.cc @@ -0,0 +1,194 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/service/computation_placer.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/grid.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { + +static constexpr int kTpuTopologyRank = 4; // (x, y, z, core). + +absl::StatusOr> GetAssignedIfrtDevices( + const xla::ifrt::Client& ifrt_client, int num_replicas, + int num_cores_per_replica, + std::optional> device_assignment) { + const int num_devices = num_replicas * num_cores_per_replica; + + // No device coordinates from ifrt devices. This disallow the mapping from + // device assignment attribute to ifrt devices. + bool no_device_coordinates = false; + for (auto* device : ifrt_client.devices()) { + if (!device->Attributes().map().contains("coords") || + !device->Attributes().map().contains("core_on_chip")) { + no_device_coordinates = true; + break; + } + } + + // If there is no device assignment attribute or no device coordinates, + // get the default device assignment from IFRT. + if (!device_assignment || device_assignment->empty() || + no_device_coordinates) { + TF_ASSIGN_OR_RETURN(xla::DeviceAssignment xla_device_assignment, + ifrt_client.GetDefaultDeviceAssignment( + num_replicas, num_cores_per_replica)); + VLOG(3) << "Getting default device lists"; + std::vector devices; + devices.reserve(num_devices); + for (int replica_idx = 0; replica_idx < num_replicas; replica_idx++) { + for (int core_idx = 0; core_idx < num_cores_per_replica; core_idx++) { + // This relies on the IFRT implementation of GetDefaultDeviceAssignment + // that keeps device id the same between device assignment and ifrt + // device list. + auto device_id = xla_device_assignment(replica_idx, core_idx); + TF_ASSIGN_OR_RETURN( + xla::ifrt::Device * device, + ifrt_client.LookupDevice(xla::ifrt::DeviceId(device_id))); + devices.push_back(device); + } + } + return devices; + } + + // Devices ordered as in the device assignment attribute. + absl::flat_hash_map devices_from_attribute; + + // Each device is encoded by [x,y,z,c] at the attribute. + std::vector coord; + coord.reserve(kTpuTopologyRank); + int device_index = 0; + + for (auto coord_attr : *device_assignment) { + coord.push_back(coord_attr); + if (coord.size() == kTpuTopologyRank) { + devices_from_attribute.insert( + {GridCoords(coord[0], coord[1], coord[2], coord[3]), device_index}); + device_index++; + coord.clear(); + } + } + if (!coord.empty()) { + return absl::FailedPreconditionError( + absl::StrCat("Device assignment attribute is expected to be a multiple " + "of 4, but got ", + device_assignment->size())); + } + + if (devices_from_attribute.size() != num_devices) { + return absl::FailedPreconditionError( + absl::StrCat("Device assignment has ", devices_from_attribute.size(), + " devices, but expected ", num_devices)); + } + + struct IfrtDeviceGrid { + xla::ifrt::Device* device; + GridCoords grid; + int index_at_attribute; + }; + std::vector ifrt_devices; + ifrt_devices.reserve(num_devices); + + for (auto* device : ifrt_client.devices()) { + GridCoords grid; + auto coords_it = device->Attributes().map().find("coords"); + auto core_on_chip_it = device->Attributes().map().find("core_on_chip"); + if (coords_it != device->Attributes().map().end() && + core_on_chip_it != device->Attributes().map().end()) { + VLOG(3) << "Adding coords and core_on_chip attributes:" + << device->DebugString(); + auto coords_list = + std::get(coords_it->second); + auto core_on_chip = std::get( + core_on_chip_it->second); + + if (coords_list.value.size() != 3) { + return absl::InternalError(absl::StrCat( + "Expected coords to be of size 3, but got ", + coords_list.value.size(), " for device ", device->DebugString())); + } + grid = GridCoords(coords_list.value[0], coords_list.value[1], + coords_list.value[2], core_on_chip.value); + } else { + return absl::InternalError( + absl::StrCat("Device ", device->DebugString(), + " does not have coords or core_on_chip attribute.")); + } + + auto device_it_from_attribute = devices_from_attribute.find(grid); + if (device_it_from_attribute == devices_from_attribute.end()) { + VLOG(1) << "Device coordinates " << grid.ToString() + << " does not match any TPU device assigned " + << absl::StrJoin(*device_assignment, " "); + continue; + } + ifrt_devices.push_back( + {.device = device, + .grid = grid, + .index_at_attribute = device_it_from_attribute->second}); + } + + if (ifrt_devices.size() != num_devices) { + return absl::FailedPreconditionError(absl::StrCat( + "Match ", ifrt_devices.size(), " devices, but expected ", num_devices)); + } + + // Sort the devices by the order in the device assignment attribute. + absl::c_sort(ifrt_devices, [&](const auto& lhs, const auto& rhs) { + return lhs.index_at_attribute < rhs.index_at_attribute; + }); + + std::vector result; + result.reserve(ifrt_devices.size()); + for (auto& device_grid : ifrt_devices) { + result.push_back(device_grid.device); + VLOG(3) << "Device: " << device_grid.device->DebugString() + << " is assigned"; + } + return result; +} + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h b/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h new file mode 100644 index 00000000000000..f779aa62c37469 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_device_utils.h @@ -0,0 +1,63 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "xla/python/ifrt/array.h" +#include "xla/python/ifrt/client.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/executable.h" +#include "xla/python/ifrt/host_callback.h" +#include "xla/xla_data.pb.h" +#include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" + +namespace tensorflow { +namespace ifrt_serving { + +// Returns the assigned IFRT devices based on the device assignment attribute. +// +// params: +// ifrt_client: The ifrt client. +// num_replicas: The number of replicas. +// num_cores_per_replica: The number of cores per replica. +// +// device_assignment: The device assignment array encoded as +// [x0,y0,z0,core0,x1,y1,z1,core1, ...]. Optional. If not provided, the +// devices will be assigned based on the default order returned by the IFRT +// client. +// +// returns: +// The assigned devices. +absl::StatusOr> GetAssignedIfrtDevices( + const xla::ifrt::Client& ifrt_client, int num_replicas, + int num_cores_per_replica, + std::optional> device_assignment); + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_DEVICE_UTILS_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_device_utils_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_device_utils_test.cc new file mode 100644 index 00000000000000..6f36192d1f7c46 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_device_utils_test.cc @@ -0,0 +1,131 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "xla/python/ifrt/attribute_map.h" +#include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/mock.h" +#include "xla/service/computation_placer.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow { +namespace ifrt_serving { +namespace { +using ::testing::ElementsAre; +using ::testing::Return; +using ::testing::ReturnRef; +using ::tsl::testing::StatusIs; + +static constexpr int kNumReplicas = 1; +static constexpr int kNumCoresPerReplica = 2; +// Intentionally have more devices than kNumReplicas * kNumCoresPerReplica for +// testing purposes. +static constexpr int kNumDevices = 4; +static constexpr int kDeviceIdOffset = 8; + +class IfrtDeviceUtilsTest : public ::testing::Test { + protected: + void SetUp() override { + mocked_devices_.reserve(kNumDevices); + devices_.reserve(kNumDevices); + for (int i = 0; i < kNumDevices; ++i) { + mocked_devices_.push_back(std::make_unique()); + ON_CALL(*mocked_devices_[i], Attributes()) + .WillByDefault(ReturnRef(device_attributes_maps_[i])); + ON_CALL(*mocked_devices_[i], Id()) + .WillByDefault(Return(xla::ifrt::DeviceId(kDeviceIdOffset + i))); + ON_CALL(client_, LookupDevice(xla::ifrt::DeviceId(kDeviceIdOffset + i))) + .WillByDefault(Return(mocked_devices_[i].get())); + + devices_.push_back(mocked_devices_[i].get()); + }; + + ON_CALL(client_, devices()).WillByDefault(Return(devices_)); + + // Default use the last two devices. + xla::DeviceAssignment assignment(kNumReplicas, kNumCoresPerReplica); + assignment(0, 0) = kDeviceIdOffset + 2; + assignment(0, 1) = kDeviceIdOffset + 3; + + ON_CALL(client_, + GetDefaultDeviceAssignment(kNumReplicas, kNumCoresPerReplica)) + .WillByDefault(Return(assignment)); + } + + xla::ifrt::MockClient client_; + std::vector> mocked_devices_; + + std::vector devices_; + std::vector device_attributes_maps_ = { + xla::ifrt::AttributeMap(xla::ifrt::AttributeMap::Map{ + {"coords", xla::ifrt::AttributeMap::Int64ListValue({1, 0, 0})}, + {"core_on_chip", xla::ifrt::AttributeMap::Int64Value(0)}}), + xla::ifrt::AttributeMap(xla::ifrt::AttributeMap::Map{ + {"coords", xla::ifrt::AttributeMap::Int64ListValue({1, 0, 0})}, + {"core_on_chip", xla::ifrt::AttributeMap::Int64Value(1)}}), + xla::ifrt::AttributeMap(xla::ifrt::AttributeMap::Map{ + {"coords", xla::ifrt::AttributeMap::Int64ListValue({2, 0, 0})}, + {"core_on_chip", xla::ifrt::AttributeMap::Int64Value(0)}}), + xla::ifrt::AttributeMap(xla::ifrt::AttributeMap::Map{ + {"coords", xla::ifrt::AttributeMap::Int64ListValue({2, 0, 0})}, + {"core_on_chip", xla::ifrt::AttributeMap::Int64Value(1)}}), + }; +}; + +TEST_F(IfrtDeviceUtilsTest, Basic) { + std::vector device_assignment_attr = {1, 0, 0, 1, 1, 0, 0, 0}; + TF_ASSERT_OK_AND_ASSIGN( + auto devices_from_attribute, + GetAssignedIfrtDevices(client_, kNumReplicas, kNumCoresPerReplica, + device_assignment_attr)); + EXPECT_THAT(devices_from_attribute, ElementsAre(devices_[1], devices_[0])); +} + +TEST_F(IfrtDeviceUtilsTest, SeparateXCoordinates) { + std::vector device_assignment_attr = {1, 0, 0, 1, 2, 0, 0, 0}; + TF_ASSERT_OK_AND_ASSIGN( + auto devices_from_attribute, + GetAssignedIfrtDevices(client_, kNumReplicas, kNumCoresPerReplica, + device_assignment_attr)); + EXPECT_THAT(devices_from_attribute, ElementsAre(devices_[1], devices_[2])); +} + +TEST_F(IfrtDeviceUtilsTest, EmptyDeviceAssignmentShallReturnDefault) { + TF_ASSERT_OK_AND_ASSIGN( + auto devices_from_attribute, + GetAssignedIfrtDevices(client_, kNumReplicas, kNumCoresPerReplica, + std::nullopt)); + EXPECT_THAT(devices_from_attribute, ElementsAre(devices_[2], devices_[3])); +} + +TEST_F(IfrtDeviceUtilsTest, MismatchCoordinatesShallFail) { + std::vector device_assignment_attr = {1, 0, 0, 1, 3, 0, 0, 0}; + auto status = GetAssignedIfrtDevices(client_, 1, 2, device_assignment_attr); + EXPECT_THAT(status, StatusIs(absl::StatusCode::kFailedPrecondition)); +} + +} // namespace + +} // namespace ifrt_serving +} // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc index 51d0f5a2b52e5d..41fb1a3bec91d5 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_executable_registry_test.cc @@ -54,9 +54,9 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { namespace { -const tsl::thread::ThreadPool& GetThreadPool() { +tsl::thread::ThreadPool& GetThreadPool() { constexpr int kMaxParallelism = 16; - static auto* const thread_pool = + static auto* thread_pool = new tsl::thread::ThreadPool(tsl::Env::Default(), tsl::ThreadOptions(), "IfrtSharding", kMaxParallelism); return *thread_pool; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h index e799c571a246b7..d488d936776954 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_TFRT_IFRT_IFRT_LOADED_VARIABLE_REGISTRY_H_ #include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" @@ -24,7 +25,10 @@ limitations under the License. #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" @@ -38,19 +42,29 @@ class IfrtLoadedVariableRegistry { // The key is per variable tensor per device assignment. For single -device // program, variables can be loaded on multiple devices with core selection. // For SPMD program, we currently assume all devices will be used, so we use - // set to make it compatible with SPMD. + // vector to make it compatible with SPMD. struct Key { - // We use a set to make it compatible with SPMD. - absl::flat_hash_set device_ids; + // We use a vector to make it compatible with SPMD because the order of the + // devices used for sharding must match the order of the devices used for + // xla compilation. + std::vector device_ids; std::string input_name; + xla::HloSharding hlo_sharding; template friend H AbslHashValue(H h, const Key& key) { - h = H::combine(std::move(h), key.input_name, key.device_ids); + h = H::combine(std::move(h), key.input_name, key.device_ids, + key.hlo_sharding); return h; } friend bool operator==(const Key& x, const Key& y) { - return x.input_name == y.input_name && x.device_ids == y.device_ids; + return x.input_name == y.input_name && x.device_ids == y.device_ids && + x.hlo_sharding == y.hlo_sharding; + } + + std::string ToString() const { + return absl::StrCat(input_name, ":", absl::StrJoin(device_ids, ","), ":", + hlo_sharding.ToString()); } }; diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc index 7a17a7e14ccb38..ff71481a490d60 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_set.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -28,7 +27,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" -#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/future.h" @@ -53,14 +51,10 @@ absl::StatusOr> LoadIfrtVariable( std::shared_ptr ifrt_client, const tsl::thread::ThreadPool& thread_pool, const tensorflow::Tensor& variable, - const VariableDeviceShardingConfigProto& sharding_config) { - std::vector device_ids{sharding_config.device_ids().begin(), - sharding_config.device_ids().end()}; - TF_ASSIGN_OR_RETURN(xla::HloSharding hlo_sharding, - xla::HloSharding::FromProto(sharding_config.sharding())); + const VariableDeviceShardingConfig& sharding_config) { return tensorflow::ifrt_serving::MakeArrayFromTensor( - *ifrt_client, variable, sharding_config.device_ids(), hlo_sharding, - thread_pool); + *ifrt_client, variable, sharding_config.device_ids, + sharding_config.hlo_sharding, thread_pool); } } // namespace @@ -97,12 +91,11 @@ absl::Status AsyncLoadRestoredTensorAsIfrtLoadedVariable( const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry, ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, - const VariableDeviceShardingConfigProto& sharding_config) { - absl::flat_hash_set device_ids{sharding_config.device_ids().begin(), - sharding_config.device_ids().end()}; + const VariableDeviceShardingConfig& sharding_config) { IfrtLoadedVariableRegistry::Key loaded_variable_key{ - .device_ids = std::move(device_ids), + .device_ids = sharding_config.device_ids, .input_name = std::string(runtime_name), + .hlo_sharding = sharding_config.hlo_sharding, }; if (ifrt_loaded_variable_registry.GetLoadedVariable(loaded_variable_key) .ok()) { diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h index 4d07d1a3771a8a..6fea3a576e45a2 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/client.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" @@ -37,6 +39,12 @@ namespace ifrt_serving { // An index to indicate a non per-core executable bundle cache. inline constexpr int kNoCoreSelectedIndex = -1; +// TODO(b/352551302) Delete VariableDeviceShardingConfigProto. +struct VariableDeviceShardingConfig { + std::vector device_ids; + xla::HloSharding hlo_sharding; +}; + absl::StatusOr GetDtypeAndShape( const ResourceHandle& resource_handle); @@ -57,7 +65,7 @@ absl::Status AsyncLoadRestoredTensorAsIfrtLoadedVariable( const ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry, ifrt_serving::IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, - const VariableDeviceShardingConfigProto& sharding_config); + const VariableDeviceShardingConfig& sharding_config); } // namespace ifrt_serving } // namespace tensorflow diff --git a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc index 4777d0a3c18103..fe8e98884e22df 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils_test.cc @@ -23,12 +23,14 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/tensor.h" @@ -39,7 +41,6 @@ limitations under the License. #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -77,8 +78,10 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableNotFoundWrongName) { auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue( /*num_threads=*/4, /*num_blocking_threads=*/4); - VariableDeviceShardingConfigProto sharding_config; - sharding_config.add_device_ids(0); + VariableDeviceShardingConfig sharding_config = { + .device_ids = {0}, + .hlo_sharding = xla::HloSharding::Replicate(), + }; auto promise = xla::ifrt::Future::CreatePromise(); auto future = xla::ifrt::Future(promise); @@ -120,8 +123,10 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) { auto restore_work_queue = tfrt::CreateMultiThreadedWorkQueue( /*num_threads=*/4, /*num_blocking_threads=*/4); - VariableDeviceShardingConfigProto sharding_config; - sharding_config.add_device_ids(0); + VariableDeviceShardingConfig sharding_config{ + .device_ids = {0}, + .hlo_sharding = xla::HloSharding::Replicate(), + }; auto promise = xla::ifrt::Future::CreatePromise(); auto future = xla::ifrt::Future(promise); @@ -140,6 +145,7 @@ TEST(ShardingUtilsTest, ShardTensorToIfrtLoadedVariableSucceed) { IfrtLoadedVariableRegistry::Key key{ .device_ids = {0}, .input_name = "var_x", + .hlo_sharding = sharding_config.hlo_sharding, }; TF_ASSERT_OK_AND_ASSIGN(auto v, loaded_variable_registry.GetLoadedVariable(key)); diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc b/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc index fec8b8b099480e..7dad92a1f873b6 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.cc @@ -24,7 +24,7 @@ limitations under the License. namespace tensorflow { namespace ifrt_serving { -const tsl::thread::ThreadPool& IfrtModelContext::GetThreadPool() const { +tsl::thread::ThreadPool& IfrtModelContext::GetThreadPool() const { return thread_pool_; } diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h index 76a9622c0ef0e4..bc8f802ab8c75c 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_model_context.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_context.h @@ -55,7 +55,7 @@ class IfrtModelContext { explicit IfrtModelContext( std::shared_ptr client, IfrtServingCoreSelector* ifrt_serving_core_selector, - const tsl::thread::ThreadPool* thread_pool, + tsl::thread::ThreadPool* thread_pool, std::unique_ptr compilation_environment_proto) : client_(std::move(client)), ifrt_serving_core_selector_(ifrt_serving_core_selector), @@ -65,8 +65,7 @@ class IfrtModelContext { IfrtModelContext( std::shared_ptr client, IfrtServingCoreSelector* ifrt_serving_core_selector, - const tsl::thread::ThreadPool* thread_pool, - tensorflow::DeviceMgr* device_mgr, + tsl::thread::ThreadPool* thread_pool, tensorflow::DeviceMgr* device_mgr, tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::unique_ptr compilation_environment_proto, std::shared_ptr topology) @@ -90,7 +89,7 @@ class IfrtModelContext { return shape_representation_fn_; } - const tsl::thread::ThreadPool& GetThreadPool() const; + tsl::thread::ThreadPool& GetThreadPool() const; const IfrtLoadedVariableRegistry& GetLoadedVariableRegistry() const { return loaded_variable_registry_; @@ -139,7 +138,7 @@ class IfrtModelContext { std::shared_ptr topology_; IfrtServingCoreSelector* ifrt_serving_core_selector_; // May be nullptr - const tsl::thread::ThreadPool& thread_pool_; + tsl::thread::ThreadPool& thread_pool_; tensorflow::DeviceMgr* device_mgr_ = nullptr; // Not owned. tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_ = diff --git a/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h new file mode 100644 index 00000000000000..da9528eab6b023 --- /dev/null +++ b/tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ +#define TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" + +namespace tensorflow { +namespace ifrt_serving { + +inline constexpr absl::string_view kIfrtModelRestoreContextName = + "IfrtModelRestoreContext"; + +// A resource context that holds the `CheckpointLoader` for a model. We need a +// different context than `IfrtModelContext` because `IfrtModelContext` is too +// large to be a dependency of other libraries. +class IfrtModelRestoreContext { + public: + explicit IfrtModelRestoreContext( + std::unique_ptr checkpoint_loader) + : checkpoint_loader_(std::move(checkpoint_loader)) {} + + CheckpointLoader* checkpoint_loader() const { + return checkpoint_loader_.get(); + } + + private: + std::unique_ptr checkpoint_loader_; +}; + +} // namespace ifrt_serving +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TFRT_IFRT_IFRT_MODEL_RESTORE_CONTEXT_H_ diff --git a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc index de0a27aecc4104..32519624d55faf 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include "absl/status/status.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h index 0ab53974be06f3..a4505cbab06f38 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/synchronization/mutex.h" #include "xla/tsl/framework/serving_device_selector.h" diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc index 45e492c3084060..a6d35bf1728290 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.cc @@ -34,8 +34,12 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "llvm/Support/FormatVariadic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/extract_callback.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h" @@ -63,6 +67,7 @@ limitations under the License. #include "xla/tsl/framework/serving_device_selector.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -70,6 +75,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_device_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" @@ -119,39 +125,61 @@ absl::StatusOr> BuildDtypeAndShape( return dtypes_and_shapes; } -absl::StatusOr GetXlaDeviceAssignment( - const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) { - if (!compile_metadata.has_device_assignment()) { - return absl::InternalError("No device assignment found."); +// Returns the device assignment from the given IFRT devices list. +absl::StatusOr GetRuntimeXlaDeviceAssignment( + const xla::ifrt::DeviceList& devices, int num_replicas, + int num_cores_per_replica) { + const int num_devices = num_replicas * num_cores_per_replica; + if (devices.size() != num_devices) { + return absl::InternalError( + absl::StrCat("Device assignment has ", devices.size(), + " devices, but expected ", num_devices)); } - TF_ASSIGN_OR_RETURN( - std::unique_ptr da, - xla::DeviceAssignment::Deserialize(compile_metadata.device_assignment())); - return *da; + xla::DeviceAssignment da(num_replicas, num_cores_per_replica); + int device_index = 0; + for (int replica_idx = 0; replica_idx < num_replicas; replica_idx++) { + for (int core_idx = 0; core_idx < num_cores_per_replica; + core_idx++, device_index++) { + da(replica_idx, core_idx) = devices[device_index]->Id().value(); + VLOG(3) << "Added IFRT device id: " << da(replica_idx, core_idx); + } + } + return da; } +static constexpr absl::string_view kDeviceAssignmentAttr = "device_assignment"; +static constexpr absl::string_view kEntryFuncName = "main"; + absl::StatusOr> GetAssignedDevices( - const xla::ifrt::Client& ifrt_client, - const tensorflow::tpu::TPUCompileMetadataProto& compile_metadata) { - TF_ASSIGN_OR_RETURN(auto device_assignment, - GetXlaDeviceAssignment(compile_metadata)); - const int num_devices = - device_assignment.replica_count() * device_assignment.computation_count(); - std::vector devices; - devices.reserve(num_devices); - for (int replica_idx = 0; replica_idx < device_assignment.replica_count(); - replica_idx++) { - for (int computation_idx = 0; - computation_idx < device_assignment.computation_count(); - computation_idx++) { - auto device_id = device_assignment(replica_idx, computation_idx); - TF_ASSIGN_OR_RETURN( - xla::ifrt::Device * device, - ifrt_client.LookupDevice(xla::ifrt::DeviceId(device_id))); - devices.push_back(device); + mlir::ModuleOp module, const xla::ifrt::Client& ifrt_client, + int num_replicas, int num_cores_per_replica) { + auto op = module.lookupSymbol(kEntryFuncName); + if (!op) { + return absl::InternalError("Could not find entry function in MLIR Module."); + } + + auto device_assignment_attr = + op->getAttrOfType(kDeviceAssignmentAttr); + std::optional> device_assignment_attr_val; + + if (device_assignment_attr && !device_assignment_attr.getValue().empty()) { + std::vector coords; + coords.reserve(num_replicas * num_cores_per_replica); + for (auto coord_attr : device_assignment_attr.getValue()) { + auto coord_attr_val = mlir::dyn_cast(coord_attr); + if (!coord_attr_val) { + return absl::InternalError( + llvm::formatv("Device assignment attribute is not an integer: {0}", + device_assignment_attr) + .str()); + } + coords.push_back(coord_attr_val.getInt()); } + device_assignment_attr_val = std::move(coords); } - return devices; + return GetAssignedIfrtDevices(ifrt_client, num_replicas, + num_cores_per_replica, + device_assignment_attr_val); } } // namespace @@ -161,7 +189,7 @@ IfrtServingExecutable::Create( int64_t program_id, absl::string_view model_name, absl::string_view signature_name, mlir::OwningOpRef module, std::shared_ptr client, - const tsl::thread::ThreadPool* thread_pool, + tsl::thread::ThreadPool* thread_pool, IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, const IfrtRestoreTensorRegistry* ifrt_restore, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, @@ -173,12 +201,21 @@ IfrtServingExecutable::Create( tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata, GetCompileMetadata(*module, *client)); + TF_ASSIGN_OR_RETURN( + std::vector assigned_devices, + GetAssignedDevices(*module, *client, + original_compile_metadata.num_replicas(), + original_compile_metadata.num_cores_per_replica())); + auto executable = absl::WrapUnique(new IfrtServingExecutable( program_id, model_name, signature_name, std::move(module), std::move(client), thread_pool, ifrt_loaded_variable_registry, ifrt_restore, checkpoint_loader_queue, device_mgr, std::move(shape_representation_fn), ifrt_serving_core_selector, - std::move(original_compile_metadata), compilation_environement_proto)); + std::move(original_compile_metadata), + xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices( + assigned_devices.begin(), assigned_devices.end())), + compilation_environement_proto)); return executable; } @@ -367,14 +404,17 @@ IfrtServingExecutable::CreateExecutableSynchronously( xla_compile_options.executable_build_options.set_num_partitions( num_partitions); - xla_compile_options.executable_build_options.set_use_spmd_partitioning(true); + xla_compile_options.executable_build_options.set_use_spmd_partitioning( + original_compile_metadata_.use_spmd_for_xla_partitioning()); xla_compile_options.parameter_is_tupled_arguments = false; // Use portable execution for single device + core selection. if (UsePortableExecution(compile_metadata)) { xla_compile_options.compile_portable_executable = true; } else { - TF_ASSIGN_OR_RETURN(xla::DeviceAssignment da, - GetXlaDeviceAssignment(tf2hlo_result.compile_metadata)); + TF_ASSIGN_OR_RETURN( + xla::DeviceAssignment da, + GetRuntimeXlaDeviceAssignment(assigned_device_list_, num_replicas, + num_partitions)); VLOG(2) << "Device assignment :" << da.ToString(); xla_compile_options.executable_build_options.set_device_assignment(da); } @@ -516,7 +556,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( // `device_reservation` should be alive before the end of the execution. tsl::DeviceReservation device_reservation(kNoCoreSelectedIndex, nullptr); - std::vector devices; + xla::ifrt::DeviceList device_list; if (UsePortableExecution(compile_metadata)) { device_reservation = ifrt_serving_core_selector_->ReserveDevice(program_id_); @@ -526,19 +566,16 @@ absl::StatusOr> IfrtServingExecutable::Execute( TF_ASSIGN_OR_RETURN(xla::ifrt::Device * device, ifrt_client_->LookupDevice(xla::ifrt::DeviceId( device_reservation.device_index()))); - devices.push_back(device); + device_list = + xla::ifrt::DeviceList(xla::ifrt::DeviceList::Devices({device})); } else { - TF_ASSIGN_OR_RETURN(devices, - GetAssignedDevices(*ifrt_client_, compile_metadata)); + device_list = assigned_device_list_; } TF_ASSIGN_OR_RETURN(SharedCachedExecutableBundle executable_bundle, LookUpOrCreateExecutable( compile_metadata, absl::MakeSpan(dtypes_and_shapes)) .Await()); - xla::ifrt::DeviceList device_list( - xla::ifrt::DeviceList::Devices(devices.begin(), devices.end())); - if (executable_bundle->compile_metadata.args().size() != dtypes_and_shapes.size()) { return absl::InternalError(absl::StrCat( @@ -548,7 +585,7 @@ absl::StatusOr> IfrtServingExecutable::Execute( // Asynchronously load the restored variable tensors to Ifrt array. TF_RETURN_IF_ERROR(AsyncLoadIfrtArray(inputs, variable_arg_indices, - *executable_bundle, devices)); + *executable_bundle, device_list)); std::vector> args; args.reserve(inputs.size()); @@ -556,13 +593,19 @@ absl::StatusOr> IfrtServingExecutable::Execute( for (int i = 0; i < inputs.size(); i++) { if (variable_index < variable_arg_indices.size() && i == variable_arg_indices[variable_index]) { - absl::flat_hash_set device_ids; - for (const auto& device : devices) { - device_ids.insert(device->Id().value()); + std::vector device_ids; + device_ids.reserve(device_list.size()); + for (const auto& device : device_list) { + device_ids.push_back(device->Id().value()); } + TF_ASSIGN_OR_RETURN( + xla::HloSharding hlo_sharding, + xla::HloSharding::FromProto( + executable_bundle->compile_metadata.args()[i].sharding())); IfrtLoadedVariableRegistry::Key key{ .device_ids = std::move(device_ids), .input_name = inputs[i].scalar()(), + .hlo_sharding = std::move(hlo_sharding), }; TF_ASSIGN_OR_RETURN( auto loaded_variable, @@ -600,14 +643,15 @@ absl::StatusOr> IfrtServingExecutable::Execute( auto status = execution_result.status.Await(); TF_RETURN_IF_ERROR(status); - std::vector outputs; - if (executable_bundle->compile_metadata.retvals().size() != execution_result.outputs.size()) { return absl::InternalError(absl::StrCat( "Expect ", executable_bundle->compile_metadata.retvals().size(), " but got ", execution_result.outputs.size(), " outputs")); } + + std::vector> output_futures; + output_futures.reserve(execution_result.outputs.size()); for (int i = 0; i < execution_result.outputs.size(); ++i) { tensorflow::TensorShape tensor_shape; const tsl::RCReference& array_for_copy = @@ -621,13 +665,17 @@ absl::StatusOr> IfrtServingExecutable::Execute( TF_ASSIGN_OR_RETURN(auto hlo_sharding, xla::HloSharding::FromProto( metadata_retval.sharding())); - TF_ASSIGN_OR_RETURN( - tensorflow::Tensor tensor, - MakeTensorFromArray(*ifrt_client_, *array_for_copy, hlo_sharding, - device_list, thread_pool_)); - outputs.push_back(std::move(tensor)); + output_futures.push_back(MakeTensorFromArray(*ifrt_client_, *array_for_copy, + hlo_sharding, device_list, + thread_pool_)); } + std::vector outputs; + outputs.reserve(output_futures.size()); + for (auto& output_future : output_futures) { + TF_ASSIGN_OR_RETURN(auto tensor, output_future.Await()); + outputs.push_back(std::move(tensor)); + } return outputs; } @@ -635,7 +683,7 @@ absl::Status IfrtServingExecutable::AsyncLoadIfrtArray( absl::Span inputs, absl::Span variable_arg_indices, const CachedExecutableBundle& executable_bundle, - const std::vector& devices) { + const xla::ifrt::DeviceList& devices) { for (const int i : variable_arg_indices) { if (inputs[i].dtype() != tensorflow::DT_STRING || !tensorflow::TensorShapeUtils::IsScalar(inputs[i].shape())) { @@ -647,11 +695,15 @@ absl::Status IfrtServingExecutable::AsyncLoadIfrtArray( } std::string runtime_name = inputs[i].scalar()(); // TODO(b/339521818): Add test cases for OpSharding on variables. - VariableDeviceShardingConfigProto sharding_config; - *sharding_config.mutable_sharding() = - executable_bundle.compile_metadata.args()[i].sharding(); + TF_ASSIGN_OR_RETURN( + xla::HloSharding hlo_sharding, + xla::HloSharding::FromProto( + executable_bundle.compile_metadata.args()[i].sharding())); + VariableDeviceShardingConfig sharding_config{ + .hlo_sharding = std::move(hlo_sharding), + }; for (const auto& device : devices) { - sharding_config.add_device_ids(device->Id().value()); + sharding_config.device_ids.push_back(device->Id().value()); } TF_RETURN_IF_ERROR( diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h index 5d4e966771b0fe..9dfc2251c8328f 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable.h @@ -19,7 +19,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -65,7 +64,7 @@ class IfrtServingExecutable { absl::string_view signature_name, mlir::OwningOpRef module, std::shared_ptr client, - const tsl::thread::ThreadPool* thread_pool, + tsl::thread::ThreadPool* thread_pool, IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, const IfrtRestoreTensorRegistry* ifrt_restore, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, @@ -136,7 +135,7 @@ class IfrtServingExecutable { absl::string_view signature_name, mlir::OwningOpRef module, std::shared_ptr client, - const tsl::thread::ThreadPool* thread_pool, + tsl::thread::ThreadPool* thread_pool, IfrtLoadedVariableRegistry* ifrt_loaded_variable_registry, const IfrtRestoreTensorRegistry* ifrt_restore_tensor_registry, tfrt::ConcurrentWorkQueue* checkpoint_loader_queue, @@ -144,12 +143,14 @@ class IfrtServingExecutable { tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn, IfrtServingCoreSelector* ifrt_serving_core_selector, tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata, + xla::ifrt::DeviceList assigned_device_list, tsl::protobuf::Message* compilation_environment_proto) : program_id_(program_id), model_name_(std::string(model_name)), signature_name_(std::string(signature_name)), module_(std::move(module)), original_compile_metadata_(std::move(original_compile_metadata)), + assigned_device_list_(std::move(assigned_device_list)), ifrt_client_(std::move(client)), thread_pool_(*thread_pool), ifrt_loaded_variable_registry_(*ifrt_loaded_variable_registry), @@ -168,12 +169,13 @@ class IfrtServingExecutable { mlir::OwningOpRef module_ ABSL_GUARDED_BY(mutex_); // The original compile metadata. We need to keep it around to be able to - // test portable execution condition even if the Module itsel is already + // test portable execution condition even if the Module itself is already // released. tensorflow::tpu::TPUCompileMetadataProto original_compile_metadata_; + const xla::ifrt::DeviceList assigned_device_list_; std::shared_ptr ifrt_client_; - const tsl::thread::ThreadPool& thread_pool_; + tsl::thread::ThreadPool& thread_pool_; IfrtLoadedVariableRegistry& ifrt_loaded_variable_registry_; const IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry_; @@ -196,7 +198,7 @@ class IfrtServingExecutable { absl::Span inputs, absl::Span variable_arg_indices, const CachedExecutableBundle& executable_bundle, - const std::vector& devices); + const xla::ifrt::DeviceList& devices); absl::StatusOr> ConvertTensorToArray( const tensorflow::Tensor& tensor, diff --git a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc index c0bdef1bf97bc7..28d3efbe479a93 100644 --- a/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc +++ b/tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/tsl/framework/serving_device_selector.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_matcher.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -41,7 +42,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/tstring.h" diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.cc b/tensorflow/core/tfrt/ifrt/sharding_utils.cc index 9621ce7e483b3c..dab557308277c1 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.cc @@ -490,13 +490,11 @@ MakeAssembledArrayFromHostBuffer(xla::ifrt::Client& ifrt_client, xla::ifrt::ArrayCopySemantics::kDonateInput); } -} // namespace - -absl::StatusOr MakeTensorFromArray( +absl::StatusOr> MakeTensorFromArrayHelper( xla::ifrt::Client& ifrt_client, xla::ifrt::Array& input_array, const xla::HloSharding& hlo_sharding, const xla::ifrt::DeviceList& device_list, - const tsl::thread::ThreadPool& thread_pool) { + tsl::thread::ThreadPool& thread_pool) { TF_ASSIGN_OR_RETURN(tensorflow::DataType data_type, ToTensorDataType(input_array.dtype())); tensorflow::TensorShape tensor_shape = ToTensorShape(input_array.shape()); @@ -504,6 +502,10 @@ absl::StatusOr MakeTensorFromArray( VLOG(2) << "Create tensor from array based on sharding: " << hlo_sharding.ToString(); + xla::ifrt::Promise promise = + xla::ifrt::Future::CreatePromise(); + xla::ifrt::Future output_tensor_future(promise); + if (hlo_sharding.IsReplicated()) { VLOG(1) << "Fast path for replication"; // fast path for replication. @@ -516,14 +518,22 @@ absl::StatusOr MakeTensorFromArray( "Not fully replicated output. Expected ", tensor_shape.DebugString(), " but got ", fully_replicated_array->shape().DebugString())); } + tensorflow::Tensor output_tensor(data_type, tensor_shape); - TF_RETURN_IF_ERROR( - fully_replicated_array - ->CopyToHostBuffer(output_tensor.data(), - GetByteStrides(data_type, tensor_shape), - xla::ifrt::ArrayCopySemantics::kAlwaysCopy) - .Await()); - return output_tensor; + fully_replicated_array + ->CopyToHostBuffer(output_tensor.data(), + GetByteStrides(data_type, tensor_shape), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .OnReady([promise = std::move(promise), + output_tensor = + std::move(output_tensor)](absl::Status status) mutable { + if (!status.ok()) { + std::move(promise).Set(status); + return; + } + std::move(promise).Set(std::move(output_tensor)); + }); + return output_tensor_future; } else if (hlo_sharding.IsTileMaximal()) { // Maximal implies single device VLOG(1) << "Fast path for maximal"; @@ -535,13 +545,20 @@ absl::StatusOr MakeTensorFromArray( int64_t device_id = hlo_sharding.GetUniqueDevice(); tensorflow::Tensor output_tensor(data_type, tensor_shape); - TF_RETURN_IF_ERROR( - disassembled_array[device_id] - ->CopyToHostBuffer(output_tensor.data(), - GetByteStrides(data_type, tensor_shape), - xla::ifrt::ArrayCopySemantics::kAlwaysCopy) - .Await()); - return output_tensor; + disassembled_array[device_id] + ->CopyToHostBuffer(output_tensor.data(), + GetByteStrides(data_type, tensor_shape), + xla::ifrt::ArrayCopySemantics::kAlwaysCopy) + .OnReady([promise = std::move(promise), + output_tensor = + std::move(output_tensor)](absl::Status status) mutable { + if (!status.ok()) { + std::move(promise).Set(status); + return; + } + std::move(promise).Set(std::move(output_tensor)); + }); + return output_tensor_future; } auto ifrt_sharding = xla::ifrt::HloSharding::Create( @@ -646,12 +663,43 @@ absl::StatusOr MakeTensorFromArray( arrays_copy_status.push_back(std::move(copy_status)); } - TF_RETURN_IF_ERROR( - xla::ifrt::JoinFutures(absl::MakeSpan(arrays_copy_status)).Await()); + xla::ifrt::JoinFutures(absl::MakeSpan(arrays_copy_status)) + .OnReady([promise = std::move(promise), &ifrt_client, + input_tensors = std::move(input_tensors), num_concats, + data_type, tensor_shape, + &thread_pool](absl::Status status) mutable { + if (!status.ok()) { + std::move(promise).Set(status); + return; + } + thread_pool.Schedule( + [promise = std::move(promise), &ifrt_client, + input_tensors = std::move(input_tensors), + num_concats = std::move(num_concats), data_type = data_type, + tensor_shape = tensor_shape, &thread_pool]() mutable { + std::move(promise).Set(MakeTensorFromDisassembledTensors( + ifrt_client, absl::MakeSpan(input_tensors), num_concats, + data_type, tensor_shape, thread_pool)); + }); + }); + return output_tensor_future; +} + +} // namespace - return MakeTensorFromDisassembledTensors( - ifrt_client, absl::MakeSpan(input_tensors), num_concats, data_type, - tensor_shape, thread_pool); +xla::ifrt::Future MakeTensorFromArray( + xla::ifrt::Client& ifrt_client, xla::ifrt::Array& input_array, + const xla::HloSharding& hlo_sharding, + const xla::ifrt::DeviceList& device_list, + tsl::thread::ThreadPool& thread_pool) { + absl::StatusOr> output_tensor_future = + MakeTensorFromArrayHelper(ifrt_client, input_array, hlo_sharding, + device_list, thread_pool); + if (!output_tensor_future.ok()) { + return xla::ifrt::Future( + std::move(output_tensor_future).status()); + } + return *std::move(output_tensor_future); } absl::StatusOr> MakeArrayFromTensor( diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils.h b/tensorflow/core/tfrt/ifrt/sharding_utils.h index a1eb590ca80947..43dbe9e8bca8dd 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils.h +++ b/tensorflow/core/tfrt/ifrt/sharding_utils.h @@ -26,9 +26,11 @@ limitations under the License. #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/device.h" +#include "xla/python/ifrt/future.h" #include "xla/tsl/concurrency/ref_count.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tsl/platform/threadpool.h" namespace tensorflow { @@ -60,11 +62,11 @@ absl::StatusOr> MakeArrayFromTensor( // device_list: list of devices that is aligned with the order of device buffers // in the `input_array`. // -absl::StatusOr MakeTensorFromArray( +xla::ifrt::Future MakeTensorFromArray( xla::ifrt::Client& ifrt_client, xla::ifrt::Array& input_array, const xla::HloSharding& hlo_sharding, const xla::ifrt::DeviceList& device_list, - const tsl::thread::ThreadPool& thread_pool); + tsl::thread::ThreadPool& thread_pool); // A wrapper around xla::ShapeUtil::ByteStrides to get the byte strides of a // TensorFlow tensor. diff --git a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc index 48233d5f293c5b..f85b3243c36191 100644 --- a/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc +++ b/tensorflow/core/tfrt/ifrt/sharding_utils_test.cc @@ -34,13 +34,14 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_matcher.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_tensor_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/status_matchers.h" @@ -139,7 +140,8 @@ TEST_P(ReshardToTensorTest, MakeHostTensorFromDeviceArrays) { TF_ASSERT_OK_AND_ASSIGN( auto output_tensor, MakeTensorFromArray(*client, *assembled_array, GetParam().sharding, - device_list, thread_pool)); + device_list, thread_pool) + .Await()); EXPECT_THAT(GetParam().expected_out_tensor, TensorEq(output_tensor)); } diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback.cc b/tensorflow/core/tfrt/ifrt/tf_host_callback.cc index 5c5a48f4fc52b4..8beeddf82a92e3 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback.cc +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/device_factory.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/protobuf/config.pb.h" diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback.h b/tensorflow/core/tfrt/ifrt/tf_host_callback.h index a78b0e5d0aecea..5b73221e6d3afa 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback.h +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback.h @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/protobuf/config.pb.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc b/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc index bc67bbae34d94a..17240e361881c8 100644 --- a/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc +++ b/tensorflow/core/tfrt/ifrt/tf_host_callback_test.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/kernels/BUILD b/tensorflow/core/tfrt/kernels/BUILD index 9716f4bf5cae14..817b8485dd44d6 100644 --- a/tensorflow/core/tfrt/kernels/BUILD +++ b/tensorflow/core/tfrt/kernels/BUILD @@ -62,7 +62,6 @@ tf_cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_xla//xla/pjrt/cpu:cpu_client", @@ -72,6 +71,7 @@ tf_cc_test( "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", "@local_xla//xla/tsl/framework:serving_device_selector", "@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc index 3ae4d09bc054d3..cd29511d4d982f 100644 --- a/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc +++ b/tensorflow/core/tfrt/kernels/ifrt_program_ops_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/tsl/framework/serving_device_selector.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/platform/test.h" #include "tensorflow/core/tfrt/ifrt/ifrt_executable_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_executable_test_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/core/tfrt/mlrt/interpreter/BUILD b/tensorflow/core/tfrt/mlrt/interpreter/BUILD index 552959b1ce0c5e..10b5346a49553e 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/BUILD +++ b/tensorflow/core/tfrt/mlrt/interpreter/BUILD @@ -127,6 +127,8 @@ cc_library( ":future", ":value", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_xla//xla/tsl/concurrency:async_value", "@tf_runtime//:async_value", ], @@ -189,9 +191,9 @@ tf_cc_test( "@com_google_absl//absl/types:span", "@com_google_benchmark//:benchmark", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_benchmark", + "@local_xla//xla/tsl/lib/core:status_test_util", "@tf_runtime//:hostcontext", ], ) diff --git a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h index ceef6679b6fa8a..43d43422e60093 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h +++ b/tensorflow/core/tfrt/mlrt/interpreter/async_handle.h @@ -19,6 +19,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/future.h" #include "tensorflow/core/tfrt/mlrt/interpreter/value.h" @@ -141,6 +143,12 @@ class AsyncHandle { } auto& execution_context = *arg->Get(); + execution_context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: unwind AsyncHandle of context ", + absl::Hex(reinterpret_cast(execution_context_.get())), + " from context ", + absl::Hex(reinterpret_cast(&execution_context)), + " of state ", execution_context.state_))); execution_context.Await(std::move(*this)); } diff --git a/tensorflow/core/tfrt/mlrt/interpreter/execute.cc b/tensorflow/core/tfrt/mlrt/interpreter/execute.cc index f3ef9bc2822085..635935911aa221 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/execute.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/execute.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tfrt/mlrt/interpreter/execute.h" +#include #include #include @@ -178,7 +179,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { function_name = context.function_stack_.back().function_object().name(); } context.LogError(absl::InternalError(absl::StrCat( - "Start UnwindOnError from function ", function_name, " at pc: ", pc))); + "UnwindOnError: start from function ", function_name, + " with stack size: ", context.function_stack_.size(), " at pc: ", pc, + " for context ", absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); while (!context.function_stack_.empty()) { DCHECK(context.state_ == ExecutionContext::State::kError); @@ -199,6 +203,11 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { reg.HandleError(context_value); if (context.state_ != ExecutionContext::State::kError) { DCHECK(context.state_ == ExecutionContext::State::kSuspended); + + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: entering state", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); + // Rewind current pc so that the execution context come back to where // is is suspended. --pc; @@ -207,6 +216,12 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { } } + context.LogError(absl::InternalError( + absl::StrCat("UnwindOnError: unwinding function from ", pc, " to ", + current_function->pc_, " for context ", + absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); + for (; context.state_ == ExecutionContext::State::kError && pc <= current_function->pc_; ++pc) { @@ -218,6 +233,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { reg.HandleError(context_value); if (context.state_ != ExecutionContext::State::kError) { DCHECK(context.state_ == ExecutionContext::State::kSuspended); + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: entering state", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); + // Rewind current pc so that the execution context come back to where // is is suspended. --pc; @@ -230,6 +249,9 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { DCHECK(context.suspend_handler_) << "suspend_handler_ must be populated when the state is set to " "kSuspended."; + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: suspended state ", context.state_, " for context ", + absl::Hex(reinterpret_cast(&context))))); std::move(context.suspend_handler_)([&context, pc]() { auto* work_queue = context.work_queue(); DCHECK(work_queue); @@ -247,8 +269,10 @@ void UnwindOnError(ExecutionContext& context, int64_t pc) { context.function_stack_.pop_back(); } - context.LogError(absl::InternalError( - absl::StrCat("Finish UnwindOnError for function ", function_name))); + context.LogError(absl::InternalError(absl::StrCat( + "UnwindOnError: done for function ", function_name, + " for context: ", absl::Hex(reinterpret_cast(&context)), + " at state ", context.state_))); // Context may no longer be valid after exit_handler_ is called. if (context.exit_handler_) { diff --git a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc index 0434f3c883a06e..97982e77e8c791 100644 --- a/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc +++ b/tensorflow/core/tfrt/mlrt/interpreter/interpreter_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "absl/types/span.h" #include "benchmark/benchmark.h" // from @com_google_benchmark +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/async_handle.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" #include "tensorflow/core/tfrt/mlrt/interpreter/register_span.h" #include "tensorflow/core/tfrt/mlrt/interpreter/value.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test_benchmark.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/mlrt/kernel/BUILD b/tensorflow/core/tfrt/mlrt/kernel/BUILD index 9da1e45e6886d5..5e377f8f809153 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/BUILD +++ b/tensorflow/core/tfrt/mlrt/kernel/BUILD @@ -10,6 +10,7 @@ package( # copybara:uncomment "//learning/brain/tfrt:__subpackages__", # copybara:uncomment "//learning/serving/servables/tfrt:__subpackages__", "//tensorflow/core/tfrt/graph_executor:__subpackages__", + "//tensorflow/core/tfrt/ifrt:__subpackages__", "//tensorflow/core/tfrt/saved_model:__subpackages__", "//tensorflow/core/tfrt/tfrt_session:__subpackages__", ], @@ -67,19 +68,16 @@ cc_library( deps = [ ":context", ":kernel", - ":kernel_runner_utils", - ":shard_restore_util", - "//tensorflow/compiler/mlir/tfrt/transforms/ifrt:ifrt_types", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", - "//tensorflow/core/common_runtime:function", "//tensorflow/core/framework:attr_value_proto_cc", "//tensorflow/core/framework:types_proto_cc", "//tensorflow/core/platform:protobuf", - "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_utils", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/interpreter:context", @@ -89,13 +87,10 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:tstring", "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/python/ifrt", - "@tf_runtime//:hostcontext", ], alwayslink = 1, ) @@ -178,8 +173,8 @@ tf_cc_shared_test( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", + "@local_xla//xla/tsl/lib/core:status_test_util", "@tf_runtime//:hostcontext", "@tf_runtime//:ref_count", ], @@ -210,9 +205,11 @@ tf_cc_shared_test( "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_compat_request_state", "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_config_proto_cc", "//tensorflow/core/tfrt/ifrt:ifrt_loaded_variable_registry", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_restore_tensor_registry", "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/mlrt/bytecode", @@ -230,7 +227,7 @@ tf_cc_shared_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/core:status_test_util", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:refcount", "@local_tsl//tsl/platform:status", @@ -242,6 +239,7 @@ tf_cc_shared_test( "@local_xla//xla/python/pjrt_ifrt:tfrt_cpu_client_test_lib", "@local_xla//xla/tsl/framework:serving_device_selector", "@local_xla//xla/tsl/framework/test_util:mock_serving_device_selector", + "@local_xla//xla/tsl/lib/core:status_test_util", "@tf_runtime//:hostcontext", ], ) diff --git a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc index 9d30f31a124ae7..5a584883452a2e 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/batch_kernel.cc @@ -234,7 +234,9 @@ class MlrtBatchResource : public tensorflow::serving::BatchResourceBase { options.num_batch_threads, options.max_batch_size, options.batch_timeout_micros, options.max_enqueued_batches, options.allowed_batch_sizes, enable_large_batch_splitting, - disable_padding, options.low_priority_max_batch_size, + disable_padding, + /* batch_padding_policy= */ options.batch_padding_policy, + options.low_priority_max_batch_size, options.low_priority_batch_timeout_micros, options.low_priority_max_enqueued_batches, options.low_priority_allowed_batch_sizes, diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc index ca9dd2271335fb..e5c7dbd1dc0c72 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel.cc @@ -25,39 +25,31 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "tensorflow/compiler/mlir/tfrt/transforms/ifrt/ifrt_types.h" #include "xla/python/ifrt/future.h" #include "xla/xla_data.pb.h" -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" -#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/protobuf.h" // IWYU pragma: keep -#include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_loaded_variable_utils.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" #include "tensorflow/core/tfrt/mlrt/interpreter/future.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" -#include "tensorflow/core/tfrt/mlrt/kernel/kernel_runner_utils.h" -#include "tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" #include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" #include "tsl/platform/tstring.h" -#include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime using tensorflow::ifrt_serving::IfrtModelContext; @@ -65,14 +57,6 @@ namespace tensorflow { namespace tf_mlrt { namespace { -int64_t GetSizeFromVarHandle(const ResourceHandle& handle) { - int size = 0; - for (auto& dtype_and_shape : handle.dtypes_and_shapes()) { - size += DataTypeSize(dtype_and_shape.dtype) * - dtype_and_shape.shape.num_elements(); - } - return size; -} struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame { using KernelFrame::KernelFrame; @@ -119,20 +103,8 @@ struct MlrtIfrtRestoreVariableKernel : mlrt::KernelFrame { // dynamically decide it based on the size of the variables. static constexpr int kNumRestoreClusters = 4; - // A shard of variables to be restored. - struct RestoreVariableShard { - tensorflow::Tensor prefix; - tensorflow::Tensor tensor_names; - tensorflow::Tensor shape_and_slices; - std::vector var_handles; - tensorflow::AttrValue dtypes_attr_value; - std::vector restored_dtypes; - std::vector truncate_in_cast; - }; - absl::Status InvokeHelper(); - absl::Status RunShard(RestoreVariableShard shard); absl::Status ValidateInput(); }; @@ -144,218 +116,6 @@ void MlrtIfrtRestoreVariableKernel::Invoke() { } } -// Returns a casted tensor if successful. -absl::StatusOr Cast( - tensorflow::Tensor& in_tensor, tensorflow::DataType restored_dtype, - tensorflow::DataType cast_dtype, bool truncate_in_cast, - const tensorflow::DeviceMgr& device_manager, - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime, - OpKernelContext::Params& params) { - auto runner = - tfrt_stub::OpKernelRunner::Create( - /*op_name=*/ - "Cast", /*node_name=*/"Cast", params.device->name(), - /*num_args=*/1, - [&](tensorflow::AttrValueMap* attr_value_map) { - tensorflow::AttrValue restored_dtype_attr_value; - restored_dtype_attr_value.set_type(restored_dtype); - attr_value_map->insert({"SrcT", restored_dtype_attr_value}); - - tensorflow::AttrValue cast_dtype_attr_value; - cast_dtype_attr_value.set_type(cast_dtype); - attr_value_map->insert({"DstT", cast_dtype_attr_value}); - - tensorflow::AttrValue truncate_attr_value; - truncate_attr_value.set_b(truncate_in_cast); - attr_value_map->insert({"Truncate", truncate_attr_value}); - return absl::OkStatus(); - }, - device_manager, process_function_library_runtime) - .value(); - - std::vector input_tf_tensor_values; - input_tf_tensor_values.push_back(tensorflow::TensorValue(&in_tensor)); - - SetUpParams(runner, input_tf_tensor_values, params); - // Use persistent device instead of the per request device. - - OpKernelContext op_kernel_context(¶ms, /*num_outputs=*/1); - - runner.Run(&op_kernel_context); - - if (!op_kernel_context.status().ok()) { - return op_kernel_context.status(); - } - DCHECK_EQ(op_kernel_context.num_outputs(), 1); - return *(op_kernel_context.mutable_output(0)); -} - -absl::Status MlrtIfrtRestoreVariableKernel::RunShard( - RestoreVariableShard shard) { - std::optional ifrt_model_context = - context().resource_context().GetResource( - "IfrtModelContext"); - if (!ifrt_model_context.has_value()) { - return absl::FailedPreconditionError( - "RestoreVariableOp: failed to fetch IfrtModelContext"); - } - const int num_outputs = shard.var_handles.size(); - DCHECK_EQ(num_outputs, shard.tensor_names.NumElements()); - auto& fallback_request_state = context().fallback_request_state(); - - // Use `tf.RestoreV2` to restore tensor. This will also populate - // tensorflow::ResourceManager. - // TODO(b/319045348): avoid populating tensorflow::ResourceManager if the - // variable is only used by device/IFRT. - // TODO(b/319045348): consider directly calling restore function such as that - // in /tensorflow/core/kernels/save_restore_v2_ops.cc - auto runner = - tfrt_stub::OpKernelRunner::Create( - /*op_name=*/ - "RestoreV2", /*node_name=*/"RestoreV2", - context().params().device->name(), - /*num_args=*/3, - [&](tensorflow::AttrValueMap* attr_value_map) { - attr_value_map->insert({"dtypes", shard.dtypes_attr_value}); - return absl::OkStatus(); - }, - fallback_request_state.device_manager(), - fallback_request_state.process_function_library_runtime()) - .value(); - - // Prepare the input tensors. - std::vector input_tf_tensor_values; - static constexpr int kNumInputArgs = 3; - input_tf_tensor_values.resize(kNumInputArgs); - // We need to keep these tensor alive - input_tf_tensor_values[0].tensor = &shard.prefix; - input_tf_tensor_values[1].tensor = &shard.tensor_names; - input_tf_tensor_values[2].tensor = &shard.shape_and_slices; - - auto& params = context().params(); - SetUpParams(runner, input_tf_tensor_values, params); - // Use persistent device instead of the per request device. - params.device = context().fallback_request_state().device_manager().HostCPU(); - - struct AsyncState { - explicit AsyncState( - const std::vector& input_tf_tensor_values, - const OpKernelContext::Params& params, int num_outputs, - const tensorflow::DeviceMgr& device_manager, - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime) - : run_state(input_tf_tensor_values, params), - context(&run_state.params, num_outputs), - device_manager(device_manager), - process_function_library_runtime(process_function_library_runtime) {} - - tfrt_stub::OpKernelRunState run_state; - OpKernelContext context; - const tensorflow::DeviceMgr& device_manager; - const tensorflow::ProcessFunctionLibraryRuntime& - process_function_library_runtime; - - std::vector> results; - }; - auto async_state = std::make_unique( - input_tf_tensor_values, params, num_outputs, - fallback_request_state.device_manager(), - fallback_request_state.process_function_library_runtime()); - - ifrt_serving::IfrtRestoreTensorRegistry& ifrt_restore_tensor_registry = - (*ifrt_model_context)->GetRestoreTensorRegistry(); - for (int i = 0; i < num_outputs; ++i) { - auto promise = xla::ifrt::Future::CreatePromise(); - auto future = xla::ifrt::Future(promise); - const ResourceHandle& var_handle = - shard.var_handles[i].tensor().scalar()(); - - TF_ASSIGN_OR_RETURN(ifrt_serving::DtypeAndShape dtype_and_shape, - ifrt_serving::GetDtypeAndShape(var_handle)); - - std::string runtime_name = - ifrt_serving::GetRuntimeNameFromVarHandle(var_handle); - - ifrt_serving::IfrtRestoreTensorRegistry::RestoredTensorInfo - restored_tensor_info = {false, std::move(dtype_and_shape), - std::move(future)}; - if (auto status = ifrt_restore_tensor_registry.TryRegister( - runtime_name, restored_tensor_info); - !status.ok()) { - // Propagate errors so that if already-registered futures are being waited - // on, they can be unblocked. - for (auto& result : async_state->results) { - std::move(result).Set(status); - }; - return status; - } - async_state->results.push_back(std::move(promise)); - } - - // Use dedicated work queue for restore operation. - DCHECK((*ifrt_model_context)->checkpoint_loader_queue() != nullptr); - (*ifrt_model_context) - ->checkpoint_loader_queue() - ->AddTask([runner = std::move(runner), - async_state = std::move(async_state), - shard = std::move(shard)]() { - // Keep input tensor alive in `shard`. - auto* op_kernel_context_ptr = &async_state->context; - runner.Run(op_kernel_context_ptr); - - auto& op_kernel_context = async_state->context; - if (!op_kernel_context.status().ok()) { - for (auto& result : async_state->results) { - std::move(result).Set(op_kernel_context.status()); - } - return; - } - DCHECK_EQ(shard.var_handles.size(), op_kernel_context.num_outputs()); - DCHECK_EQ(shard.truncate_in_cast.size(), - op_kernel_context.num_outputs()); - - // TODO(b/343964091): consider to run multiple casts in parallel. - for (int i = 0; i < op_kernel_context.num_outputs(); ++i) { - DCHECK(op_kernel_context.mutable_output(i)); - - if (op_kernel_context.mutable_output(i)->dtype() != - shard.restored_dtypes[i]) { - std::move(async_state->results[i]) - .Set(absl::InvalidArgumentError(absl::StrCat( - "The restored tensor has a different dtype than the " - "variable handle: ", - op_kernel_context.mutable_output(i)->dtype(), " vs. ", - shard.restored_dtypes[i]))); - return; - } - const ResourceHandle& var_handle = - shard.var_handles[i] - .tensor() - .scalar()(); - - if (shard.restored_dtypes[i] == - var_handle.dtypes_and_shapes()[0].dtype) { - std::move(async_state->results[i]) - .Set(*std::move(op_kernel_context.mutable_output(i))); - } else { - absl::StatusOr cast_output = Cast( - *op_kernel_context.mutable_output(i), shard.restored_dtypes[i], - var_handle.dtypes_and_shapes()[0].dtype, - shard.truncate_in_cast[i], async_state->device_manager, - async_state->process_function_library_runtime, - async_state->run_state.params); - if (!cast_output.ok()) { - std::move(async_state->results[i]).Set(cast_output.status()); - } else { - std::move(async_state->results[i]).Set(*std::move(cast_output)); - } - } - } - }); - return absl::OkStatus(); -} - absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() { if (prefix().tensor().NumElements() != 1) { return absl::InvalidArgumentError( @@ -398,65 +158,26 @@ absl::Status MlrtIfrtRestoreVariableKernel::ValidateInput() { } absl::Status MlrtIfrtRestoreVariableKernel::InvokeHelper() { - TF_RETURN_IF_ERROR(ValidateInput()); - - std::vector variable_sizes; - variable_sizes.reserve(var_handles().size()); - for (auto& handle : var_handles()) { - variable_sizes.push_back(GetSizeFromVarHandle( - handle.tensor().scalar()())); + std::optional model_restore_context = + context() + .resource_context() + .GetResource( + ifrt_serving::kIfrtModelRestoreContextName); + if (!model_restore_context.has_value()) { + return absl::InternalError( + "Did not find IfrtModelRestoreContext resource."); } - - std::vector> sharded_indices = - ShardVariables(kNumRestoreClusters, absl::MakeSpan(variable_sizes)); - - // Converts the names and slices back to the tensor. - auto vector_to_tensor = [](const std::vector& vec) { - tensorflow::Tensor tensor(tensorflow::DT_STRING, - TensorShape({static_cast(vec.size())})); - for (int i = 0; i < vec.size(); ++i) { - tensor.flat()(i) = vec[i]; - } - return tensor; - }; - - const auto& tensor_names_flat = tensor_names().tensor().flat(); - const auto& shape_and_slices_flat = - shape_and_slices().tensor().flat(); - - std::vector shards; - shards.reserve(sharded_indices.size()); - for (auto& sharded_index : sharded_indices) { - RestoreVariableShard shard; - shard.var_handles.reserve(sharded_index.size()); - shard.truncate_in_cast.reserve(sharded_index.size()); - shard.restored_dtypes.reserve(sharded_index.size()); - - std::vector tensor_names; - std::vector shape_and_slices; - shape_and_slices.reserve(sharded_index.size()); - tensor_names.reserve(sharded_index.size()); - for (int index : sharded_index) { - tensor_names.push_back(tensor_names_flat(index)); - shape_and_slices.push_back(shape_and_slices_flat(index)); - shard.dtypes_attr_value.mutable_list()->add_type( - restored_dtypes()[index]); - - shard.var_handles.push_back(var_handles()[index]); - shard.restored_dtypes.push_back(restored_dtypes()[index]); - shard.truncate_in_cast.push_back(truncate_in_cast()[index]); - } - - shard.prefix = prefix().tensor(); - shard.tensor_names = vector_to_tensor(tensor_names); - shard.shape_and_slices = vector_to_tensor(shape_and_slices); - shards.push_back(std::move(shard)); + if (*model_restore_context == nullptr) { + return absl::InternalError("IfrtModelRestoreContext must not be null."); } - - for (const auto& shard : shards) { - TF_RETURN_IF_ERROR(RunShard(shard)); + ifrt_serving::CheckpointLoader* checkpoint_loader = + (*model_restore_context)->checkpoint_loader(); + if (!checkpoint_loader) { + return absl::InternalError("CheckpointLoader must not be null."); } - return absl::OkStatus(); + return checkpoint_loader->Load(prefix(), var_handles(), tensor_names(), + shape_and_slices(), restored_dtypes(), + truncate_in_cast(), context()); } class MlrtIfrtLoadVariableKernel : public mlrt::KernelFrame { diff --git a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc index 83b5876cdae788..07fb83b1e6eb32 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/ifrt_ops_kernel_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/python/ifrt/future.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/resource_var.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_matcher.h" @@ -43,8 +44,10 @@ limitations under the License. #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_compat_request_state.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_config.pb.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_restore_tensor_registry.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" @@ -57,7 +60,6 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/kernel/context.h" #include "tensorflow/core/tfrt/mlrt/kernel/kernel.h" #include "tensorflow/core/tfrt/utils/fallback_tensor.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/refcount.h" #include "tsl/platform/status.h" @@ -403,6 +405,13 @@ class KernelTest : public ::testing::Test { .value(); ifrt_model_context_->set_checkpoint_loader_queue(restore_work_queue_.get()); + resource_context_ + .CreateResource( + ifrt_serving::kIfrtModelRestoreContextName, + std::make_unique( + &ifrt_model_context_->GetRestoreTensorRegistry(), + ifrt_model_context_->checkpoint_loader_queue())); + serving_device_selector_ = std::make_unique(); ifrt_core_selector_ = diff --git a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc index f9966959ea517d..a9aa89a0aebad5 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/kernel_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/strings/substitute.h" #include "absl/synchronization/notification.h" #include "absl/types/span.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/device.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/tfrt/fallback/device_with_custom_allocator.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/tfrt/mlrt/interpreter/interpreter_testutil.h" #include "tensorflow/core/tfrt/mlrt/kernel/batch_kernel.h" #include "tensorflow/core/tfrt/mlrt/kernel/context.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tfrt/concurrency/ref_count.h" // from @tf_runtime #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc index cd3f49f3d6b37c..16293c2ed5d4bb 100644 --- a/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc +++ b/tensorflow/core/tfrt/mlrt/kernel/shard_restore_util.cc @@ -66,7 +66,7 @@ std::vector> ShardVariables( }; std::priority_queue, decltype(cmp)> - min_heap; + min_heap(cmp); for (int i = 0; i < num_shards; ++i) { min_heap.push(RestoreVariableCluster()); } diff --git a/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc b/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc index 092e6c2697ea28..d887731e46d2e8 100644 --- a/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc +++ b/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_util.cc @@ -78,7 +78,7 @@ std::vector ParamFromEnvWithDefault(const char* var_name, bool ParamFromEnvBoolWithDefault(const char* var_name, bool default_value) { const char* val = std::getenv(var_name); - return (val) ? tensorflow::str_util::Lowercase(val) == "true" : default_value; + return (val) ? absl::AsciiStrToLower(val) == "true" : default_value; } } // namespace tf diff --git a/tensorflow/core/tfrt/saved_model/BUILD b/tensorflow/core/tfrt/saved_model/BUILD index 9e85c14baef362..5261546e6c6a0f 100644 --- a/tensorflow/core/tfrt/saved_model/BUILD +++ b/tensorflow/core/tfrt/saved_model/BUILD @@ -135,6 +135,8 @@ cc_library( "//tensorflow/core/tfrt/graph_executor", "//tensorflow/core/tfrt/graph_executor:export_mlir", "//tensorflow/core/tfrt/graph_executor:graph_execution_options", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/mlrt/bytecode", "//tensorflow/core/tfrt/mlrt/bytecode:executable", "//tensorflow/core/tfrt/mlrt/interpreter:context", diff --git a/tensorflow/core/tfrt/saved_model/saved_model.cc b/tensorflow/core/tfrt/saved_model/saved_model.cc index 62ad6550cd6c6d..84fbbff7401340 100644 --- a/tensorflow/core/tfrt/saved_model/saved_model.cc +++ b/tensorflow/core/tfrt/saved_model/saved_model.cc @@ -16,13 +16,11 @@ limitations under the License. #include #include -#include #include #include #include #include #include -#include #include #include @@ -70,6 +68,8 @@ limitations under the License. #include "tensorflow/core/tfrt/graph_executor/export_mlir.h" #include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h" #include "tensorflow/core/tfrt/graph_executor/graph_executor.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/mlrt/bytecode/bytecode.h" #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/context.h" @@ -134,6 +134,34 @@ auto* saved_model_input_spec_validation_failure = "/tensorflow/tfrt/saved_model/input_spec_validation_failure", "Record the models that failed input spec validation.", "model_name"); +absl::Status PrepareRestore(mlir::MLIRContext* context, + ModelRuntimeContext* model_runtime_context, + const tensorflow::MetaGraphDef& meta_graph_def, + FallbackState& fallback_state, + const std::string& saved_model_dir, + const SavedModel::Options& options, + ifrt_serving::CheckpointLoader* checkpoint_loader) { + // Import the global MLIR with `import_user_signatures` as true so that we can + // analysis the global MLIR to retrieve data needed for restore. + mlir::OwningOpRef mlir_module_restore_analysis; + ASSIGN_OR_RETURN_IN_IMPORT( + mlir_module_restore_analysis, + ImportSavedModel( + context, meta_graph_def, fallback_state, saved_model_dir, + /*import_user_signatures=*/true, + options.graph_execution_options.run_placer_grappler_on_functions)); + + if (!checkpoint_loader) { + return absl::InternalError("Missing checkpoint loader."); + } + + TF_RETURN_IF_ERROR(checkpoint_loader->PrepareRestore( + std::move(mlir_module_restore_analysis))); + + LOG(INFO) << "Complete set restore metadata."; + return absl::OkStatus(); +} + tensorflow::Status RunBytecodeInitializers( const GraphExecutionOptions& options, const InitializersAndSignatures& initializers_and_signatures, @@ -596,6 +624,25 @@ absl::StatusOr> SavedModelImpl::LoadSavedModel( model_context.set_callable_options(nullptr); } + if (options.graph_execution_options.use_ifrt) { + std::optional + model_restore_context = + model_context.resource_context() + .GetResource( + ifrt_serving::kIfrtModelRestoreContextName); + if (!model_restore_context.has_value()) { + return absl::InternalError( + "Did not find IfrtModelRestoreContext resource."); + } + if (*model_restore_context == nullptr) { + return absl::InternalError("IfrtModelRestoreContexts must not be null."); + } + TF_RETURN_IF_ERROR( + PrepareRestore(&context, &model_context, meta_graph_def, + *fallback_state, std::string(saved_model_dir), options, + (*model_restore_context)->checkpoint_loader())); + } + GetDefaultInputValue(meta_graph_def.signature_def(), model_context, initializers_and_signatures.signature_map); diff --git a/tensorflow/core/tfrt/saved_model/tests/BUILD b/tensorflow/core/tfrt/saved_model/tests/BUILD index 3dfc07d245eaf1..c026800861ff2a 100644 --- a/tensorflow/core/tfrt/saved_model/tests/BUILD +++ b/tensorflow/core/tfrt/saved_model/tests/BUILD @@ -649,7 +649,9 @@ cc_library( "//tensorflow/core/platform:resource_loader", "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink", "//tensorflow/core/tfrt:ifrt_program_ops_op_lib", + "//tensorflow/core/tfrt/ifrt:checkpoint_loader", "//tensorflow/core/tfrt/ifrt:ifrt_model_context", + "//tensorflow/core/tfrt/ifrt:ifrt_model_restore_context", "//tensorflow/core/tfrt/ifrt:ifrt_serving_core_selector", "//tensorflow/core/tfrt/mlrt/kernel:ifrt_ops_kernel", "//tensorflow/core/tfrt/runtime", diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc index 8eaf2252618f9a..4f4caf0e028b52 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_ifrt_test.cc @@ -26,14 +26,16 @@ limitations under the License. #include "xla/python/ifrt/client.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/framework/test_util/mock_serving_device_selector.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/resource_loader.h" +#include "tensorflow/core/tfrt/ifrt/checkpoint_loader.h" #include "tensorflow/core/tfrt/ifrt/ifrt_model_context.h" +#include "tensorflow/core/tfrt/ifrt/ifrt_model_restore_context.h" #include "tensorflow/core/tfrt/ifrt/ifrt_serving_core_selector.h" #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" @@ -77,10 +79,17 @@ TEST(SavedModelIfrt, Basic) { "IfrtModelContext", client, &core_selector, &GetThreadPool(), /*compilation_environment_proto=*/nullptr); - (*model_context.resource_context() - .GetResource( - "IfrtModelContext")) - ->set_checkpoint_loader_queue(work_queue.get()); + tensorflow::ifrt_serving::IfrtModelContext* ifrt_model_context = + (*model_context.resource_context() + .GetResource( + "IfrtModelContext")); + ifrt_model_context->set_checkpoint_loader_queue(work_queue.get()); + model_context.resource_context() + .CreateResource( + ifrt_serving::kIfrtModelRestoreContextName, + std::make_unique( + &ifrt_model_context->GetRestoreTensorRegistry(), + ifrt_model_context->checkpoint_loader_queue())); return absl::OkStatus(); }); diff --git a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc index 183b41805d4fab..605e4413ffffc1 100644 --- a/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc +++ b/tensorflow/core/tfrt/saved_model/tests/saved_model_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/backend_compiler.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/tfrt/runtime/work_queue_interface.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tensorflow/core/tfrt/saved_model/saved_model_util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tfrt/host_context/concurrent_work_queue.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/saved_model/utils/BUILD b/tensorflow/core/tfrt/saved_model/utils/BUILD index b76aa6fa0b8c62..008fea2a0daf9d 100644 --- a/tensorflow/core/tfrt/saved_model/utils/BUILD +++ b/tensorflow/core/tfrt/saved_model/utils/BUILD @@ -60,8 +60,8 @@ tf_cc_shared_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", + "@local_xla//xla/tsl/lib/core:status_test_util", "@tf_runtime//:bef", ], ) diff --git a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc index deaf171d9d1162..2fc9c28436507d 100644 --- a/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc +++ b/tensorflow/core/tfrt/saved_model/utils/serialize_utils_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tfrt/transforms/mlrt/import_model.h" #include "tensorflow/compiler/mlir/tfrt/translate/import_model.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tensorflow/core/tfrt/saved_model/saved_model_util.h" #include "tensorflow/core/tfrt/utils/utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tfrt/bef/bef_buffer.h" // from @tf_runtime diff --git a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc index 2b6326824e61f2..b63fc769e42b7c 100644 --- a/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc +++ b/tensorflow/core/tfrt/tfrt_session/tfrt_session_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/const_op.h" #include "tensorflow/cc/saved_model/reader.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/core/tfrt/runtime/runtime.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" #include "tensorflow/core/tfrt/utils/thread_pool.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" namespace tensorflow { diff --git a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc index 330added9a6099..f5af931d344dd6 100644 --- a/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc +++ b/tensorflow/core/tfrt/utils/debug/node_io_dump_rewriter_test.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/math_ops.h" #include "tensorflow/cc/saved_model/reader.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" @@ -45,7 +46,6 @@ limitations under the License. #include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/tfrt/saved_model/saved_model.h" #include "tensorflow/core/tfrt/saved_model/saved_model_testutil.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc index 498d07a0e41e23..c18230e0b431dc 100644 --- a/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc +++ b/tensorflow/core/tfrt/utils/tfrt_graph_execution_state.cc @@ -250,9 +250,9 @@ TfrtGraphExecutionState::CreateOptimizedGraph( DumpGraphDefToFile("before_pruning", graph_def); } - TF_ASSIGN_OR_RETURN( - result.graph, - CreatePrunedGraph(graph_def, build_graph_options.callable_options)); + TF_ASSIGN_OR_RETURN(result.graph, + CreatePrunedGraph(std::move(graph_def), + build_graph_options.callable_options)); DCHECK(result.graph); if (VLOG_IS_ON(1)) { diff --git a/tensorflow/core/tpu/graph_rewrite/BUILD b/tensorflow/core/tpu/graph_rewrite/BUILD index 273c822fd74df4..73fbacd589160b 100644 --- a/tensorflow/core/tpu/graph_rewrite/BUILD +++ b/tensorflow/core/tpu/graph_rewrite/BUILD @@ -1,5 +1,9 @@ # Contains graph rewrites for TPU runtimes and optimizations. +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) load( "//tensorflow/core/platform:build_config_root.bzl", "if_static", @@ -119,6 +123,7 @@ cc_library( "//tensorflow/core:session_options", "//tensorflow/core/common_runtime:function_body", "//tensorflow/core/common_runtime:function_utils", + "//tensorflow/core/config:flag_defs", "//tensorflow/core/tpu:tpu_compile_interface", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/container:flat_hash_map", @@ -131,6 +136,7 @@ cc_library( "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:status", "@local_xla//xla:status_macros", + "@local_xla//xla/tsl/util:env_var", ] + if_static( [ "//tensorflow/core/common_runtime:function", @@ -140,6 +146,26 @@ cc_library( ), ) +tf_cc_test( + name = "encapsulate_tpu_computations_pass_test", + srcs = ["encapsulate_tpu_computations_pass_test.cc"], + deps = [ + ":encapsulate_tpu_computations_pass", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", + "//tensorflow/core:lib_internal", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/common_runtime:optimization_registry", + "//tensorflow/core/config:flag_defs", + ], +) + cc_library( name = "distributed_tpu_rewrite_pass_internal", srcs = ["distributed_tpu_rewrite_pass_internal.cc"], diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc index 7b58c9e4c10f31..bca30520071c66 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_helpers.cc @@ -77,7 +77,7 @@ Status DistributedTPURewriteHelpers::GetSystemDevice( return errors::InvalidArgument( "System devices cannot be part " "of multiple different jobs. Found: ", - str_util::Join(job_names, ",")); + absl::StrJoin(job_names, ",")); } // Identify the lexicographically first device from the list of diff --git a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc index 9e63af5b0fb4ad..a13b3caba2fc17 100644 --- a/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/distributed_tpu_rewrite_pass.cc @@ -573,7 +573,7 @@ Status FindHostComputeKeyPlaceholderNodes( for (Node* node : graph->op_nodes()) { if (node->type_string() == "Placeholder" && - str_util::EndsWith(node->name(), "_key_placeholder")) { + absl::EndsWith(node->name(), "_key_placeholder")) { const AttrValue* call_node_attr = node->attrs().Find("_host_compute_call_node"); if (call_node_attr != nullptr) { diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc index dc2fe786101e29..9370cd6b01ab1c 100644 --- a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function_utils.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/config/flag_defs.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph_to_functiondef.h" @@ -151,8 +152,8 @@ Status RewriteSubgraph(const std::vector& arg_source_tensors, retvals.push_back(n); } else if (n->type_string() == "TPUReplicateMetadata") { metadata_node = n; - } else if (!str_util::StrContains(n->requested_device(), - DEVICE_TPU_REPLICATED_CORE)) { + } else if (!absl::StrContains(n->requested_device(), + DEVICE_TPU_REPLICATED_CORE)) { // If an operator isn't assigned to a TPU core device, assign it to // TPU_REPLICATED_CORE without a specific core ID. For some operators, // such as variable reads/writes, the operator may be assigned to non-TPU @@ -2481,10 +2482,32 @@ Status LiftOutsideCompilationOnlyArgs(Graph* g, FunctionLibraryRuntime* flr, return absl::OkStatus(); } +// TODO(b/355263902): Encapsulation fails for some non-TPU graphs that are +// missing full variable shape information. Remove this path once the +// underlying issue is fixed. +bool ShouldSkipEncapsulationForNonTPUGraph() { + return flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.value(); +} + } // namespace /*static*/ Status EncapsulateTPUComputationsPass::Encapsulate( std::unique_ptr* graph, FunctionLibraryDefinition* flib_def) { + // If the graph does not contain any TPU computations, there is nothing to do. + if (ShouldSkipEncapsulationForNonTPUGraph()) { + bool found_tpu_replicate = false; + for (const Node* n : (*graph)->nodes()) { + if (n->attrs().Find(kTPUReplicateAttr) != nullptr) { + found_tpu_replicate = true; + break; + } + } + if (!found_tpu_replicate) { + VLOG(1) << "No TPU replicate found, skipping encapsulation"; + return absl::OkStatus(); + } + } + // Check for undeclared outputs before Encapsulation, so we can give a better // error message. // TODO(phawkins): merge this with the encapsulation code to avoid the extra diff --git a/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc new file mode 100644 index 00000000000000..a21cdaec4dbc72 --- /dev/null +++ b/tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass_test.cc @@ -0,0 +1,90 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/tpu/graph_rewrite/encapsulate_tpu_computations_pass.h" + +#include + +#include "xla/tsl/lib/core/status_test_util.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/config/flag_defs.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +std::unique_ptr CreateGraph() { + // c = a + b + auto g = std::make_unique(OpRegistry::Global()); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 1, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + auto ret = test::graph::Retval(g.get(), 0, tmp); + g->AddControlEdge(in1, ret); + FixupSourceAndSinkEdges(g.get()); + return g; +} + +TEST(EncapsulateTPUComputationsPassTest, NonTPUGraph) { + auto g = CreateGraph(); + GraphOptimizationPassOptions options; + options.graph = &g; + options.flib_def = g->mutable_flib_def(); + + EncapsulateTPUComputationsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + int nodes_meeting_expectations = 0; + + for (const auto* node : g->nodes()) { + if (!IsSource(node) && !IsSink(node)) { + ASSERT_TRUE(node->attrs().Find("_xla_inferred_shapes")); + ++nodes_meeting_expectations; + } + } + EXPECT_EQ(nodes_meeting_expectations, 4); +} + +TEST(EncapsulateTPUComputationsPassTest, SkipEncapsulationForNonTPUGraph) { + flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(true); + auto g = CreateGraph(); + GraphOptimizationPassOptions options; + options.graph = &g; + options.flib_def = g->mutable_flib_def(); + + EncapsulateTPUComputationsPass pass; + TF_ASSERT_OK(pass.Run(options)); + + int nodes_meeting_expectations = 0; + + for (const auto* node : g->nodes()) { + if (!IsSource(node) && !IsSink(node)) { + ASSERT_FALSE(node->attrs().Find("_xla_inferred_shapes")); + ++nodes_meeting_expectations; + } + } + EXPECT_EQ(nodes_meeting_expectations, 4); + + flags::Global().enable_skip_encapsulation_for_non_tpu_graphs.reset(false); +} + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops.cc b/tensorflow/core/tpu/kernels/sharding_util_ops.cc index 4ca817b23ebc13..643f5bc588f334 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops.cc @@ -178,7 +178,6 @@ class XlaSplitNDBaseOp : public XlaSplitNDShared { bool resource, OpKernelContext* ctx, const std::function& assign_or_copy_value_fn, const Tensor* input) { - const auto& input_shape = input->shape().dim_sizes(); absl::string_view input_name = resource ? kResourceName : kTensorName; auto allocate_output_fn = [&](int i, const TensorShape& output_slice_shape, diff --git a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc index e62e9d71c04e6f..084802a89b5137 100644 --- a/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc +++ b/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/tpu/kernels/sharding_utils_test.cc b/tensorflow/core/tpu/kernels/sharding_utils_test.cc index cd583df8a57bef..552a637a84b206 100644 --- a/tensorflow/core/tpu/kernels/sharding_utils_test.cc +++ b/tensorflow/core/tpu/kernels/sharding_utils_test.cc @@ -26,11 +26,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/status.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.cc b/tensorflow/core/tpu/kernels/sparse_core_layout.cc index 2f4a945be745cb..bc4c416f120ce2 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.cc @@ -72,10 +72,10 @@ SparseCoreLayoutStacker::SparseCoreLayoutStacker(int num_partitions, activation_mem_bytes_limit_(GetXlaSparseCoreStackingMemLimit()), variable_shard_bytes_limit_(GetXlaSparseCoreStackingTableShardLimit()) {} -absl::Status SparseCoreLayoutStacker::AddTable(tsl::StringPiece table_name, +absl::Status SparseCoreLayoutStacker::AddTable(absl::string_view table_name, int64_t table_height, int64_t table_width, - tsl::StringPiece group, + absl::string_view group, int64_t output_samples) { if (stacks_by_group_.empty()) { // First call? VLOG(1) << "Stacking parameters: stacking_enabled_ = " << stacking_enabled_ diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.h b/tensorflow/core/tpu/kernels/sparse_core_layout.h index c1d22f330c3882..9f4697c2b910c9 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.h +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.h @@ -84,8 +84,8 @@ class SparseCoreLayoutStacker { // // Be sure you call AddTable in a deterministic order; the details of the // stacking will depend on the order you call AddTable. - absl::Status AddTable(tsl::StringPiece table_name, int64_t table_height, - int64_t table_width, tsl::StringPiece group, + absl::Status AddTable(absl::string_view table_name, int64_t table_height, + int64_t table_width, absl::string_view group, int64_t output_samples); // Get the information about each table out. diff --git a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc index ca2b1d5b353475..d854eab1513b75 100644 --- a/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc +++ b/tensorflow/core/tpu/tpu_embedding_configuration_proto_rewrite_test.cc @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status_matchers.h" #include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/test.h" diff --git a/tensorflow/core/tpu/tpu_embedding_errors_test.cc b/tensorflow/core/tpu/tpu_embedding_errors_test.cc index 3dbb182a97abaf..f0a8d869b797ef 100644 --- a/tensorflow/core/tpu/tpu_embedding_errors_test.cc +++ b/tensorflow/core/tpu/tpu_embedding_errors_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/errors.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/error_codes.pb.h" diff --git a/tensorflow/core/util/autotune_maps/BUILD b/tensorflow/core/util/autotune_maps/BUILD index 8b89487f0b0d9b..990edbe549f3eb 100644 --- a/tensorflow/core/util/autotune_maps/BUILD +++ b/tensorflow/core/util/autotune_maps/BUILD @@ -52,8 +52,8 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/protobuf:dnn_proto_cc", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) @@ -118,7 +118,7 @@ tf_cuda_library( "conv_parameters.h", ], cuda_deps = [ - "@local_tsl//tsl/lib/strings:proto_serialization", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], deps = [ ":conv_parameters_proto_cc", @@ -182,12 +182,12 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core/platform:status", "//tensorflow/core/platform:str_util", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/protobuf:dnn_proto_cc", "@local_xla//xla:status_macros", "@local_xla//xla/stream_executor:dnn", "@local_xla//xla/stream_executor:platform_manager", "@local_xla//xla/stream_executor/gpu:gpu_init", + "@local_xla//xla/tsl/lib/strings:proto_serialization", ], ) diff --git a/tensorflow/core/util/autotune_maps/autotune_serialize.cc b/tensorflow/core/util/autotune_maps/autotune_serialize.cc index 63470c09df5f87..c601502a0d0512 100644 --- a/tensorflow/core/util/autotune_maps/autotune_serialize.cc +++ b/tensorflow/core/util/autotune_maps/autotune_serialize.cc @@ -25,13 +25,13 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/str_util.h" #include "tensorflow/core/util/activation_mode.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_autotune_maps.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc index baa68aae1131c1..0bd1122c132238 100644 --- a/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc +++ b/tensorflow/core/util/autotune_maps/conv_map_wrapper.cc @@ -19,9 +19,9 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/util/autotune_maps/autotune_map.pb.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/protobuf/dnn.pb.h" namespace tensorflow { diff --git a/tensorflow/core/util/autotune_maps/conv_parameters.cc b/tensorflow/core/util/autotune_maps/conv_parameters.cc index 63436938980b68..a620e39c2b2afe 100644 --- a/tensorflow/core/util/autotune_maps/conv_parameters.cc +++ b/tensorflow/core/util/autotune_maps/conv_parameters.cc @@ -19,9 +19,9 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/hash.h" #include "tensorflow/core/util/autotune_maps/conv_parameters.pb.h" -#include "tsl/lib/strings/proto_serialization.h" namespace tensorflow { diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h index 5c8a5dbfda4fcf..61d1fb5a19d538 100644 --- a/tensorflow/core/util/bcast.h +++ b/tensorflow/core/util/bcast.h @@ -199,7 +199,6 @@ BCastList::BCastList(const BCastList::Vec (&x)[N], prev_is_one[i] = false; current_is_one[i] = false; } - Vec output; bool output_dim_set = false; int64_t output_dim = -1; bool none_is_one = true; diff --git a/tensorflow/core/util/dump_graph_test.cc b/tensorflow/core/util/dump_graph_test.cc index d24eccf54d34e7..935ca41a7e9d26 100644 --- a/tensorflow/core/util/dump_graph_test.cc +++ b/tensorflow/core/util/dump_graph_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/util/dump_graph.h" #include "absl/strings/match.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -23,7 +24,6 @@ limitations under the License. #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace tensorflow { diff --git a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc index fe27a6c5055cf3..ad28eebba468a6 100644 --- a/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc +++ b/tensorflow/core/util/quantization/uniform_quant_ops_params_test.cc @@ -16,9 +16,9 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/core/util/sparse/sparse_tensor.cc b/tensorflow/core/util/sparse/sparse_tensor.cc index 48ce3b5b13cee8..75dffd02fed286 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.cc +++ b/tensorflow/core/util/sparse/sparse_tensor.cc @@ -257,7 +257,7 @@ Status SparseTensor::IndicesValidHelper() const { if (!valid) { return errors::InvalidArgument(index, " is out of bounds: need 0 <= index < [", - str_util::Join(shape_, ","), "]"); + absl::StrJoin(shape_, ","), "]"); } if (!increasing) { return errors::InvalidArgument( diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index 469502b3f63ce2..ec04070b90024b 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -447,8 +447,8 @@ inline SparseTensor SparseTensor::Concat( << "All SparseTensors' shapes must match except on the concat dim. " << "Concat dim: " << primary_dim << ", mismatched shape at dim: " << cdim - << ". Expecting shape like: [" << str_util::Join(final_shape, ",") - << "] but saw shape: [" << str_util::Join(st_shape, ",") << "]"; + << ". Expecting shape like: [" << absl::StrJoin(final_shape, ",") + << "] but saw shape: [" << absl::StrJoin(st_shape, ",") << "]"; } // Update dimension of final shape diff --git a/tensorflow/core/util/strided_slice_op_test.cc b/tensorflow/core/util/strided_slice_op_test.cc index cbe097683be662..6eb961c2f2a250 100644 --- a/tensorflow/core/util/strided_slice_op_test.cc +++ b/tensorflow/core/util/strided_slice_op_test.cc @@ -22,12 +22,12 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index 6e9c20d0a39671..05f5d0f9636d04 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -151,10 +151,12 @@ bool IsDataTypeSupportedByOneDNNOnThisCPU(const DataType& dt) { } else if (dt == DT_HALF) { // Float16 is not supported in oneDNN v2.x #ifdef ENABLE_ONEDNN_V3 - result = (TestCPUFeature(port::CPUFeature::AVX512BW) && - (TestCPUFeature(port::CPUFeature::AVX512_FP16) || - TestCPUFeature(port::CPUFeature::AMX_FP16) || - TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT))); + // Some CPUs that don't support AVX-512 use AVX-NE-CONVERT to cast to and + // from FP32 + result = ((TestCPUFeature(port::CPUFeature::AVX512BW) && + (TestCPUFeature(port::CPUFeature::AVX512_FP16) || + TestCPUFeature(port::CPUFeature::AMX_FP16))) || + TestCPUFeature(port::CPUFeature::AVX_NE_CONVERT)); if (result) VLOG(2) << "CPU supports " << DataType_Name(dt); #endif // ENABLE_ONEDNN_V3 } else { diff --git a/tensorflow/dtensor/cc/dtensor_device.cc b/tensorflow/dtensor/cc/dtensor_device.cc index 76abf34544f88f..6600a6b23ebd9d 100644 --- a/tensorflow/dtensor/cc/dtensor_device.cc +++ b/tensorflow/dtensor/cc/dtensor_device.cc @@ -1716,7 +1716,7 @@ void DTensorDevice::ModuleToExecutionFunctions( absl::flat_hash_set control_ret_nodes; GraphExportConfig export_config; RETURN_C_STATUS_IF_NOT_OK( - tensorflow::tf2xla::v2::ConvertMlirToGraph( + tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( *lowering_context.module, export_config, &(lowering_context.graph), flib_def, &control_ret_nodes), status); diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 97f7d3d2a7a93d..f304d843096efb 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -436,6 +436,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@llvm-project//llvm:Support", "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", ], alwayslink = True, diff --git a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc index cff404f2095fec..a89a07521eb939 100644 --- a/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc @@ -235,7 +235,6 @@ GetSpecsFromLabelsAndMap( std::vector sharding_specs(layout_rank); absl::flat_hash_map dimension_use_count; - absl::flat_hash_set dimension_use_set; for (const auto& label_and_indices : label_to_index) { const auto& loc = label_to_sharding_spec.find(label_and_indices.first); if (loc != label_to_sharding_spec.end()) { diff --git a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc index cac70c2b9848a6..61d51226141168 100644 --- a/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/matmul_spmd_expander.cc @@ -61,7 +61,6 @@ void GetTransposeSettings(mlir::Operation* op, bool* left_transposed, } // namespace StatusOr MatMulSPMDExpander::ExpandOp(mlir::Operation* op) { - absl::flat_hash_set reduced_dims; bool left_transposed; bool right_transposed; TF_ASSIGN_OR_RETURN(const Layout left_layout, diff --git a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc index 07e121b88424fb..737d1f562bb8ff 100644 --- a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/log/log.h" @@ -388,9 +389,6 @@ StatusOr ExpandSaveV2Op(mlir::Operation* op) { auto save_v2 = mlir::cast(op); mlir::OpBuilder builder(save_v2); - - absl::flat_hash_map, Layout>> - tensor_shape_layout_map; std::vector metadata; for (const auto& it : llvm::enumerate(save_v2.getTensors())) { mlir::Value tensor = it.value(); diff --git a/tensorflow/dtensor/mlir/sparse_expander_common.h b/tensorflow/dtensor/mlir/sparse_expander_common.h index 4496043bed384f..9d6115067ae2a8 100644 --- a/tensorflow/dtensor/mlir/sparse_expander_common.h +++ b/tensorflow/dtensor/mlir/sparse_expander_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/optional.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.h b/tensorflow/dtensor/mlir/spmd_expander_common.h index 90b5ba5346bc34..0a35ce8032b07b 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.h +++ b/tensorflow/dtensor/mlir/spmd_expander_common.h @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project diff --git a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc index 802d46fd27ecde..2e24e5d1f9db4c 100644 --- a/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc +++ b/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc @@ -165,7 +165,7 @@ Status UpdateMetadataProtoXlaSpmd(const Mesh& mesh_config, mesh_name = ""; } const std::vector& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name]; - VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", "); + VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", "); xla::DeviceAssignmentProto device_assignment; device_assignment.set_replica_count(1); @@ -223,7 +223,7 @@ Status UpdateMetadataProtoDtensorSpmd(const Mesh& mesh_config, mesh_name = ""; } const std::vector& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name]; - VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", "); + VLOG(1) << "tpu_core_ids: " << absl::StrJoin(tpu_core_ids, ", "); xla::DeviceAssignmentProto device_assignment; device_assignment.set_replica_count(num_replicas); diff --git a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc index ff29775e6ad5d9..475e08c28269f8 100644 --- a/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc +++ b/tensorflow/dtensor/tests/layout_to_xla_sharding_test.cc @@ -22,11 +22,11 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "benchmark/benchmark.h" // from @com_google_benchmark #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/dtensor/cc/dstatus.h" #include "tensorflow/dtensor/cc/tensor_layout.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace tensorflow { diff --git a/tensorflow/examples/speech_commands/accuracy_utils_test.cc b/tensorflow/examples/speech_commands/accuracy_utils_test.cc index cf4f5bad49c31d..7edd1b42689382 100644 --- a/tensorflow/examples/speech_commands/accuracy_utils_test.cc +++ b/tensorflow/examples/speech_commands/accuracy_utils_test.cc @@ -15,11 +15,11 @@ limitations under the License. #include "tensorflow/examples/speech_commands/accuracy_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/examples/speech_commands/recognize_commands_test.cc b/tensorflow/examples/speech_commands/recognize_commands_test.cc index 1730d064037821..1f13e2499362e4 100644 --- a/tensorflow/examples/speech_commands/recognize_commands_test.cc +++ b/tensorflow/examples/speech_commands/recognize_commands_test.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/examples/speech_commands/recognize_commands.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD index 4a1d5bda4f394d..d49976b8cf3886 100644 --- a/tensorflow/lite/BUILD +++ b/tensorflow/lite/BUILD @@ -126,6 +126,8 @@ filegroup( name = "tflite_internal_cc_3p_api_deps_src_all", srcs = [ ":tflite_internal_cc_3p_api_deps_src", + "//tensorflow/compiler/mlir/lite:tflite_internal_cc_3p_api_deps_src", + "//tensorflow/compiler/mlir/lite/core/api:tflite_internal_cc_3p_api_deps_src", "//tensorflow/compiler/mlir/lite/schema:tflite_internal_cc_3p_api_deps_src", "//tensorflow/lite/core:macros.h", "//tensorflow/lite/core/acceleration/configuration/c:tflite_internal_cc_3p_api_deps_src", @@ -141,7 +143,6 @@ filegroup( filegroup( name = "tflite_internal_cc_3p_api_deps_src", srcs = [ - ":allocation.cc", ":allocation.h", ":array.cc", ":array.h", @@ -150,7 +151,6 @@ filegroup( ":minimal_logging.cc", ":minimal_logging.h", ":minimal_logging_android.cc", - ":mmap_allocation.cc", ":mutable_op_resolver.cc", ":mutable_op_resolver.h", ":op_resolver.h", @@ -482,22 +482,16 @@ cc_library( cc_library( name = "allocation", - srcs = [ - "allocation.cc", - ] + select({ - ":tflite_mmap_disabled": [ - "mmap_allocation_disabled.cc", - ], - "//conditions:default": [ - "mmap_allocation.cc", - ], - }), hdrs = [ "allocation.h", + "//tensorflow/compiler/mlir/lite:allocation.h", ], compatible_with = get_compatible_with_portable(), copts = tflite_copts_warnings(), - deps = ["//tensorflow/lite/core/api:error_reporter"], + deps = [ + "//tensorflow/compiler/mlir/lite:allocation", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + ], ) cc_library( @@ -830,7 +824,6 @@ cc_library( deps = [ ":minimal_logging", "//tensorflow/lite/core/api:error_reporter", - "//tensorflow/lite/core/c:common", ], ) diff --git a/tensorflow/lite/CMakeLists.txt b/tensorflow/lite/CMakeLists.txt index 09e9ed33c61626..1aa7dec994944a 100644 --- a/tensorflow/lite/CMakeLists.txt +++ b/tensorflow/lite/CMakeLists.txt @@ -275,12 +275,6 @@ list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*with_selected_ops\\.cc$") # Exclude tensorflow_profiler_logger files. list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*tensorflow_profiler_logger\\.cc$") -if(_TFLITE_ENABLE_MMAP) - list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$") -else() - list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation\\.cc$") -endif() - # Handle TFLite logging source. list(FILTER TFLITE_SRCS EXCLUDE REGEX ".*minimal_logging_.*\\.cc$") if("${CMAKE_SYSTEM_NAME}" STREQUAL "Android") @@ -373,7 +367,9 @@ if(TFLITE_ENABLE_GPU) list(APPEND TFLITE_DELEGATES_GPU_SRCS ${TFLITE_SOURCE_DIR}/delegates/gpu/api.cc ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate.h ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.cc + ${TFLITE_SOURCE_DIR}/delegates/gpu/delegate_options.h ${TFLITE_SOURCE_DIR}/delegates/gpu/tflite_profile.cc ${TFLITE_SOURCE_DIR}/experimental/acceleration/compatibility/android_info.cc ${TFLITE_DELEGATES_GPU_CL_SRCS} @@ -681,10 +677,25 @@ set(_ALL_TFLITE_SRCS ${TF_SOURCE_DIR}/compiler/mlir/lite/utils/string_utils.h ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.h ${TF_SOURCE_DIR}/compiler/mlir/lite/experimental/remat/metadata_util.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/macros.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.h ${TF_SOURCE_DIR}/compiler/mlir/lite/core/model_builder_base.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/error_reporter.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/core/api/verifier.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.h + ${TF_SOURCE_DIR}/compiler/mlir/lite/allocation.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/mmap_allocation.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/mmap_allocation_disabled.cc ${TFLITE_SOURCE_DIR}/schema/schema_generated.h ) + +if(_TFLITE_ENABLE_MMAP) + list(FILTER _ALL_TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation_disabled\\.cc$") +else() + list(FILTER _ALL_TFLITE_SRCS EXCLUDE REGEX ".*mmap_allocation\\.cc$") +endif() + add_library(tensorflow-lite ${_ALL_TFLITE_SRCS} ) @@ -774,6 +785,9 @@ set(TFLITE_GENERATED_HEADERS_DIR ${CMAKE_BINARY_DIR}/tensorflow/lite) # Add the profiling proto directory. add_subdirectory(${TFLITE_SOURCE_DIR}/profiling/proto) +# Add the tf example directory. +add_subdirectory(${TF_SOURCE_DIR}/core/example ${CMAKE_BINARY_DIR}/example_proto_generated) + # The benchmark tool. add_subdirectory(${TFLITE_SOURCE_DIR}/tools/benchmark) diff --git a/tensorflow/lite/allocation.h b/tensorflow/lite/allocation.h index 6840646a115310..b2a03a66ae36bf 100644 --- a/tensorflow/lite/allocation.h +++ b/tensorflow/lite/allocation.h @@ -18,136 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_ALLOCATION_H_ #define TENSORFLOW_LITE_ALLOCATION_H_ -#include - -#include -#include -#include - -#include "tensorflow/lite/core/api/error_reporter.h" - -namespace tflite { - -/// A memory allocation handle. This could be a mmap or shared memory. -class Allocation { - public: - virtual ~Allocation() {} - - enum class Type { - kMMap, - kFileCopy, - kMemory, - }; - - /// Base pointer of this allocation - virtual const void* base() const = 0; - /// Size in bytes of the allocation - virtual size_t bytes() const = 0; - /// Whether the allocation is valid - virtual bool valid() const = 0; - /// Return the type of the Allocation. - Type type() const { return type_; } - - protected: - Allocation(ErrorReporter* error_reporter, Type type) - : error_reporter_(error_reporter), type_(type) {} - ErrorReporter* error_reporter_; - - private: - const Type type_; -}; - -/// Note that not all platforms support MMAP-based allocation. -/// Use `IsSupported()` to check. -class MMAPAllocation : public Allocation { - public: - /// Loads and maps the provided file to a memory region. - MMAPAllocation(const char* filename, ErrorReporter* error_reporter); - - /// Maps the provided file descriptor to a memory region. - /// Note: The provided file descriptor will be dup'ed for usage; the caller - /// retains ownership of the provided descriptor and should close accordingly. - MMAPAllocation(int fd, ErrorReporter* error_reporter); - - /// Maps the provided file descriptor, with the given offset and length (both - /// in bytes), to a memory region. - /// Note: The provided file descriptor will be dup'ed for usage; the caller - /// retains ownership of the provided descriptor and should close accordingly. - MMAPAllocation(int fd, size_t offset, size_t length, - ErrorReporter* error_reporter); - - ~MMAPAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - int fd() const { return mmap_fd_; } - - // The start address of the mmapped buffer. - // This will be base() rounded down to the nearest page boundary. - const void* mmapped_buffer() const { return mmapped_buffer_; } - - // The size of the mmapped buffer. - size_t mmapped_buffer_size() const { return bytes() + offset_in_buffer_; } - - // Offset of mmapped_buffer() in the file referenced by the file descriptor. - size_t mmapped_buffer_offset_in_file() const { - return offset_of_buffer_in_file_; - } - - static bool IsSupported(); - - protected: - // Data required for mmap. - int mmap_fd_ = -1; // mmap file descriptor - const void* mmapped_buffer_; - size_t buffer_size_bytes_ = 0; - // Used when the address to mmap is not page-aligned. - size_t offset_in_buffer_ = 0; - size_t offset_of_buffer_in_file_ = 0; - - private: - // Assumes ownership of the provided `owned_fd` instance. - MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); - - // Assumes ownership of the provided `owned_fd` instance, and uses the given - // offset and length (both in bytes) for memory mapping. - MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, - size_t length); -}; - -class FileCopyAllocation : public Allocation { - public: - /// Loads the provided file into a heap memory region. - FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); - ~FileCopyAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - private: - std::unique_ptr copied_buffer_; - size_t buffer_size_bytes_ = 0; -}; - -class MemoryAllocation : public Allocation { - public: - /// Provides a (read-only) view of the provided buffer region as an - /// allocation. - /// Note: The caller retains ownership of `ptr`, and must ensure it remains - /// valid for the lifetime of the class instance. - MemoryAllocation(const void* ptr, size_t num_bytes, - ErrorReporter* error_reporter); - ~MemoryAllocation() override; - const void* base() const override; - size_t bytes() const override; - bool valid() const override; - - private: - const void* buffer_; - size_t buffer_size_bytes_ = 0; -}; - -} // namespace tflite +#include "tensorflow/compiler/mlir/lite/allocation.h" #endif // TENSORFLOW_LITE_ALLOCATION_H_ diff --git a/tensorflow/lite/c/BUILD b/tensorflow/lite/c/BUILD index 19cdd37ed4f549..f1664849f36e50 100644 --- a/tensorflow/lite/c/BUILD +++ b/tensorflow/lite/c/BUILD @@ -292,7 +292,10 @@ cc_test( size = "small", srcs = ["c_api_signature_runner_test.cc"], copts = tflite_copts(), - data = ["//tensorflow/lite:testdata/multi_signatures.bin"], + data = [ + "//tensorflow/lite:testdata/multi_signatures.bin", + "//tensorflow/lite:testdata/no_signatures.bin", + ], deps = [ ":c_api", "//tensorflow/lite/core/c:c_api", diff --git a/tensorflow/lite/c/c_api_signature_runner_test.cc b/tensorflow/lite/c/c_api_signature_runner_test.cc index 30614e5d7e59f5..61af71ffd863a6 100644 --- a/tensorflow/lite/c/c_api_signature_runner_test.cc +++ b/tensorflow/lite/c/c_api_signature_runner_test.cc @@ -24,6 +24,94 @@ limitations under the License. namespace tflite { namespace { +TEST(SignatureRunnerTest, TestNoSignatures) { + TfLiteModel* model = TfLiteModelCreateFromFile( + "tensorflow/lite/testdata/no_signatures.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreter* interpreter = + TfLiteInterpreterCreate(model, /*optional_options=*/nullptr); + ASSERT_NE(interpreter, nullptr); + + int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter); + ASSERT_EQ(nun_signatures, 0); + + ASSERT_EQ(TfLiteInterpreterGetSignatureRunner(interpreter, "foo"), nullptr); + + TfLiteSignatureRunner* runner = + TfLiteInterpreterGetSignatureRunner(interpreter, nullptr); + ASSERT_NE(runner, nullptr); + + int num_interpreter_inputs = + TfLiteInterpreterGetInputTensorCount(interpreter); + int num_runner_inputs = TfLiteSignatureRunnerGetInputCount(runner); + ASSERT_EQ(num_runner_inputs, num_interpreter_inputs); + + for (int i = 0; i < num_interpreter_inputs; ++i) { + auto* interpreter_input_tensor = + TfLiteInterpreterGetInputTensor(interpreter, i); + ASSERT_NE(interpreter_input_tensor, nullptr); + auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor); + ASSERT_NE(interpreter_input_name, nullptr); + auto* runner_input_name = TfLiteSignatureRunnerGetInputName(runner, i); + ASSERT_NE(runner_input_name, nullptr); + EXPECT_STREQ(runner_input_name, interpreter_input_name); + auto* runner_input_tensor = + TfLiteSignatureRunnerGetInputTensor(runner, interpreter_input_name); + ASSERT_NE(runner_input_tensor, nullptr); + ASSERT_EQ(runner_input_tensor, interpreter_input_tensor); + } + + int num_interpreter_outputs = + TfLiteInterpreterGetOutputTensorCount(interpreter); + int num_runner_outputs = TfLiteSignatureRunnerGetOutputCount(runner); + ASSERT_EQ(num_runner_outputs, num_interpreter_outputs); + + for (int i = 0; i < num_interpreter_outputs; ++i) { + auto* interpreter_output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, i); + ASSERT_NE(interpreter_output_tensor, nullptr); + auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor); + ASSERT_NE(interpreter_output_name, nullptr); + auto* runner_output_name = TfLiteSignatureRunnerGetOutputName(runner, i); + ASSERT_NE(runner_output_name, nullptr); + EXPECT_STREQ(runner_output_name, interpreter_output_name); + auto* runner_output_tensor = + TfLiteSignatureRunnerGetOutputTensor(runner, interpreter_output_name); + ASSERT_NE(runner_output_tensor, nullptr); + ASSERT_EQ(runner_output_tensor, interpreter_output_tensor); + } + + std::array input_dims{2}; + ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor( + runner, "x1", input_dims.data(), input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TfLiteSignatureRunnerResizeInputTensor( + runner, "x2", input_dims.data(), input_dims.size()), + kTfLiteOk); + ASSERT_EQ(TfLiteSignatureRunnerAllocateTensors(runner), kTfLiteOk); + TfLiteTensor* input1 = TfLiteSignatureRunnerGetInputTensor(runner, "x1"); + ASSERT_NE(input1, nullptr); + TfLiteTensor* input2 = TfLiteSignatureRunnerGetInputTensor(runner, "x2"); + ASSERT_NE(input2, nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetInputTensor(runner, "foo"), nullptr); + const TfLiteTensor* output = + TfLiteSignatureRunnerGetOutputTensor(runner, "Identity"); + ASSERT_NE(output, nullptr); + ASSERT_EQ(TfLiteSignatureRunnerGetOutputTensor(runner, "foo"), nullptr); + input1->data.f[0] = -8; + input1->data.f[1] = 0.5; + input2->data.f[0] = -1; + input2->data.f[1] = 1.5; + ASSERT_EQ(TfLiteSignatureRunnerInvoke(runner), kTfLiteOk); + ASSERT_EQ(output->data.f[0], 0); + ASSERT_EQ(output->data.f[1], 2); + + TfLiteSignatureRunnerDelete(runner); + TfLiteInterpreterDelete(interpreter); + TfLiteModelDelete(model); +} + TEST(SignatureRunnerTest, TestMultiSignatures) { TfLiteModel* model = TfLiteModelCreateFromFile( "tensorflow/lite/testdata/multi_signatures.bin"); diff --git a/tensorflow/lite/core/BUILD b/tensorflow/lite/core/BUILD index d3939b91f911ea..4309e28baf8e38 100644 --- a/tensorflow/lite/core/BUILD +++ b/tensorflow/lite/core/BUILD @@ -43,9 +43,7 @@ cc_library( ], compatible_with = get_compatible_with_portable(), copts = tflite_copts() + tflite_copts_warnings(), - visibility = [ - "//tensorflow/lite:__subpackages__", - ], + visibility = ["//tensorflow/lite:__subpackages__"], deps = [ ":cc_api_stable", ":signature_runner", diff --git a/tensorflow/lite/core/api/BUILD b/tensorflow/lite/core/api/BUILD index 08ac033fcb0f77..6613d1c3e14c96 100644 --- a/tensorflow/lite/core/api/BUILD +++ b/tensorflow/lite/core/api/BUILD @@ -11,7 +11,6 @@ package( filegroup( name = "tflite_internal_cc_3p_api_deps_src", srcs = [ - ":error_reporter.cc", ":error_reporter.h", ":op_resolver.cc", ":op_resolver.h", @@ -68,32 +67,41 @@ cc_library( ], deps = [ ":error_reporter", + "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/lite/core/c:common", "//tensorflow/lite/schema:schema_fbs", - "//tensorflow/lite/schema:schema_utils", - "@flatbuffers//:runtime_cc", ], ) cc_library( name = "error_reporter", - srcs = ["error_reporter.cc"], - hdrs = ["error_reporter.h"], + hdrs = [ + "error_reporter.h", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter.h", + ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = [ "//visibility:public", ], - deps = [], + deps = [ + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + ], ) cc_library( name = "verifier", - hdrs = ["verifier.h"], + hdrs = [ + "verifier.h", + "//tensorflow/compiler/mlir/lite/core/api:verifier.h", + ], compatible_with = get_compatible_with_portable(), copts = tflite_copts(), visibility = ["//visibility:public"], - deps = [":error_reporter"], + deps = [ + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/core/api:verifier", + ], ) cc_library( @@ -108,24 +116,19 @@ cc_library( deps = [":op_resolver"], ) -cc_test( - name = "error_reporter_test", - size = "small", - srcs = ["error_reporter_test.cc"], - deps = [ - ":api", - "@com_google_googletest//:gtest_main", - ], -) - cc_test( name = "op_resolver_test", size = "small", srcs = ["op_resolver_test.cc"], deps = [ ":api", - "//tensorflow/lite/schema:schema_conversion_utils", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", + "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", + "//tensorflow/lite/schema:schema_fbs", "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", ], ) @@ -136,7 +139,6 @@ cc_test( deps = [ ":op_resolver", ":op_resolver_internal", - "//tensorflow/lite:builtin_ops", "//tensorflow/lite:framework", "//tensorflow/lite:mutable_op_resolver", "//tensorflow/lite/core/kernels:builtin_ops", @@ -151,6 +153,7 @@ cc_test( srcs = ["flatbuffer_conversions_test.cc"], deps = [ ":api", + "//tensorflow/compiler/mlir/lite/core/api:error_reporter", "//tensorflow/lite:string", "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/c:c_api_types", diff --git a/tensorflow/lite/core/api/error_reporter.h b/tensorflow/lite/core/api/error_reporter.h index 1e0ef7dc913a44..f9106046b2f231 100644 --- a/tensorflow/lite/core/api/error_reporter.h +++ b/tensorflow/lite/core/api/error_reporter.h @@ -15,58 +15,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ #define TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ -#include - -namespace tflite { - -/// A functor that reports error to supporting system. Invoked similar to -/// printf. -/// -/// Usage: -/// ErrorReporter foo; -/// foo.Report("test %d", 5); -/// or -/// va_list args; -/// foo.Report("test %d", args); // where args is va_list -/// -/// Subclass ErrorReporter to provide another reporting destination. -/// For example, if you have a GUI program, you might redirect to a buffer -/// that drives a GUI error log box. -class ErrorReporter { - public: - virtual ~ErrorReporter() = default; - /// Converts `args` to character equivalents according to `format` string, - /// constructs the error string and report it. - /// Returns number of characters written or zero on success, and negative - /// number on error. - virtual int Report(const char* format, va_list args) = 0; - - /// Converts arguments to character equivalents according to `format` string, - /// constructs the error string and report it. - /// Returns number of characters written or zero on success, and negative - /// number on error. - int Report(const char* format, ...); - - /// Equivalent to `Report` above. The additional `void*` parameter is unused. - /// This method is for compatibility with macros that takes `TfLiteContext`, - /// like TF_LITE_ENSURE and related macros. - int ReportError(void*, const char* format, ...); -}; - -} // namespace tflite - -// You should not make bare calls to the error reporter, instead use the -// TF_LITE_REPORT_ERROR macro, since this allows message strings to be -// stripped when the binary size has to be optimized. If you are looking to -// reduce binary size, define TF_LITE_STRIP_ERROR_STRINGS when compiling and -// every call will be stubbed out, taking no memory. -#ifndef TF_LITE_STRIP_ERROR_STRINGS -#define TF_LITE_REPORT_ERROR(reporter, ...) \ - do { \ - static_cast<::tflite::ErrorReporter*>(reporter)->Report(__VA_ARGS__); \ - } while (false) -#else // TF_LITE_STRIP_ERROR_STRINGS -#define TF_LITE_REPORT_ERROR(reporter, ...) -#endif // TF_LITE_STRIP_ERROR_STRINGS +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_API_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 35268103be8792..c27e4e6f8b82a9 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -20,9 +20,8 @@ limitations under the License. #include #include -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "flatbuffers/vector.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.h b/tensorflow/lite/core/api/flatbuffer_conversions.h index c01e8875813f93..de287af21c8a5b 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/lite/core/api/flatbuffer_conversions.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc index 87c897dfc0928e..98c8c910ac1d84 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions_test.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/op_resolver.cc b/tensorflow/lite/core/api/op_resolver.cc index ce5ae4f406eb6a..214490c874d7ad 100644 --- a/tensorflow/lite/core/api/op_resolver.cc +++ b/tensorflow/lite/core/api/op_resolver.cc @@ -15,10 +15,10 @@ limitations under the License. #include "tensorflow/lite/core/api/op_resolver.h" -#include "flatbuffers/flatbuffers.h" // from @flatbuffers -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include "tensorflow/lite/core/c/common.h" -#include "tensorflow/lite/schema/schema_utils.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/op_resolver.h b/tensorflow/lite/core/api/op_resolver.h index 7aff7cafea1783..f6f5fd214d187a 100644 --- a/tensorflow/lite/core/api/op_resolver.h +++ b/tensorflow/lite/core/api/op_resolver.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/schema/schema_generated.h" diff --git a/tensorflow/lite/core/api/op_resolver_internal_test.cc b/tensorflow/lite/core/api/op_resolver_internal_test.cc index d052e9c7bab8ee..b62df374c483ef 100644 --- a/tensorflow/lite/core/api/op_resolver_internal_test.cc +++ b/tensorflow/lite/core/api/op_resolver_internal_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/lite/core/kernels/builtin_op_kernels.h" #include "tensorflow/lite/core/kernels/register.h" #include "tensorflow/lite/mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { diff --git a/tensorflow/lite/core/api/op_resolver_test.cc b/tensorflow/lite/core/api/op_resolver_test.cc index 45fcdcf81dac18..59b08ad21864dc 100644 --- a/tensorflow/lite/core/api/op_resolver_test.cc +++ b/tensorflow/lite/core/api/op_resolver_test.cc @@ -18,7 +18,13 @@ limitations under the License. #include #include -#include "tensorflow/lite/schema/schema_conversion_utils.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/schema/schema_generated.h" namespace tflite { namespace { diff --git a/tensorflow/lite/core/api/verifier.h b/tensorflow/lite/core/api/verifier.h index 8128ff31e1ea85..dcb1d029b5678a 100644 --- a/tensorflow/lite/core/api/verifier.h +++ b/tensorflow/lite/core/api/verifier.h @@ -18,22 +18,6 @@ limitations under the License. #ifndef TENSORFLOW_LITE_CORE_API_VERIFIER_H_ #define TENSORFLOW_LITE_CORE_API_VERIFIER_H_ -#include "tensorflow/lite/core/api/error_reporter.h" - -namespace tflite { - -/// Abstract interface that verifies whether a given model is legit. -/// It facilitates the use-case to verify and build a model without loading it -/// twice. -/// (See also "tensorflow/lite/tools/verifier.h".) -class TfLiteVerifier { - public: - /// Returns true if the model is legit. - virtual bool Verify(const char* data, int length, - ErrorReporter* reporter) = 0; - virtual ~TfLiteVerifier() {} -}; - -} // namespace tflite +#include "tensorflow/compiler/mlir/lite/core/api/verifier.h" // IWYU pragma: export #endif // TENSORFLOW_LITE_CORE_API_VERIFIER_H_ diff --git a/tensorflow/lite/core/async/BUILD b/tensorflow/lite/core/async/BUILD index 625104252899a1..ca2f3caac2906a 100644 --- a/tensorflow/lite/core/async/BUILD +++ b/tensorflow/lite/core/async/BUILD @@ -38,8 +38,9 @@ cc_test( name = "task_internal_test", srcs = ["task_internal_test.cc"], deps = [ - ":async_kernel_internal", ":task_internal", + "//tensorflow/lite/c:c_api_types", + "//tensorflow/lite/c:common", "//tensorflow/lite/core/async/c:types", "//tensorflow/lite/core/async/interop/c:types", "@com_google_googletest//:gtest_main", diff --git a/tensorflow/lite/core/async/async_signature_runner_test.cc b/tensorflow/lite/core/async/async_signature_runner_test.cc index bb5e23b31111d1..3eb075ac143b35 100644 --- a/tensorflow/lite/core/async/async_signature_runner_test.cc +++ b/tensorflow/lite/core/async/async_signature_runner_test.cc @@ -183,7 +183,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, GetAsyncSignatureRunner) { TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) { signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr); EXPECT_EQ(1, signature_runner_->input_size()); - EXPECT_EQ(0, signature_runner_->input_names().size()); + EXPECT_EQ(1, signature_runner_->input_names().size()); EXPECT_EQ(1, signature_runner_->inputs().size()); EXPECT_NE(nullptr, signature_runner_->tensor(signature_runner_->inputs()[0])); @@ -192,7 +192,7 @@ TEST_F(AsyncSignatureRunnerNoSignatureDefTest, InputsTest) { TEST_F(AsyncSignatureRunnerNoSignatureDefTest, OutputsTest) { signature_runner_ = interpreter_->GetAsyncSignatureRunner(nullptr); EXPECT_EQ(1, signature_runner_->output_size()); - EXPECT_EQ(0, signature_runner_->output_names().size()); + EXPECT_EQ(1, signature_runner_->output_names().size()); EXPECT_EQ(1, signature_runner_->outputs().size()); EXPECT_NE(nullptr, diff --git a/tensorflow/lite/core/async/c/BUILD b/tensorflow/lite/core/async/c/BUILD index e9a8bf9ae6c7cc..0f6bb9c62bc2d8 100644 --- a/tensorflow/lite/core/async/c/BUILD +++ b/tensorflow/lite/core/async/c/BUILD @@ -118,6 +118,9 @@ cc_test( name = "async_signature_runner_test", srcs = ["async_signature_runner_test.cc"], copts = tflite_copts() + tflite_copts_warnings(), + data = [ + "//tensorflow/lite:testdata/no_signatures.bin", + ], deps = [ ":async_signature_runner", ":internal", diff --git a/tensorflow/lite/core/async/c/async_signature_runner_test.cc b/tensorflow/lite/core/async/c/async_signature_runner_test.cc index 2648e5028ed84b..1e2b54dacd55f3 100644 --- a/tensorflow/lite/core/async/c/async_signature_runner_test.cc +++ b/tensorflow/lite/core/async/c/async_signature_runner_test.cc @@ -182,9 +182,10 @@ TEST_P(AsyncSignatureRunnerTest, InputsTest) { "x", TfLiteOpaqueTensorName( TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input"))); } else { - EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetInputName(runner_, 0)); - EXPECT_EQ(nullptr, - TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "input")); + EXPECT_STREQ("x", TfLiteAsyncSignatureRunnerGetInputName(runner_, 0)); + EXPECT_STREQ("x", + TfLiteOpaqueTensorName( + TfLiteAsyncSignatureRunnerGetInputTensor(runner_, "x"))); } } @@ -198,9 +199,10 @@ TEST_P(AsyncSignatureRunnerTest, OutputsTest) { "a", TfLiteOpaqueTensorName( TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output"))); } else { - EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0)); - EXPECT_EQ(nullptr, - TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "output")); + EXPECT_STREQ("a", TfLiteAsyncSignatureRunnerGetOutputName(runner_, 0)); + EXPECT_STREQ("a", + TfLiteOpaqueTensorName( + TfLiteAsyncSignatureRunnerGetOutputTensor(runner_, "a"))); } } @@ -229,5 +231,93 @@ TEST_P(AsyncSignatureRunnerTest, IndexOutOfBound) { EXPECT_EQ(nullptr, TfLiteAsyncSignatureRunnerGetTensor(runner_, 42)); } +TEST(AsyncSignatureRunnerTest, TestNoSignatures) { + TfLiteModel* model = TfLiteModelCreateFromFile( + "third_party/tensorflow/lite/testdata/no_signatures.bin"); + ASSERT_NE(model, nullptr); + + TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate(); + ASSERT_NE(options, nullptr); + auto kernel = + std::make_unique<::testing::StrictMock>(); + auto backend = std::make_unique(kernel->kernel()); + TfLiteInterpreterOptionsAddDelegate(options, backend->get_delegate()); + + TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options); + ASSERT_NE(interpreter, nullptr); + + TfLiteInterpreterOptionsDelete(options); + + int nun_signatures = TfLiteInterpreterGetSignatureCount(interpreter); + ASSERT_EQ(nun_signatures, 0); + + ASSERT_EQ(TfLiteInterpreterGetAsyncSignatureRunner(interpreter, "foo"), + nullptr); + + TfLiteAsyncSignatureRunner* runner = + TfLiteInterpreterGetAsyncSignatureRunner(interpreter, nullptr); + ASSERT_NE(runner, nullptr); + + int num_interpreter_inputs = + TfLiteInterpreterGetInputTensorCount(interpreter); + int num_runner_inputs = TfLiteAsyncSignatureRunnerGetInputCount(runner); + ASSERT_EQ(num_runner_inputs, num_interpreter_inputs); + + for (int i = 0; i < num_interpreter_inputs; ++i) { + auto* interpreter_input_tensor = + TfLiteInterpreterGetInputTensor(interpreter, i); + ASSERT_NE(interpreter_input_tensor, nullptr); + auto* interpreter_input_name = TfLiteTensorName(interpreter_input_tensor); + ASSERT_NE(interpreter_input_name, nullptr); + auto* runner_input_name = TfLiteAsyncSignatureRunnerGetInputName(runner, i); + ASSERT_NE(runner_input_name, nullptr); + EXPECT_STREQ(runner_input_name, interpreter_input_name); + auto* runner_input_tensor = TfLiteAsyncSignatureRunnerGetInputTensor( + runner, interpreter_input_name); + ASSERT_NE(runner_input_tensor, nullptr); + ASSERT_EQ(runner_input_tensor, reinterpret_cast( + interpreter_input_tensor)); + } + + int num_interpreter_outputs = + TfLiteInterpreterGetOutputTensorCount(interpreter); + int num_runner_outputs = TfLiteAsyncSignatureRunnerGetOutputCount(runner); + ASSERT_EQ(num_runner_outputs, num_interpreter_outputs); + + for (int i = 0; i < num_interpreter_outputs; ++i) { + auto* interpreter_output_tensor = + TfLiteInterpreterGetOutputTensor(interpreter, i); + ASSERT_NE(interpreter_output_tensor, nullptr); + auto* interpreter_output_name = TfLiteTensorName(interpreter_output_tensor); + ASSERT_NE(interpreter_output_name, nullptr); + auto* runner_output_name = + TfLiteAsyncSignatureRunnerGetOutputName(runner, i); + ASSERT_NE(runner_output_name, nullptr); + EXPECT_STREQ(runner_output_name, interpreter_output_name); + auto* runner_output_tensor = TfLiteAsyncSignatureRunnerGetOutputTensor( + runner, interpreter_output_name); + ASSERT_NE(runner_output_tensor, nullptr); + ASSERT_EQ(runner_output_tensor, reinterpret_cast( + interpreter_output_tensor)); + } + + EXPECT_CALL(*kernel, Prepare(_, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Eval(_, _, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Wait(_, _)).WillOnce(Return(kTfLiteOk)); + EXPECT_CALL(*kernel, Finish(_, _)).WillOnce(Return(kTfLiteOk)); + + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerPrepareBackends(runner)); + + auto* task = TfLiteAsyncSignatureRunnerCreateTask(runner); + + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerInvokeAsync(runner, task)); + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerWait(runner, task)); + EXPECT_EQ(kTfLiteOk, TfLiteAsyncSignatureRunnerFinish(runner, task)); + + TfLiteAsyncSignatureRunnerDelete(runner); + TfLiteInterpreterDelete(interpreter); + TfLiteModelDelete(model); +} + } // namespace async } // namespace tflite diff --git a/tensorflow/lite/core/async/task_internal_test.cc b/tensorflow/lite/core/async/task_internal_test.cc index d63eb03e89767f..b0dc1ae385917f 100644 --- a/tensorflow/lite/core/async/task_internal_test.cc +++ b/tensorflow/lite/core/async/task_internal_test.cc @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/lite/core/async/async_kernel_internal.h" +#include "tensorflow/lite/c/c_api_types.h" +#include "tensorflow/lite/c/common.h" #include "tensorflow/lite/core/async/c/types.h" #include "tensorflow/lite/core/async/interop/c/types.h" diff --git a/tensorflow/lite/core/c/common.h b/tensorflow/lite/core/c/common.h index 96f19f12336bc4..648b8623960f92 100644 --- a/tensorflow/lite/core/c/common.h +++ b/tensorflow/lite/core/c/common.h @@ -100,7 +100,9 @@ typedef struct TfLiteExternalContext { TfLiteStatus (*Refresh)(struct TfLiteContext* context); } TfLiteExternalContext; +// LINT.IfChange(optional_tensor) #define kTfLiteOptionalTensor (-1) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/flatbuffer_export.cc:optional_tensor) /// Fixed size list of integers. Used for dimensions and inputs/outputs tensor /// indices diff --git a/tensorflow/lite/core/interpreter.cc b/tensorflow/lite/core/interpreter.cc index 9d1623c5b4821f..dd9feb9b90c3be 100644 --- a/tensorflow/lite/core/interpreter.cc +++ b/tensorflow/lite/core/interpreter.cc @@ -520,7 +520,13 @@ void Interpreter::AddProfiler(std::unique_ptr profiler) { } impl::SignatureRunner* Interpreter::GetSignatureRunner( - const char* signature_key) { + const char* signature_key_) { + auto [signature_key, empty_signature_fallback] = + ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_); + if (!signature_key) { + return nullptr; + } + auto iter = signature_runner_map_.find(signature_key); if (iter != signature_runner_map_.end()) { return &(iter->second); @@ -533,6 +539,14 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner( return nullptr; } + if (empty_signature_fallback) { + placeholder_signature_def_ = CreatePlaceholderSignatureDef(); + auto status = signature_runner_map_.insert( + {signature_key, SignatureRunner(placeholder_signature_def_.get(), + &primary_subgraph())}); + return &(status.first->second); + } + for (const auto& signature : signature_defs_) { if (signature.signature_key == signature_key) { auto status = signature_runner_map_.insert( @@ -541,7 +555,56 @@ impl::SignatureRunner* Interpreter::GetSignatureRunner( return &(status.first->second); } } + return nullptr; } +std::unique_ptr +Interpreter::CreatePlaceholderSignatureDef() { + auto placeholder_signature_def = std::make_unique(); + for (auto i = 0; i < inputs().size(); ++i) { + auto* name = GetInputName(i); + placeholder_signature_def->inputs[name] = inputs()[i]; + } + for (auto i = 0; i < outputs().size(); ++i) { + auto* name = GetOutputName(i); + placeholder_signature_def->outputs[name] = outputs()[i]; + } + placeholder_signature_def->signature_key = kPlaceholderSignatureDefKey; + placeholder_signature_def->subgraph_index = 0; + return placeholder_signature_def; +} + +std::pair +Interpreter::ReplaceWithPlaceholderSignatureKeyIfNeeded( + const char* signature_key) { + // Handles nullptr signature key. + // If the model does not have signature def, use default name as placeholder. + // Otherwise use the first signature key that points to primary subgraph. + bool empty_signature_fallback = false; + if (signature_key == nullptr) { + if (signature_defs_.empty()) { + signature_key = kPlaceholderSignatureDefKey; + empty_signature_fallback = true; + } else { + for (const auto& signature : signature_defs_) { + if (signature.subgraph_index == 0) { + signature_key = signature.signature_key.c_str(); + break; + } + } + } + } + + if (signature_key == nullptr) { + // The model has signature def but none of those points to primary subgraph. + TF_LITE_REPORT_ERROR(error_reporter_, + "The model has signature def but none of those points " + "to primary subgraph."); + return {nullptr, empty_signature_fallback}; + } else { + return {signature_key, empty_signature_fallback}; + } +} + } // namespace tflite diff --git a/tensorflow/lite/core/interpreter.h b/tensorflow/lite/core/interpreter.h index 4a3fb131da3c14..f26a15dcd0b9b8 100644 --- a/tensorflow/lite/core/interpreter.h +++ b/tensorflow/lite/core/interpreter.h @@ -335,21 +335,25 @@ class Interpreter { } /// \brief Returns a pointer to the SignatureRunner instance to run the part - /// of the graph identified by a SignatureDef. The nullptr is returned if the - /// given signature key is not valid. + /// of the graph identified by a SignatureDef. If the model does not have any + /// signature defs, pass nullptr as signature_key and a SignatureRunner will + /// be created using the primary subgraph (0). A nullptr is returned if the + /// given signature_key is not valid. Note, the returned SignatureRunner + /// instance is owned by and has the same lifetime as the Interpreter object; + /// additionally, class SignatureRunner is *not* thread-safe. /// If you need to specify delegates, you have to do that before calling this /// function. This function will additionally apply default delegates. Thus, /// applying delegates after that might lead to undesirable behaviors. - /// Note, the pointed instance has lifetime same as the Interpreter object - /// and the SignatureRunner class is *not* thread-safe. SignatureRunner* GetSignatureRunner(const char* signature_key); - /// \warning Experimental interface, subject to change. \n - /// \brief Returns a pointer to the AsyncSignatureRunner instance to run the - /// part of the graph identified by a SignatureDef. The nullptr is returned if - /// the given signature key is not valid. - /// if the model does not have signature def, pass nullptr to signature_key - /// and AsyncSignatureRunner will be created using primary subgraph (0). + /// \warning Experimental interface, subject to change. \n \brief Returns a + /// pointer to the AsyncSignatureRunner instance to run the part of the graph + /// identified by a SignatureDef. If the model does not have any signature + /// defs, pass nullptr as signature_key and an AsyncSignatureRunner will be + /// created using the primary subgraph (0). A nullptr is returned if the + /// given signature_key is not valid. Note, the returned AsyncSignatureRunner + /// instance is owned by and has the same lifetime as the Interpreter object; + /// additionally, class AsyncSignatureRunner is *not* thread-safe. /// The async delegate should be applied before calling this function. async::AsyncSignatureRunner* GetAsyncSignatureRunner( const char* signature_key); @@ -905,6 +909,10 @@ class Interpreter { TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options); + std::unique_ptr CreatePlaceholderSignatureDef(); + std::pair ReplaceWithPlaceholderSignatureKeyIfNeeded( + const char* signature_key); + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -964,6 +972,13 @@ class Interpreter { // List of SignatureDefs obtained from the model. std::vector signature_defs_; + // Default signature key to use when the model has no signatures. + static constexpr char kPlaceholderSignatureDefKey[] = + ""; + + // Placeholder SignatureDef for legacy models with no signatures. + std::unique_ptr placeholder_signature_def_; + // Map of signature key to its corresponding SignatureRunner object. // A SignatureRunner is basically a wrapper of the Subgraph corresponding to // its SignatureDef. diff --git a/tensorflow/lite/core/interpreter_experimental.cc b/tensorflow/lite/core/interpreter_experimental.cc index 7eef090791df8f..4a7bca720d8239 100644 --- a/tensorflow/lite/core/interpreter_experimental.cc +++ b/tensorflow/lite/core/interpreter_experimental.cc @@ -34,10 +34,6 @@ limitations under the License. namespace tflite { -namespace { -static constexpr char kDefaultServingSignatureDefKey[] = "serving_default"; -} // namespace - TfLiteStatus Interpreter::SetCustomAllocationForTensor( int tensor_index, const TfLiteCustomAllocation& allocation, int64_t flags) { return primary_subgraph().SetCustomAllocationForTensor(tensor_index, @@ -145,27 +141,10 @@ TfLiteStatus Interpreter::ApplyOptions(InterpreterOptions* options) { } async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner( - const char* signature_key) { - // Handles nullptr signature key. - // If the model does not have signature def, use default name as placeholder. - // Otherwise use the first signature key that points to primary subgraph. - bool empty_signature_fallback = false; - if (signature_key == nullptr) { - if (signature_defs_.empty()) { - signature_key = kDefaultServingSignatureDefKey; - empty_signature_fallback = true; - } else { - for (const auto& signature : signature_defs_) { - if (signature.subgraph_index == 0) { - signature_key = signature.signature_key.c_str(); - break; - } - } - } - } - - if (signature_key == nullptr) { - // The model has signature def but none of those points to primary subgraph. + const char* signature_key_) { + auto [signature_key, empty_signature_fallback] = + ReplaceWithPlaceholderSignatureKeyIfNeeded(signature_key_); + if (!signature_key) { return nullptr; } @@ -175,11 +154,14 @@ async::AsyncSignatureRunner* Interpreter::GetAsyncSignatureRunner( } if (empty_signature_fallback) { + placeholder_signature_def_ = CreatePlaceholderSignatureDef(); auto status = async_signature_runner_map_.insert( {signature_key, - async::AsyncSignatureRunner(nullptr, &primary_subgraph())}); + async::AsyncSignatureRunner(placeholder_signature_def_.get(), + &primary_subgraph())}); return &(status.first->second); } + for (const auto& signature : signature_defs_) { if (signature.signature_key == signature_key) { auto status = async_signature_runner_map_.insert( diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 87f701f5adc427..d5ae8d056bde0d 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -364,6 +364,24 @@ tf_cc_test( ], ) +tf_cc_test( + name = "allowlisted_flex_ops_test", + size = "small", + srcs = [ + "allowlisted_flex_ops_test.cc", + ], + features = tf_features_nolayering_check_if_ios(), + deps = [ + ":delegate", + "//tensorflow/compiler/mlir/lite/delegates/flex:allowlisted_flex_ops_lib", + "@com_google_googletest//:gtest_main", + ] + if_mobile([ + "//tensorflow/core:portable_tensorflow_lib_lite", + ]) + if_not_mobile([ + "//tensorflow/core:framework", + ]), +) + # Alias to support selective build of image ops. # TODO(b/163285312): Remove after tensorflow/core refactoring completed. cc_library( diff --git a/tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_test.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops_test.cc similarity index 100% rename from tensorflow/compiler/mlir/lite/delegates/flex/allowlisted_flex_ops_test.cc rename to tensorflow/lite/delegates/flex/allowlisted_flex_ops_test.cc diff --git a/tensorflow/lite/delegates/gpu/BUILD b/tensorflow/lite/delegates/gpu/BUILD index 2be6504b9a5878..d66d66b544a608 100644 --- a/tensorflow/lite/delegates/gpu/BUILD +++ b/tensorflow/lite/delegates/gpu/BUILD @@ -85,6 +85,7 @@ cc_library( "//tensorflow/lite:minimal_logging", "//tensorflow/lite/core/c:common", "//tensorflow/lite/delegates/gpu/common:convert", + "//tensorflow/lite/delegates/gpu/common:gpu_info", "//tensorflow/lite/delegates/gpu/common:model", "//tensorflow/lite/delegates/gpu/common:model_builder", "//tensorflow/lite/delegates/gpu/common:model_transformer", @@ -95,9 +96,14 @@ cc_library( "//tensorflow/lite/delegates/gpu/gl:api", "//tensorflow/lite/delegates/gpu/gl:command_queue", "//tensorflow/lite/delegates/gpu/gl:compiler", + "//tensorflow/lite/delegates/gpu/gl:compiler_options", "//tensorflow/lite/delegates/gpu/gl:egl_environment", + "//tensorflow/lite/delegates/gpu/gl:gl_buffer", "//tensorflow/lite/delegates/gpu/gl:gl_call", + "//tensorflow/lite/delegates/gpu/gl:object", + "//tensorflow/lite/delegates/gpu/gl:object_manager", "//tensorflow/lite/delegates/gpu/gl:request_gpu_info", + "//tensorflow/lite/delegates/gpu/gl:runtime_options", "//tensorflow/lite/delegates/gpu/gl/converters:bhwc_to_phwc4", "//tensorflow/lite/delegates/gpu/gl/converters:phwc4_to_bhwc", "//tensorflow/lite/delegates/gpu/gl/kernels:registry", diff --git a/tensorflow/lite/delegates/gpu/cl/BUILD b/tensorflow/lite/delegates/gpu/cl/BUILD index 73f192d17ebf0c..b84cb9a71a46f0 100644 --- a/tensorflow/lite/delegates/gpu/cl/BUILD +++ b/tensorflow/lite/delegates/gpu/cl/BUILD @@ -299,6 +299,7 @@ cc_library( ":cl_kernel", ":program_cache", ":tensor", + "//tensorflow/lite/delegates/gpu/common/task:compiler_options", "//tensorflow/lite/delegates/gpu/common/task:gpu_operation", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc index 1cc1738d071d44..8fd94938b57258 100644 --- a/tensorflow/lite/delegates/gpu/cl/cl_operation.cc +++ b/tensorflow/lite/delegates/gpu/cl/cl_operation.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "tensorflow/lite/delegates/gpu/common/task/compiler_options.h" + namespace tflite { namespace gpu { namespace cl { @@ -165,6 +167,10 @@ absl::Status ClOperation::Compile(const CreationContext& creation_context) { creation_context.context, &operation_->args_, &operation_->code_)); operation_->args_.ReleaseCPURepresentation(); + if (creation_context.device->info_.opencl_info.IsCLVK()) { + operation_->compiler_options_.push_back( + CompilerOptions::kClFastRelaxedMath); + } RETURN_IF_ERROR(creation_context.cache->GetOrCreateCLKernel( operation_->code_, "main_function", operation_->compiler_options_, *creation_context.context, *creation_context.device, &kernel_, diff --git a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc index 7a40a609e3fce7..5bd407de4c8b78 100644 --- a/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc +++ b/tensorflow/lite/delegates/gpu/common/memory_management/greedy_by_size_assignment.cc @@ -75,6 +75,7 @@ absl::Status GreedyBySizeAssignment( // Ordered records are to be sorted by size of corresponding tensor. std::vector> ordered_records; + ordered_records.reserve(num_tensors); for (size_t i = 0; i < num_tensors; ++i) { ordered_records.emplace_back(&usage_records[i], i); } diff --git a/tensorflow/lite/delegates/gpu/common/model.cc b/tensorflow/lite/delegates/gpu/common/model.cc index dc68e702ac2328..a7a174f60f54d2 100644 --- a/tensorflow/lite/delegates/gpu/common/model.cc +++ b/tensorflow/lite/delegates/gpu/common/model.cc @@ -333,10 +333,16 @@ absl::Status GraphFloat32::MakeExactCopy(GraphFloat32* model) const { model->nodes_.clear(); model->execution_plan_.clear(); model->values_.clear(); + model->known_graph_outputs_.clear(); for (auto& value_def : values_) { model->values_.push_back({}); if (value_def.value) { model->values_.back().value = std::make_unique(*value_def.value); + if (std::find(known_graph_outputs_.begin(), known_graph_outputs_.end(), + value_def.value.get()) != known_graph_outputs_.end()) { + model->known_graph_outputs_.push_back( + model->values_.back().value.get()); + } } } // Add all nodes first. diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc index ae3e4e5438a5d4..804eac531e26f9 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_bias.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_bias.h" +#include #include #include #include @@ -59,12 +60,12 @@ class AddBias : public NodeTransformation { "runtime input."}; } auto& attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } @@ -76,17 +77,17 @@ class AddBias : public NodeTransformation { "with one " "runtime input."}; } - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o * attr.weights.shape.i, &attr.bias); } if (node->operation.type == ToString(OperationType::FULLY_CONNECTED)) { auto& attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } if (node->operation.type == ToString(OperationType::FULLY_CONNECTED_INT8)) { - auto& attr = absl::any_cast( + auto& attr = std::any_cast( node->operation.attributes); return FillBias(attr.weights.shape.o, &attr.bias); } @@ -97,7 +98,7 @@ class AddBias : public NodeTransformation { } // namespace std::unique_ptr NewAddBias() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc index 361b6d0ebf1322..66040d03aa8cde 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments_test.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/add_quant_adjustments.h" +#include #include +#include #include #include @@ -34,7 +36,7 @@ namespace tflite { namespace gpu { namespace { -void AddQuantParams(absl::optional* params, float min, +void AddQuantParams(std::optional* params, float min, float max, float scale) { params->emplace(); params->value().min = min; @@ -154,7 +156,7 @@ TEST(AddQuantAdjustments, GeneralCase) { graph.nodes()[2]->operation.type); EXPECT_EQ(quant_node->id, graph.nodes()[2]->id); EXPECT_EQ(ToString(OperationType::ADD), graph.nodes()[3]->operation.type); - auto new_quant_attr = absl::any_cast( + auto new_quant_attr = std::any_cast( graph.nodes()[1]->operation.attributes); EXPECT_EQ(0.0, new_quant_attr.min); EXPECT_EQ(2.0, new_quant_attr.max); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc index 0236bfa4326ce0..4500b0ed50655a 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv.cc @@ -39,8 +39,8 @@ namespace { void FuseBiasWithAddAttributes(const ElementwiseAttributes& add_attr, const int channels, Tensor* bias) { - auto add = absl::get_if>(&add_attr.param); - auto add_scalar = absl::get_if(&add_attr.param); + auto add = std::get_if>(&add_attr.param); + auto add_scalar = std::get_if(&add_attr.param); if (bias->data.empty()) { *bias = MakeZeroTensor(Linear(channels)); } @@ -65,35 +65,35 @@ class MergeConvolutionWithAdd : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } ElementwiseAttributes add_attr = - absl::any_cast(add_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(add_node.operation.attributes); + if (!std::holds_alternative>( add_attr.param) && - !absl::holds_alternative(add_attr.param)) { + !std::holds_alternative(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolution2DWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolutionTransposedWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseDepthwiseConvolution2DWithAdd(add_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseFullyConnectedWithAdd(add_attr, conv_attr); } else { @@ -112,8 +112,8 @@ class MergeConvolutionWithAdd : public SequenceTransformation { void FuseAddWithConvolution2D(const ElementwiseAttributes& add_attr, Convolution2DAttributes* attr) { - auto add = absl::get_if>(&add_attr.param); - auto add_scalar = absl::get_if(&add_attr.param); + auto add = std::get_if>(&add_attr.param); + auto add_scalar = std::get_if(&add_attr.param); if (attr->bias.data.empty()) { attr->bias = MakeZeroTensor( Linear(attr->weights.shape.o)); @@ -149,17 +149,17 @@ class MergeAddWithConvolution : public SequenceTransformation { return {TransformStatus::SKIPPED, ""}; } ElementwiseAttributes add_attr = - absl::any_cast(add_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(add_node.operation.attributes); + if (!std::holds_alternative>( add_attr.param) && - !absl::holds_alternative(add_attr.param)) { + !std::holds_alternative(add_attr.param)) { return {TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar addition."}; } if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); if (conv_attr->groups != 1) { return {TransformStatus::DECLINED, @@ -191,11 +191,11 @@ class MergeAddWithConvolution : public SequenceTransformation { } // namespace std::unique_ptr NewMergeConvolutionWithAdd() { - return absl::make_unique(); + return std::make_unique(); } std::unique_ptr NewMergeAddWithConvolution() { - return absl::make_unique(); + return std::make_unique(); } void FuseConvolution2DWithAdd(const ElementwiseAttributes& add_attr, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc index ca2ec7caba7805..fc6c3e2975c98d 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_add_to_conv_test.cc @@ -224,7 +224,7 @@ TEST(MergeAddWithConvolutionTest, Smoke) { graph.nodes()[0]->operation.type); Convolution2DAttributes* conv_attr_new = - absl::any_cast( + std::any_cast( &graph.nodes()[0]->operation.attributes); EXPECT_THAT(conv_attr_new->bias.data, diff --git a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc index 507456a8fefe15..6496c77ac07163 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/fuse_mul_to_conv.cc @@ -55,10 +55,10 @@ class MergeConvolutionWithMul : public SequenceTransformation { } ElementwiseAttributes mul_attr = - absl::any_cast(mul_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(mul_node.operation.attributes); + if (!std::holds_alternative>( mul_attr.param) && - !absl::holds_alternative(mul_attr.param)) { + !std::holds_alternative(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; @@ -66,25 +66,25 @@ class MergeConvolutionWithMul : public SequenceTransformation { if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolution2DWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseConvolutionTransposedWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseDepthwiseConvolution2DWithMultiply(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseFullyConnectedWithMultiply(mul_attr, conv_attr); } else { @@ -119,10 +119,10 @@ class MergeMulWithConvolution : public SequenceTransformation { } ElementwiseAttributes mul_attr = - absl::any_cast(mul_node.operation.attributes); - if (!absl::holds_alternative>( + std::any_cast(mul_node.operation.attributes); + if (!std::holds_alternative>( mul_attr.param) && - !absl::holds_alternative(mul_attr.param)) { + !std::holds_alternative(mul_attr.param)) { return { TransformStatus::DECLINED, "This fuse applicable only for broadcast or scalar multiplication."}; @@ -130,25 +130,25 @@ class MergeMulWithConvolution : public SequenceTransformation { if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_2D)) { Convolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithConvolution2D(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::CONVOLUTION_TRANSPOSED)) { ConvolutionTransposedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithConvolutionTransposed(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::DEPTHWISE_CONVOLUTION)) { DepthwiseConvolution2DAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithDepthwiseConvolution2D(mul_attr, conv_attr); } else if (conv_node.operation.type == ToString(OperationType::FULLY_CONNECTED)) { FullyConnectedAttributes* conv_attr = - absl::any_cast( + std::any_cast( &conv_node.operation.attributes); FuseMultiplyWithFullyConnected(mul_attr, conv_attr); } else { @@ -168,17 +168,17 @@ class MergeMulWithConvolution : public SequenceTransformation { } // namespace std::unique_ptr NewMergeConvolutionWithMul() { - return absl::make_unique(); + return std::make_unique(); } std::unique_ptr NewMergeMulWithConvolution() { - return absl::make_unique(); + return std::make_unique(); } void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr, Convolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -198,8 +198,8 @@ void FuseConvolution2DWithMultiply(const ElementwiseAttributes& mul_attr, void FuseDepthwiseConvolution2DWithMultiply( const ElementwiseAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int g = 0; g < attr->weights.shape.o; ++g) { for (int s = 0; s < attr->weights.shape.i; ++s) { const int d = s * attr->weights.shape.o + g; @@ -220,8 +220,8 @@ void FuseDepthwiseConvolution2DWithMultiply( void FuseConvolutionTransposedWithMultiply( const ElementwiseAttributes& mul_attr, ConvolutionTransposedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -240,8 +240,8 @@ void FuseConvolutionTransposedWithMultiply( void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr, FullyConnectedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int d = 0; d < attr->weights.shape.o; ++d) { const float multiplier = mul ? mul->data[d] : *mul_scalar; for (int s = 0; s < attr->weights.shape.i; ++s) { @@ -256,8 +256,8 @@ void FuseFullyConnectedWithMultiply(const ElementwiseAttributes& mul_attr, void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr, Convolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { @@ -274,8 +274,8 @@ void FuseMultiplyWithConvolution2D(const ElementwiseAttributes& mul_attr, void FuseMultiplyWithDepthwiseConvolution2D( const ElementwiseAttributes& mul_attr, DepthwiseConvolution2DAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int g = 0; g < attr->weights.shape.o; ++g) { @@ -292,8 +292,8 @@ void FuseMultiplyWithDepthwiseConvolution2D( void FuseMultiplyWithConvolutionTransposed( const ElementwiseAttributes& mul_attr, ConvolutionTransposedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { @@ -309,8 +309,8 @@ void FuseMultiplyWithConvolutionTransposed( void FuseMultiplyWithFullyConnected(const ElementwiseAttributes& mul_attr, FullyConnectedAttributes* attr) { - auto mul = absl::get_if>(&mul_attr.param); - auto mul_scalar = absl::get_if(&mul_attr.param); + auto mul = std::get_if>(&mul_attr.param); + auto mul_scalar = std::get_if(&mul_attr.param); for (int s = 0; s < attr->weights.shape.i; ++s) { const float multiplier = mul ? mul->data[s] : *mul_scalar; for (int d = 0; d < attr->weights.shape.o; ++d) { diff --git a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc index 3034c91c0929d3..fc3dec545a21b3 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/global_pooling_to_reduce_op.h" +#include #include #include #include @@ -56,7 +57,7 @@ class GlobalPoolingToReduceOp : public NodeTransformation { auto inputs = graph->FindInputs(node->id); auto outputs = graph->FindOutputs(node->id); const auto& pool_attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); if (!IsGlobalAveragePooling(pool_attr, inputs[0]->tensor.shape, outputs[0]->tensor.shape)) { return {TransformStatus::SKIPPED, ""}; @@ -75,7 +76,7 @@ class GlobalPoolingToReduceOp : public NodeTransformation { } // namespace std::unique_ptr NewGlobalPoolingToReduceOp() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc index 226e7d4b2a9696..d8e7aebb2a8960 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include #include #include #include @@ -56,7 +57,7 @@ class MakeFullyConnectedFromConvolution : public NodeTransformation { return {TransformStatus::SKIPPED, ""}; } - const auto& conv_attr = absl::any_cast( + const auto& conv_attr = std::any_cast( node->operation.attributes); if (!IsConvEquivalentToFullyConnected(conv_attr)) { return {TransformStatus::SKIPPED, ""}; @@ -76,7 +77,7 @@ class MakeFullyConnectedFromConvolution : public NodeTransformation { } // namespace std::unique_ptr NewMakeFullyConnectedFromConvolution() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc index 783dcb02aa7d1a..24ae7894949cf6 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_fully_connected.h" +#include #include #include #include @@ -102,7 +103,7 @@ TEST(MakeFullyConnected, Smoke) { graph.nodes()[1]->operation.type); ASSERT_EQ(ToString(OperationType::FULLY_CONNECTED), graph.nodes()[2]->operation.type); - auto fc_attr = absl::any_cast( + auto fc_attr = std::any_cast( graph.nodes()[2]->operation.attributes); EXPECT_EQ(OHWI(32, 1, 1, 16), fc_attr.weights.shape); EXPECT_EQ(Linear(32), fc_attr.bias.shape); diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc index 6245f82289a6bb..865024002929f0 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include #include #include #include @@ -37,7 +38,7 @@ bool IsConstZeros(const Node& node) { return false; } auto& attr = - absl::any_cast(node.operation.attributes); + std::any_cast(node.operation.attributes); for (auto f : attr.tensor.data) { if (f != 0) { return false; @@ -62,7 +63,7 @@ class MakePaddingFromZerosConcat : public NodeTransformation { auto dep = graph->FindProducer(input->id); if (dep != nullptr && IsConstZeros(*dep)) { auto& concat_attr = - absl::any_cast(node->operation.attributes); + std::any_cast(node->operation.attributes); PadAttributes pad_attr; pad_attr.type = PaddingContentType::ZEROS; pad_attr.appended = BHWC(0, 0, 0, 0); @@ -101,7 +102,7 @@ class MakePaddingFromZerosConcat : public NodeTransformation { } // namespace std::unique_ptr NewMakePaddingFromConcat() { - return absl::make_unique(); + return std::make_unique(); } } // namespace gpu diff --git a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc index c33960c21d0eac..abe3594d0cdbd1 100644 --- a/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc +++ b/tensorflow/lite/delegates/gpu/common/transformations/make_padding_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/make_padding.h" +#include #include #include #include @@ -71,7 +72,7 @@ TEST(MakePadding, Smoke) { ASSERT_EQ(2, graph.values().size()); auto pad_node = graph.nodes()[0]; ASSERT_EQ(ToString(OperationType::PAD), pad_node->operation.type); - auto pad_attr = absl::any_cast(pad_node->operation.attributes); + auto pad_attr = std::any_cast(pad_node->operation.attributes); EXPECT_EQ(BHWC(0, 0, 0, 0), pad_attr.prepended); EXPECT_EQ(BHWC(0, 5, 0, 0), pad_attr.appended); } diff --git a/tensorflow/lite/delegates/gpu/gl_delegate.cc b/tensorflow/lite/delegates/gpu/gl_delegate.cc index 995cbd17af470c..7703de58f51330 100644 --- a/tensorflow/lite/delegates/gpu/gl_delegate.cc +++ b/tensorflow/lite/delegates/gpu/gl_delegate.cc @@ -26,9 +26,11 @@ limitations under the License. #include #include "absl/types/span.h" -#include "tensorflow/lite/builtin_ops.h" +#include "flatbuffers/buffer.h" // from @flatbuffers +#include "flatbuffers/verifier.h" // from @flatbuffers #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/delegates/gpu/common/convert.h" +#include "tensorflow/lite/delegates/gpu/common/gpu_info.h" #include "tensorflow/lite/delegates/gpu/common/model.h" #include "tensorflow/lite/delegates/gpu/common/model_builder.h" #include "tensorflow/lite/delegates/gpu/common/model_transformer.h" @@ -38,18 +40,21 @@ limitations under the License. #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h" #include "tensorflow/lite/delegates/gpu/gl/api.h" #include "tensorflow/lite/delegates/gpu/gl/command_queue.h" -#include "tensorflow/lite/delegates/gpu/gl/compiler.h" +#include "tensorflow/lite/delegates/gpu/gl/compiler_options.h" #include "tensorflow/lite/delegates/gpu/gl/converters/bhwc_to_phwc4.h" #include "tensorflow/lite/delegates/gpu/gl/converters/phwc4_to_bhwc.h" #include "tensorflow/lite/delegates/gpu/gl/egl_environment.h" -#include "tensorflow/lite/delegates/gpu/gl/gl_call.h" +#include "tensorflow/lite/delegates/gpu/gl/gl_buffer.h" #include "tensorflow/lite/delegates/gpu/gl/kernels/registry.h" +#include "tensorflow/lite/delegates/gpu/gl/object.h" +#include "tensorflow/lite/delegates/gpu/gl/object_manager.h" #include "tensorflow/lite/delegates/gpu/gl/request_gpu_info.h" +#include "tensorflow/lite/delegates/gpu/gl/runtime_options.h" #include "tensorflow/lite/delegates/gpu/gl/workgroups/best_effort_calculator.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #ifndef TFLITE_GPU_BINARY_RELEASE -#include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "tensorflow/lite/delegates/gpu/gl/metadata_generated.h" #include "tensorflow/lite/schema/schema_generated.h" #endif // TFLITE_GPU_BINARY_RELEASE diff --git a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc index f80a52d0896742..85541247c6d1d9 100644 --- a/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc +++ b/tensorflow/lite/delegates/utils/experimental/sample_stable_delegate/sample_stable_delegate_external_test.cc @@ -56,7 +56,7 @@ TEST(SampleStableDelegate, LoadFromSharedLibraryTestFile) { LoadDelegateFromSharedLibrary( "tensorflow/lite/delegates/utils/experimental/" "sample_stable_delegate/" - "libtensorflowlite_sample_stable_delegate_for_test.so"); + "libtensorflowlite_sample_stable_delegate.so"); ASSERT_NE(stable_delegate_handle, nullptr); EXPECT_STREQ(stable_delegate_handle->delegate_abi_version, TFL_STABLE_DELEGATE_ABI_VERSION); diff --git a/tensorflow/lite/delegates/xnnpack/BUILD b/tensorflow/lite/delegates/xnnpack/BUILD index 714ea029210c60..43ff934dbdf758 100644 --- a/tensorflow/lite/delegates/xnnpack/BUILD +++ b/tensorflow/lite/delegates/xnnpack/BUILD @@ -234,9 +234,10 @@ cc_library( }) + select({ ":xnnpack_use_transient_indirection_buffers_explicit": ["-DXNNPACK_DELEGATE_USE_TRANSIENT_INDIRECTION_BUFFERS=1"], "//conditions:default": [], - }), + }) + ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], linkstatic = True, deps = [ + ":flexbuffers_util", ":quantization_util", ":tflite_with_xnnpack_dynamic_fully_connected", ":tflite_with_xnnpack_logging", @@ -260,6 +261,7 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK", "@eigen_archive//:eigen3", + "@flatbuffers//:runtime_cc", "@pthreadpool", ], ) @@ -278,9 +280,10 @@ cc_library( name = "xnnpack_delegate_test_mode", srcs = ["xnnpack_delegate.cc"], hdrs = ["xnnpack_delegate.h"], - copts = tflite_copts() + ["-DXNNPACK_DELEGATE_TEST_MODE=1"], + copts = tflite_copts() + ["-DXNNPACK_DELEGATE_TEST_MODE=1"] + ["-DFLATBUFFERS_LOCALE_INDEPENDENT=0"], linkstatic = True, deps = [ + ":flexbuffers_util", ":quantization_util", ":weight_cache", "//tensorflow/lite:kernel_api", @@ -299,6 +302,7 @@ cc_library( "//tensorflow/lite/tools/optimize:reduced_precision_support", "@XNNPACK", "@eigen_archive//:eigen3", + "@flatbuffers//:runtime_cc", "@pthreadpool", ], ) @@ -341,6 +345,15 @@ cc_library( ], ) +cc_library( + name = "flexbuffers_util", + hdrs = ["flexbuffers_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@flatbuffers//:runtime_cc", + ], +) + ################################ Tester classes ################################ cc_library( @@ -1686,9 +1699,9 @@ cc_test( name = "odml_sdpa_test", srcs = ["odml_sdpa_test.cc"], data = [ - ":odml_sdpa_composite_gqa.tflite", - ":odml_sdpa_composite_mha.tflite", - ":odml_sdpa_composite_mqa.tflite", + ":odml_sdpa_composite_gqa.tflite.bin", + ":odml_sdpa_composite_mha.tflite.bin", + ":odml_sdpa_composite_mqa.tflite.bin", ], linkopts = select({ "//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS, @@ -2937,4 +2950,14 @@ cc_test( ], ) +cc_test( + name = "flexbuffers_util_test", + srcs = ["flexbuffers_util_test.cc"], + deps = [ + ":flexbuffers_util", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + ], +) + tflite_portable_test_suite_combined(combine_conditions = {"deps": [":test_main"]}) diff --git a/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc b/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc index 4383ab6ee5e367..d3bf54c145f975 100644 --- a/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/concatenation_tester.cc @@ -166,6 +166,7 @@ std::vector ConcatenationTester::CreateTfLiteModel( }}; std::vector> tensors; + tensors.reserve(NumInputs()); for (size_t i = 0; i < NumInputs(); i++) { tensors.push_back(CreateTensor( builder, @@ -190,6 +191,7 @@ std::vector ConcatenationTester::CreateTfLiteModel( builder.CreateVector({output_zero_point_})))); std::vector op_inputs; + op_inputs.reserve(NumInputs()); for (size_t i = 0; i < NumInputs(); i++) { op_inputs.push_back(static_cast(i)); } diff --git a/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h b/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h new file mode 100644 index 00000000000000..6f303c8a92a2da --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/flexbuffers_util.h @@ -0,0 +1,59 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ +#define TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ + +#include "flatbuffers/base.h" // from @flatbuffers +#include "flatbuffers/flexbuffers.h" // from @flatbuffers + +namespace tflite::xnnpack { +// We use this class defined with internal linkage as a key to prevent the +// following workaround to leak into other translation units. +struct FloatPointer { + const float* ptr = nullptr; +}; +} // namespace tflite::xnnpack + +namespace flexbuffers { + +// TODO(b/359351192): switch to xnnpack builtin. This is a workaround until we +// are able to use just the value. +// +// We go around the access policy of the `Reference` class by specializing a +// template function that was not specialized for our use case. +// +// This is weakly tolerant to an update to the `Reference` class because: +// - THIS IS MEANT TO BE TEMPORARY until we actually use the XNNPack +// implementation of SDPA (and dependent on not needing data ptr). +// - The flexbuffer spec is public and set, so the layout should not evolve +// much. +// +// The alternative was to copy/paste the code to get to the map data and grab +// the pointer which basically means rewriting flexbuffer.h. +template <> +tflite::xnnpack::FloatPointer inline flexbuffers::Reference::As< + tflite::xnnpack::FloatPointer>() const { +#if !FLATBUFFERS_LITTLEENDIAN + // Flexbuffers are always stored in little endian order. Returning a pointer + // to the float data on a big endian architecture is meaningless. + return nullptr; +#else + return {IsFloat() ? reinterpret_cast(data_) : nullptr}; +#endif +} + +} // namespace flexbuffers + +#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_FLEXBUFFERS_UTIL_H_ diff --git a/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc b/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc new file mode 100644 index 00000000000000..d3e112bea1547c --- /dev/null +++ b/tensorflow/lite/delegates/xnnpack/flexbuffers_util_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/lite/delegates/xnnpack/flexbuffers_util.h" + +#include +#include +#include "flatbuffers/flexbuffers.h" // from @flatbuffers + +namespace tflite::xnnpack { +namespace { + +using ::testing::Pointee; + +TEST(FlexbuffersUtilTest, FloatPointer) { + constexpr float kAValue = 3.14; + constexpr float kBValue = 56; + + flexbuffers::Builder fbb; + fbb.Map([&] { + fbb.Float("a", kAValue); + fbb.Float("b", kBValue); + }); + fbb.Finish(); + + const flexbuffers::Map map = flexbuffers::GetRoot(fbb.GetBuffer()).AsMap(); + + const flexbuffers::Reference a = map["a"]; + EXPECT_TRUE(a.IsFloat()); + EXPECT_THAT(a.As().ptr, Pointee(kAValue)); + + const flexbuffers::Reference b = map["b"]; + EXPECT_TRUE(b.IsFloat()); + EXPECT_THAT(b.As().ptr, Pointee(kBValue)); + + const flexbuffers::Reference c = map["c"]; + ASSERT_TRUE(c.IsNull()); + EXPECT_EQ(c.As().ptr, nullptr); +} + +} // namespace +} // namespace tflite::xnnpack diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_gqa.tflite.bin b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_gqa.tflite.bin new file mode 100644 index 00000000000000..5dd33bc206a287 Binary files /dev/null and b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_gqa.tflite.bin differ diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mha.tflite.bin b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mha.tflite.bin new file mode 100644 index 00000000000000..eb03e9e424cbe6 Binary files /dev/null and b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mha.tflite.bin differ diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mqa.tflite.bin b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mqa.tflite.bin new file mode 100644 index 00000000000000..4c12a0d7ce0230 Binary files /dev/null and b/tensorflow/lite/delegates/xnnpack/odml_sdpa_composite_mqa.tflite.bin differ diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc index bf54f45cf04233..0a2c6d85cfa02d 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include @@ -35,6 +36,17 @@ struct SDPATestParams { int head_dim; // embedding_dim//q_heads }; +void PrintTo(const SDPATestParams& p, std::ostream* os) { + if (p.model_name != kOdmlSdpaCustom) { + *os << "{ TFLite file: " << p.model_name << ".tflite.bin }"; + } else { + *os << "{ Custom test: " << p.custom_test_name << ", b:" << p.batch + << ", isl:" << p.input_seq_len << ", msl:" << p.max_seq_len + << ", q:" << p.q_heads << ", k:" << p.kv_heads << "h:" << p.head_dim + << " }"; + } +} + std::string TestName(const testing::TestParamInfo& info) { if (info.param.model_name != kOdmlSdpaCustom) { return info.param.model_name; diff --git a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc index c7714a816bf953..0af79ba33cb2ab 100644 --- a/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc +++ b/tensorflow/lite/delegates/xnnpack/odml_sdpa_tester.cc @@ -119,8 +119,9 @@ void ODMLSDPATester::Test(TfLiteDelegate* delegate) const { std::vector ODMLSDPATester::CreateTfLiteModel() const { if (!model_name_.empty() && model_name_ != kOdmlSdpaCustom) { const char kTestModelFolder[] = - "third_party/tensorflow/lite/delegates/xnnpack/"; - const std::string test_model = kTestModelFolder + model_name_ + ".tflite"; + "tensorflow/lite/delegates/xnnpack/"; + const std::string test_model = + kTestModelFolder + model_name_ + ".tflite.bin"; std::string model_data; if (!flatbuffers::LoadFile(test_model.c_str(), /*binary=*/true, &model_data)) { diff --git a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc index db5cf51f6845ee..65b3475b75552a 100644 --- a/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc +++ b/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc @@ -34,6 +34,7 @@ limitations under the License. #include "xnnpack.h" // from @XNNPACK #include "Eigen/Core" // from @eigen_archive +#include "flatbuffers/flexbuffers.h" // from @flatbuffers #include "pthreadpool.h" // from @pthreadpool #include "tensorflow/lite/builtin_ops.h" #include "tensorflow/lite/c/c_api_types.h" @@ -41,6 +42,7 @@ limitations under the License. #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/subgraph.h" +#include "tensorflow/lite/delegates/xnnpack/flexbuffers_util.h" #include "tensorflow/lite/delegates/xnnpack/quantization_util.h" #include "tensorflow/lite/delegates/xnnpack/weight_cache.h" #include "tensorflow/lite/kernels/cpu_backend_context.h" @@ -6701,22 +6703,14 @@ class Subgraph { const TfLiteTensor* tensors, const uint8_t* buffer, const size_t buffer_size, const std::unordered_map& input_output_tensors) { - const float* scale_val = nullptr; - // ensure 28 bytes as we expect - // TODO(b/339106680): this reading method may not work for every case. - if (buffer_size == 28 && sizeof(float) == 4) { - // Custom data here is a flexbuffer map. - // byte_width is 4 for our map. - // First 5 values are "scale", then is the float value, and last is - // flexbuffer metadata. - if (strcmp("scale", reinterpret_cast(buffer)) == 0) { - constexpr size_t kScaleValOffset = 20; - scale_val = reinterpret_cast(buffer + kScaleValOffset); - } - } - + flexbuffers::Map flexbuffer_map = + flexbuffers::GetRoot(buffer, buffer_size).AsMap(); + const float* const scale_ptr = + flexbuffer_map["scale"].As().ptr; + const float* const cap_ptr = + flexbuffer_map["logit_cap"].As().ptr; return VisitDotAttentionNode(subgraph, delegate, logging_context, - node_index, node, tensors, scale_val, + node_index, node, tensors, scale_ptr, cap_ptr, input_output_tensors); } @@ -6724,6 +6718,7 @@ class Subgraph { xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* logging_context, int node_index, TfLiteNode* node, const TfLiteTensor* tensors, const float* scale_param, + const float* cap_param, const std::unordered_map& input_output_tensors) { const TfLiteTensor& query_proj = tensors[node->inputs->data[0]]; TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type( @@ -6946,7 +6941,45 @@ class Subgraph { permute_q_out_id, reshape_dims_k_out_id, XNN_INVALID_VALUE_ID, fc_out_id, /*flags=*/0)); } - // TODO(b/323195341): add CapTanh support. + if (cap_param != nullptr) { + uint32_t cap_val_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, cap_param, + XNN_INVALID_VALUE_ID, 0, &cap_val_id)); + uint32_t cap_div_out_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_div_out_id)); + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_divide(subgraph, default_out_min, default_out_max, + fc_out_id, cap_val_id, cap_div_out_id, + /*flags=*/0)); + uint32_t cap_tanh_out_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_tanh_out_id)); + TF_LITE_ENSURE_EQ(logging_context, xnn_status_success, + xnn_define_tanh(subgraph, cap_div_out_id, + cap_tanh_out_id, /*flags=*/0)); + uint32_t cap_logits_id = XNN_INVALID_VALUE_ID; + TF_LITE_ENSURE_EQ( + logging_context, xnn_status_success, + xnn_define_tensor_value(subgraph, xnn_datatype_fp32, /*num_dims=*/0, + /*dims=*/nullptr, nullptr, + XNN_INVALID_VALUE_ID, 0, &cap_logits_id)); + TF_LITE_ENSURE_EQ(logging_context, xnn_status_success, + xnn_define_multiply2(subgraph, default_out_min, + default_out_max, cap_tanh_out_id, + cap_val_id, cap_logits_id, 0)); + fc_out_id = cap_logits_id; + } // element_add atten_mask and matmul_out uint32_t padded_logits_id = XNN_INVALID_VALUE_ID; TF_LITE_ENSURE_EQ( diff --git a/tensorflow/lite/examples/label_image/CMakeLists.txt b/tensorflow/lite/examples/label_image/CMakeLists.txt index 9874801f34fa31..2fcb09ce96e990 100644 --- a/tensorflow/lite/examples/label_image/CMakeLists.txt +++ b/tensorflow/lite/examples/label_image/CMakeLists.txt @@ -61,6 +61,11 @@ if(TFLITE_ENABLE_EXTERNAL_DELEGATE) ${TFLITE_SOURCE_DIR}/tools/delegates/external_delegate_provider.cc) endif() +include_directories(label_image + PUBLIC + ${CMAKE_BINARY_DIR} +) + add_executable(label_image ${TFLITE_LABEL_IMAGE_SRCS} ) @@ -78,4 +83,6 @@ target_compile_options(label_image ) target_link_libraries(label_image tensorflow-lite + profiling_info_proto + protobuf ) diff --git a/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb b/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb index c3d7257aca3a2d..6ec3922157583f 100644 --- a/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb +++ b/tensorflow/lite/g3doc/examples/on_device_training/overview.ipynb @@ -68,7 +68,7 @@ "id": "9ee074e4" }, "source": [ - "When deploying TensorFlow Lite machine learning model to device or mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model *without* data leaving your users' devices, improving user privacy, and without requiring users to update the device software.\n", + "When deploying a TensorFlow Lite machine learning model on an edge device or mobile app, you may want to enable the model to be improved or personalized based on input from the device or end user. Using on-device training techniques allows you to update a model *without* data leaving your users' devices, improving user privacy, and without requiring users to update the device software.\n", "\n", "For example, you may have a model in your mobile app that recognizes fashion items, but you want users to get improved recognition performance over time based on their interests. Enabling on-device training allows users who are interested in shoes to get better at recognizing a particular style of shoe or shoe brand the more often they use your app.\n", "\n", @@ -315,7 +315,7 @@ "id": "79f5f372fb0e" }, "source": [ - "Note: Make sure you preprocess your *training* and *testing* datasets in the same way, so that your testing accurately evaluate your model's performance." + "Note: Make sure to preprocess your *training* and *testing* datasets in the same way so that your testing can accurately evaluate your model's performance." ] }, { @@ -428,7 +428,7 @@ "id": "LaMMDLLewAaX" }, "source": [ - "Note: You should complete initial training of your model before converting it to TensorFlow Lite format, so that the model has an initial set of weights, and is able to perform reasonable inferences *before* you start collecting data and conducting training runs on the device." + "Note: You should complete the initial training of your model before converting it to TensorFlow Lite format, so that the model has an initial set of weights, and is able to perform reasonable inferences *before* you start collecting data and conducting training runs on the device." ] }, { @@ -439,7 +439,7 @@ "source": [ "## Convert model to TensorFlow Lite format\n", "\n", - "After you have extended your TensorFlow model to enable additional functions for on-device training and completed initial training of the model, you can convert it to TensorFlow Lite format. The following code converts and saves your model to that format, including the set of signatures that you use with the TensorFlow Lite model on a device: `train, infer, save, restore`." + "After you have extended your TensorFlow model to enable additional functions for on-device training and completed the initial training of the model, you can convert it to TensorFlow Lite format. The following code converts and saves your model to that format, including the set of signatures that you use with the TensorFlow Lite model on a device: `train, infer, save, restore`." ] }, { @@ -1062,7 +1062,7 @@ "source": [ "Congratulations! You now have built a TensorFlow Lite model that supports on-device training. For more coding details, check out the example implementation in the [model personalization demo app](https://github.com/tensorflow/examples/tree/master/lite/examples/model_personalization).\n", "\n", - "If you are interested in learning more about image classification, check [Keras classification tutorial](https://www.tensorflow.org/tutorials/keras/classification) in the TensorFlow official guide page. This tutorial is based on that exercise and provides more depth on the subject of classification.\n" + "If you are interested in learning more about image classification, check [Keras classification tutorial](https://www.tensorflow.org/tutorials/keras/classification) on the TensorFlow official guide page. This tutorial is based on that exercise and provides more depth on the subject of classification.\n" ] } ], diff --git a/tensorflow/lite/java/aar_with_jni.bzl b/tensorflow/lite/java/aar_with_jni.bzl index 808183ad93b16b..f2770119daaadf 100644 --- a/tensorflow/lite/java/aar_with_jni.bzl +++ b/tensorflow/lite/java/aar_with_jni.bzl @@ -6,7 +6,8 @@ def aar_with_jni( name, android_library, headers = None, - flatten_headers = False): + flatten_headers = False, + strip_headers_prefix = ""): """Generates an Android AAR with repo root license given an Android library target. Args: @@ -18,6 +19,7 @@ def aar_with_jni( generated .aar file. This is useful for distributing self-contained .aars with native libs that can be used directly by native clients. flatten_headers: Whether to flatten the output paths of included headers. + strip_headers_prefix: The prefix to strip from the output paths of included headers. """ # Generate dummy AndroidManifest.xml for dummy apk usage @@ -83,9 +85,14 @@ zip $$origdir/$(location :{1}.aar) LICENSE """.format(src) else: cmd += """ - mkdir -p headers/$$(dirname $(location {0})) - cp -RL $$origdir/$(location {0}) headers/$(location {0}) - """.format(src) + default_dir=$$(dirname $(rootpath {0})) + modified_dir=$$(echo $$default_dir | sed -e 's/^{1}//g') + mkdir -p headers/$$modified_dir + cp -RL $$origdir/$(location {0}) headers/$$modified_dir + if [ -n "{1}" ]; then + sed -i -e 's/^#include \"{1}/#include \"/g' headers/$$modified_dir/$$(basename $(location {0})) + fi + """.format(src, strip_headers_prefix.replace("/", "\\/")) cmd += "zip -r $$origdir/$(location :{0}.aar) headers".format(name) native.genrule( diff --git a/tensorflow/lite/java/src/main/native/BUILD b/tensorflow/lite/java/src/main/native/BUILD index 8ce5e0c075b86e..b6c518d0f1ff1d 100644 --- a/tensorflow/lite/java/src/main/native/BUILD +++ b/tensorflow/lite/java/src/main/native/BUILD @@ -22,6 +22,7 @@ cc_library_with_tflite( visibility = jni_utils_visibility_allowlist(), deps = [ "//tensorflow/lite:error_reporter", + "//tensorflow/lite/core/c:common", "//tensorflow/lite/java/jni", ], ) diff --git a/tensorflow/lite/java/src/main/native/jni_utils.h b/tensorflow/lite/java/src/main/native/jni_utils.h index 1602d77d95a4e7..1796a388ec01fc 100644 --- a/tensorflow/lite/java/src/main/native/jni_utils.h +++ b/tensorflow/lite/java/src/main/native/jni_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include +#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/error_reporter.h" namespace tflite { diff --git a/tensorflow/lite/kernels/CMakeLists.txt b/tensorflow/lite/kernels/CMakeLists.txt index ae2523e738bc1c..946b56353e6c15 100644 --- a/tensorflow/lite/kernels/CMakeLists.txt +++ b/tensorflow/lite/kernels/CMakeLists.txt @@ -90,6 +90,8 @@ set(TEST_FRAMEWORK_SRC ${DELEGATE_PROVIDERS} ${TFLITE_SOURCE_DIR}/tools/optimize/model_utils.cc ${TF_SOURCE_DIR}/compiler/mlir/lite/tools/optimize/operator_property.cc + ${TF_SOURCE_DIR}/compiler/mlir/tools/optimize/quantization_utils.cc + ${TF_SOURCE_DIR}/compiler/mlir/lite/kernels/internal/runtime_shape.cc ${TFLITE_SOURCE_DIR}/tools/optimize/quantization_utils.cc ${TFLITE_SOURCE_DIR}/tools/tool_params.cc ${TFLITE_SOURCE_DIR}/tools/versioning/op_version.cc diff --git a/tensorflow/lite/kernels/batch_matmul.cc b/tensorflow/lite/kernels/batch_matmul.cc index 6c72d9003ea76c..d1eb3130c78c04 100644 --- a/tensorflow/lite/kernels/batch_matmul.cc +++ b/tensorflow/lite/kernels/batch_matmul.cc @@ -27,7 +27,6 @@ limitations under the License. #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h" #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" -#include "tensorflow/lite/kernels/internal/reference/reference_ops.h" #include "tensorflow/lite/kernels/internal/tensor.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/tensor_utils.h" @@ -441,7 +440,6 @@ RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { return swapped_shape; } -template TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, const RuntimeShape& input_shape, const TfLiteTensor* input, @@ -494,18 +492,10 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, output_size *= output_shape.Dims(i); } std::fill_n(GetTensorData(output), output_size, 0.0f); - if (kernel_type == kGenericOptimized) { - optimized_ops::BatchMatMul( - filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, - input_offset_ptr, row_sums_ptr, GetTensorShape(output), - GetTensorData(accum_scratch), GetTensorData(output), - &(data->compute_row_sums), CpuBackendContext::GetFromContext(context)); - } else { - reference_ops::BatchMatMul( - filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, - input_offset_ptr, row_sums_ptr, GetTensorShape(output), - GetTensorData(output), &(data->compute_row_sums)); - } + reference_ops::BatchMatMul( + filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, + input_offset_ptr, row_sums_ptr, GetTensorShape(output), + GetTensorData(output), &(data->compute_row_sums)); return kTfLiteOk; } @@ -638,9 +628,9 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* row_sums; TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/6, &row_sums)); - return EvalHybrid( - context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, - scaling_factors, accum_scratch, row_sums, input_offsets, output); + return EvalHybrid(context, node, data, lhs_shape, lhs, rhs_shape, rhs, + input_quantized, scaling_factors, accum_scratch, row_sums, + input_offsets, output); } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) { if (output->type == kTfLiteInt8) { return EvalInt8Int8(context, data, lhs_shape, lhs, rhs_shape, diff --git a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc index d08592faec5856..40f3b812825497 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_lstm.cc @@ -126,7 +126,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set. -// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc) // Temporary tensors. enum TemporaryTensor { diff --git a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc index 4813b7c84204e9..e58c1471457318 100644 --- a/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/lite/kernels/bidirectional_sequence_rnn.cc @@ -61,7 +61,7 @@ constexpr int kBwAuxWeightsTensor = 11; // Optional. constexpr int kFwOutputTensor = 0; constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. -// LINT.ThenChange(//tensorflow/lite/tools/optimize/quantize_weights.cc) +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc) // Temporary tensors. enum TemporaryTensor { diff --git a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc index 1abad4b50a2b89..05f242397f44d9 100644 --- a/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc +++ b/tensorflow/lite/kernels/ctc/ctc_beam_search_decoder.cc @@ -203,6 +203,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // tensorflow/core/kernels/ctc_decoder_ops.cc std::vector::UnalignedConstMatrix> input_list_t; + input_list_t.reserve(max_time); for (std::size_t t = 0; t < max_time; ++t) { input_list_t.emplace_back( GetTensorData(inputs) + t * batch_size * num_classes, batch_size, diff --git a/tensorflow/lite/kernels/embedding_lookup.cc b/tensorflow/lite/kernels/embedding_lookup.cc index 4190fd7121c30f..d92701059822f6 100644 --- a/tensorflow/lite/kernels/embedding_lookup.cc +++ b/tensorflow/lite/kernels/embedding_lookup.cc @@ -104,13 +104,13 @@ TfLiteStatus EvalSimple(TfLiteContext* context, TfLiteNode* node, // Propagate empty tensor if input is empty return kTfLiteOk; } - const int row_bytes = value->bytes / row_size; + const int64_t row_bytes = value->bytes / row_size; char* output_raw = GetTensorData(output); const char* value_raw = GetTensorData(value); const int32_t* lookup_data = GetTensorData(lookup); for (int i = 0; i < SizeOfDimension(lookup, 0); i++) { - int idx = lookup_data[i]; + int64_t idx = lookup_data[i]; if (idx >= row_size || idx < 0) { TF_LITE_KERNEL_LOG(context, "Embedding Lookup: index out of bounds. " diff --git a/tensorflow/lite/kernels/embedding_lookup_test.cc b/tensorflow/lite/kernels/embedding_lookup_test.cc index d13ddd443f6891..493d086aa50804 100644 --- a/tensorflow/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/lite/kernels/embedding_lookup_test.cc @@ -92,6 +92,19 @@ class EmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { } } } + + template + void Set2DWeightMatrix(const std::function& function) { + TfLiteTensor* tensor = interpreter_->tensor(weight_); + int64_t rows = tensor->dims->data[0]; + int64_t columns = tensor->dims->data[1]; + T* data = GetTensorData(tensor); + for (int64_t i = 0; i < rows; i++) { + for (int64_t j = 0; j < columns; j++) { + data[i * columns + j] = function(i, j); + } + } + } }; class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel { @@ -144,6 +157,28 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { }))); } +#if !defined(MEMORY_SANITIZER) && !defined(GOOGLE_UNSUPPORTED_OS_LOONIX) && \ + defined(__LP64__) +TEST(EmbeddingLookupOpTest, LargeTableTest) { + EmbeddingLookupOpModel m({1}, {256000, 9216}); + // Choose a value specifically designed to overflow int32.max + m.SetInput({235248}); + m.Set2DWeightMatrix( + [](int i, int j) -> float { return j + i / 100.; }); + + // This will cause a lookup at index 235248 in a buffer where every row + // has 9216 entries * 4 bytes per entry, which will overflow unless + // the Op is using a 64-bit offset for address calculation. + ASSERT_EQ(m.Invoke(), kTfLiteOk); + std::vector exp(9216); + + for (int s = 0; s < exp.size(); s++) { + exp[s] = static_cast(s) + 2352.48f; + } + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(exp))); +} +#endif + TEST(HybridEmbeddingLookupHybridOpTest, Simple2DTestUint8) { HybridEmbeddingLookupOpModel m({3}, {3, 8}, TensorType_UINT8); m.SetInput({1, 0, 2}); diff --git a/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc b/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc index fea343ae6b8824..5173586d423ab5 100644 --- a/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/averagepool_quantized_test.cc @@ -34,11 +34,11 @@ namespace { // are the same. void RunOneAveragePoolTest(const PoolParams& params, const RuntimeShape& input_shape, - const int8* input_data, + const int8_t* input_data, const RuntimeShape& output_shape) { const int buffer_size = output_shape.FlatSize(); - std::vector optimized_averagePool_output(buffer_size); - std::vector reference_averagePool_output(buffer_size); + std::vector optimized_averagePool_output(buffer_size); + std::vector reference_averagePool_output(buffer_size); bool reference_success = reference_integer_ops::AveragePool( params, input_shape, input_data, output_shape, @@ -86,7 +86,7 @@ void CreateDataAndRunAveragePool(bool padding_same) { auto output_shape = RuntimeShape({batch, output_height, output_width, output_depth}); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); PoolParams params; @@ -172,17 +172,17 @@ void CreateExtremalDataAndRunAveragePool(bool padding_same) { filter_height, output_height); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); // Test small values - int8 min = std::numeric_limits::min(); - int8 max = std::numeric_limits::min() + 10; + int8_t min = std::numeric_limits::min(); + int8_t max = std::numeric_limits::min() + 10; FillRandom(&input_data, min, max); RunOneAveragePoolTest(params, input_shape, input_data.data(), output_shape); // Test large values - min = std::numeric_limits::max() - 10; - max = std::numeric_limits::max(); + min = std::numeric_limits::max() - 10; + max = std::numeric_limits::max(); FillRandom(&input_data, min, max); RunOneAveragePoolTest(params, input_shape, input_data.data(), output_shape); } diff --git a/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc b/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc index 562797bfffeb0e..f0ad42b2cd100f 100644 --- a/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc +++ b/tensorflow/lite/kernels/internal/conv_per_channel_quantized_16x8_test.cc @@ -38,8 +38,8 @@ namespace { void PickOutputMultiplier( const ConvParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, + const int16_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, const std::int64_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; @@ -81,9 +81,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset(filter_shape, output_channel, filter_y, filter_x, in_channel)]; acc += static_cast(filter_val) * @@ -296,8 +296,8 @@ void TryTestOneConvFilter(int test_num) { for (int c = 0; c < output_shape_inference.Dims(3); c++) { int offset = Offset(output_shape_inference, n, h, w, c); float float_res = output_data_float.data()[offset]; - int16 int16_res = reference_output_data.data()[offset]; - int32 output_mul = output_multiplier.data()[c]; + int16_t int16_res = reference_output_data.data()[offset]; + int32_t output_mul = output_multiplier.data()[c]; int shift = output_shift.data()[c]; float scale = (float)output_mul / (float)(1ULL << 31); if (shift > 0) scale = scale * (float)(1 << shift); diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc index 7d586c5ac94430..f395cdd13ff18b 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_16x8_test.cc @@ -38,8 +38,8 @@ namespace { void PickOutputMultiplier( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int16* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, + const int16_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, const std::int64_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; @@ -81,9 +81,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = filter_data[Offset( + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( filter_shape, 0, filter_y, filter_x, output_channel)]; acc += static_cast(filter_val) * static_cast(input_val); @@ -286,8 +286,8 @@ void TryTestOneDepthwiseConv3x3Filter() { for (int c = 0; c < output_shape_inference.Dims(3); c++) { int offset = Offset(output_shape_inference, n, h, w, c); float float_res = output_data_float.data()[offset]; - int16 int16_res = reference_output_data.data()[offset]; - int32 output_mul = output_multiplier.data()[c]; + int16_t int16_res = reference_output_data.data()[offset]; + int32_t output_mul = output_multiplier.data()[c]; int shift = output_shift.data()[c]; float scale = (float)output_mul / (float)(1ULL << 31); if (shift > 0) scale = scale * (float)(1 << shift); diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc index 8336b63b0ba48e..716b0fce731298 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_per_channel_quantized_test.cc @@ -39,9 +39,9 @@ namespace { void PickOutputMultiplier( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, float* output_multiplier) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -50,7 +50,7 @@ void PickOutputMultiplier( const int pad_width = params.padding_values.width; const int pad_height = params.padding_values.height; const int depth_multiplier = params.depth_multiplier; - const int32 input_offset = params.input_offset; + const int32_t input_offset = params.input_offset; const int batches = MatchingDim(input_shape, 0, output_shape, 0); const int input_height = input_shape.Dims(1); @@ -72,7 +72,7 @@ void PickOutputMultiplier( const int output_channel = m + in_channel * depth_multiplier; const int in_x_origin = (out_x * stride_width) - pad_width; const int in_y_origin = (out_y * stride_height) - pad_height; - int32 acc = 0; + int32_t acc = 0; for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { const int in_x = in_x_origin + dilation_width_factor * filter_x; @@ -83,9 +83,9 @@ void PickOutputMultiplier( (in_x >= 0) && (in_x < input_width) && (in_y >= 0) && (in_y < input_height); if (is_point_inside_image) { - int32 input_val = input_data[Offset(input_shape, batch, in_y, - in_x, in_channel)]; - int32 filter_val = filter_data[Offset( + int32_t input_val = input_data[Offset( + input_shape, batch, in_y, in_x, in_channel)]; + int32_t filter_val = filter_data[Offset( filter_shape, 0, filter_y, filter_x, output_channel)]; acc += filter_val * (input_val + input_offset); } diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc index c9d301ab9564c3..d5a2da2b9d58f8 100644 --- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc @@ -124,7 +124,7 @@ inline void DispatchDepthwiseConvGeneral( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const std::int32_t* output_shift_adjust, const std::int32_t* output_multiplier_adjust, const RuntimeShape& output_shape, @@ -139,11 +139,11 @@ inline void DispatchDepthwiseConvGeneral( template <> inline void DispatchDepthwiseConvGeneral( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const std::int32_t* output_shift_adjust, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const std::int32_t* output_shift_adjust, const std::int32_t* output_multiplier_adjust, - const RuntimeShape& output_shape, int8* output_data, int thread_start, + const RuntimeShape& output_shape, int8_t* output_data, int thread_start, int thread_end, int thread_dim) { optimized_integer_ops::depthwise_conv::DepthwiseConvGeneral( params, output_multiplier_adjust, output_shift_adjust, input_shape, @@ -160,7 +160,7 @@ inline void DispatchDepthwiseConvImpl( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl::ExternalType* output_data) { @@ -349,7 +349,7 @@ inline void DispatchDepthwiseConvImpl( CpuBackendContext backend_context; backend_context.SetMaxNumThreads(test_param.num_threads); optimized_ops::DepthwiseConv< - typename QuantizationTypeImpl::ExternalType, int32>( + typename QuantizationTypeImpl::ExternalType, int32_t>( params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data, &backend_context); } @@ -363,7 +363,7 @@ inline void DispatchDepthwiseConvImpl( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl< QuantizationType::kPerChannelInt8>::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl< QuantizationType::kPerChannelInt8>::ExternalType* output_data) { @@ -530,7 +530,7 @@ inline void DispatchDepthwiseConv( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl::ExternalType* output_data) { @@ -546,10 +546,10 @@ template <> struct ReferenceRunner { static inline void Run( const TestParam& test_param, const tflite::DepthwiseParams& op_params, - const uint8* input_data, const RuntimeShape& input_shape, - const uint8* filter_data, const RuntimeShape& filter_shape, + const uint8_t* input_data, const RuntimeShape& input_shape, + const uint8_t* filter_data, const RuntimeShape& filter_shape, const std::int32_t* bias_data, const RuntimeShape& bias_shape, - const RuntimeShape& output_shape, uint8* reference_output_data) { + const RuntimeShape& output_shape, uint8_t* reference_output_data) { switch (test_param.output_rounding) { case DepthwiseConvOutputRounding::kUpward: reference_ops::depthwise_conv::DepthwiseConvBasicKernel< @@ -577,10 +577,10 @@ template <> struct ReferenceRunner { static inline void Run( const TestParam& test_param, const tflite::DepthwiseParams& op_params, - const int8* input_data, const RuntimeShape& input_shape, - const int8* filter_data, const RuntimeShape& filter_shape, + const int8_t* input_data, const RuntimeShape& input_shape, + const int8_t* filter_data, const RuntimeShape& filter_shape, const std::int32_t* bias_data, const RuntimeShape& bias_shape, - const RuntimeShape& output_shape, int8* reference_output_data) { + const RuntimeShape& output_shape, int8_t* reference_output_data) { switch (test_param.output_rounding) { case DepthwiseConvOutputRounding::kUpward: reference_ops::depthwise_conv::DepthwiseConvBasicKernel< @@ -646,8 +646,8 @@ int TestOneDepthwiseConvWithGivenOutputShift( op_params.output_shift = -output_shift; const int depth = output_shape.Dims(3); - std::vector output_multiplier_per_channel(depth, output_multiplier); - std::vector output_shift_per_channel(depth, -output_shift); + std::vector output_multiplier_per_channel(depth, output_multiplier); + std::vector output_shift_per_channel(depth, -output_shift); if (output_multiplier_adjust != nullptr) { for (int i = 0; i < depth; ++i) { output_multiplier_per_channel[i] += output_multiplier_adjust[i]; @@ -898,8 +898,10 @@ bool TryTestDepthwiseConv(const TestParam& test_param, if (test_param.quantization_type == QuantizationType::kPerChannelInt8) { std::vector input_data(input_buffer_size); std::vector filter_data(filter_buffer_size); - FillRandom(&input_data, static_cast(-127), static_cast(127)); - FillRandom(&filter_data, static_cast(-127), static_cast(127)); + FillRandom(&input_data, static_cast(-127), + static_cast(127)); + FillRandom(&filter_data, static_cast(-127), + static_cast(127)); std::int32_t filter_offset = 0; EXPECT_TRUE(params_specialization == ParamsSpecialization::kSymmetric); diff --git a/tensorflow/lite/kernels/internal/log_quantized_test.cc b/tensorflow/lite/kernels/internal/log_quantized_test.cc index 2a27a097d2ab4c..7d0a549cbe180c 100644 --- a/tensorflow/lite/kernels/internal/log_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/log_quantized_test.cc @@ -53,23 +53,24 @@ class LogQuantizedTest : public ::testing::Test { }; // input_integer_bits <= 30. output_integer_bits > 0. -inline int32 LogPositiveValuesViaFloat(int32 input_val, int input_integer_bits, - int output_integer_bits) { +inline int32_t LogPositiveValuesViaFloat(int32_t input_val, + int input_integer_bits, + int output_integer_bits) { const double float_log_sum_of_exps = std::log( static_cast(input_val) * 0.5 / (1 << (30 - input_integer_bits))); static constexpr double min_int = - static_cast(std::numeric_limits::min()); + static_cast(std::numeric_limits::min()); static constexpr double max_int = - static_cast(std::numeric_limits::max()); + static_cast(std::numeric_limits::max()); double double_result = tflite::TfLiteRound(float_log_sum_of_exps * (1 << (31 - output_integer_bits))); return static_cast( std::min(max_int, std::max(min_int, double_result))); } -void CheckOutputData(const std::vector& test_output, - const std::vector& reference_output, - const std::vector& test_input, +void CheckOutputData(const std::vector& test_output, + const std::vector& reference_output, + const std::vector& test_input, const string& check_label, int input_integer_bits, int output_integer_bits, int tolerance) { // In the special case of small input, specifically raw value of 5, a rounding @@ -107,8 +108,8 @@ void CheckOutputData(const std::vector& test_output, } } -void RightShiftVector(const std::vector& shifts, - std::vector* vec) { +void RightShiftVector(const std::vector& shifts, + std::vector* vec) { const int n = vec->size(); ASSERT_EQ(n, shifts.size()); for (int i = 0; i < n; ++i) { @@ -117,15 +118,15 @@ void RightShiftVector(const std::vector& shifts, } template -void RunSingleTest(const std::vector& test_input, +void RunSingleTest(const std::vector& test_input, const string& check_label, int tolerance) { const int n = test_input.size(); - std::vector float_gen_output(n, 0); - std::vector quantized_output(n, 0); + std::vector float_gen_output(n, 0); + std::vector quantized_output(n, 0); // Workaround the stupid things that intelligent humans do. // Consequence of __builtin_clz(0u) may equal 31 instead of 32. - std::vector fudged_input(n, 0); + std::vector fudged_input(n, 0); for (int i = 0; i < n; ++i) { fudged_input[i] = std::max(test_input[i], 2); } @@ -134,7 +135,7 @@ void RunSingleTest(const std::vector& test_input, quantized_output[i] = tflite::log_x_for_x_greater_than_or_equal_to_1_impl( - gemmlowp::FixedPoint::FromRaw( + gemmlowp::FixedPoint::FromRaw( fudged_input[i])) .raw(); float_gen_output[i] = LogPositiveValuesViaFloat( @@ -151,8 +152,9 @@ void RunSingleTest(const std::vector& test_input, } template -void RunSingleTest(const std::vector& test_input, int input_integer_bits, - const string& check_label, int tolerance) { +void RunSingleTest(const std::vector& test_input, + int input_integer_bits, const string& check_label, + int tolerance) { #define INPUT_CASE(K) \ case K: \ return RunSingleTest(test_input, check_label, \ @@ -195,9 +197,9 @@ void RunSingleTest(const std::vector& test_input, int input_integer_bits, #undef INPUT_CASE } -void RunSingleTest(const std::vector& test_input, int input_integer_bits, - int output_integer_bits, const string& check_label, - int tolerance) { +void RunSingleTest(const std::vector& test_input, + int input_integer_bits, int output_integer_bits, + const string& check_label, int tolerance) { #define OUTPUT_CASE(K) \ case K: \ return RunSingleTest(test_input, input_integer_bits, check_label, \ @@ -248,9 +250,9 @@ void RunUniformTest(int test_size, int input_integer_bits, test_data[0] = 2; test_data[1] = 3; test_data[2] = 4; - test_data[3] = std::numeric_limits::max() - 2; - test_data[4] = std::numeric_limits::max() - 1; - test_data[5] = std::numeric_limits::max(); + test_data[3] = std::numeric_limits::max() - 2; + test_data[4] = std::numeric_limits::max() - 1; + test_data[5] = std::numeric_limits::max(); RunSingleTest(test_data, input_integer_bits, output_integer_bits, check_label + " / uniform test", tolerance); diff --git a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc index 72e4685d1e949a..3dfbd6930fe8c8 100644 --- a/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/logsoftmax_quantized_test.cc @@ -34,11 +34,11 @@ limitations under the License. namespace tflite { namespace { -void RunLogSoftmaxFloatReference(const uint8* input_data, +void RunLogSoftmaxFloatReference(const uint8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - uint8* reference_output_data) { + uint8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -67,11 +67,11 @@ void RunLogSoftmaxFloatReference(const uint8* input_data, // - input and output data type // - Dequnatize function // - clamping values -void RunLogSoftmaxFloatReference(const int8* input_data, +void RunLogSoftmaxFloatReference(const int8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - int8* reference_output_data) { + int8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -143,21 +143,22 @@ void CheckOutputData(const T* test_output, const T* reference_output, // Runs the LogSoftmax and compares against the float reference implementation // and the quantized reference implementation. -void RunOneLogSoftmaxTest(const uint8* input_data, - const RuntimeShape& shape_common, int32 input_offset, - const double input_scale, int stride, float beta) { +void RunOneLogSoftmaxTest(const uint8_t* input_data, + const RuntimeShape& shape_common, + int32_t input_offset, const double input_scale, + int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector optimized_logsoftmax_output(buffer_size); - std::vector reference_float_logsoftmax_output(buffer_size); - std::vector reference_quant_logsoftmax_output(buffer_size); + std::vector optimized_logsoftmax_output(buffer_size); + std::vector reference_float_logsoftmax_output(buffer_size); + std::vector reference_quant_logsoftmax_output(buffer_size); RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_logsoftmax_output.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; - int32 reverse_scaling_divisor; + int32_t reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessLogSoftmaxScalingExp( @@ -201,20 +202,22 @@ void RunOneLogSoftmaxTest(const uint8* input_data, // Runs the LogSoftmax and compares against the float reference implementation // and the int8 quantized reference implementation. -void RunOneLogSoftmaxTest(const int8* input_data, - const RuntimeShape& shape_common, int32 input_offset, - const double input_scale, int stride, float beta) { +void RunOneLogSoftmaxTest(const int8_t* input_data, + const RuntimeShape& shape_common, + int32_t input_offset, const double input_scale, + int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector quantized_logsoftmax_reference_implementation(buffer_size); - std::vector float_logsoftmax_optimized_implementation(buffer_size); + std::vector quantized_logsoftmax_reference_implementation( + buffer_size); + std::vector float_logsoftmax_optimized_implementation(buffer_size); RunLogSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, float_logsoftmax_optimized_implementation.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; - int32 reverse_scaling_divisor; + int32_t reverse_scaling_divisor; int reverse_scaling_right_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessLogSoftmaxScalingExp( @@ -258,7 +261,7 @@ bool TryOneUniformLogSoftmax() { const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; auto shape_common = @@ -291,7 +294,7 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); static constexpr float beta = 1.0f; // Extra parameters for skyscraper input patterns. const double middle_proportion = @@ -303,7 +306,7 @@ bool TryOneSkyscraperLogSoftmax(bool small_depth) { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); RunOneLogSoftmaxTest(input_data.data(), shape_common, input_offset, diff --git a/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc b/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc index 84afd3ddd52211..50b39085387b1b 100644 --- a/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/maxpool_quantized_test.cc @@ -33,11 +33,12 @@ namespace { // Runs the reference and optimized MaxPool functions and asserts the values // are the same. void RunOneMaxPoolTest(const PoolParams& params, - const RuntimeShape& input_shape, const int8* input_data, + const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& output_shape) { const int buffer_size = output_shape.FlatSize(); - std::vector optimized_maxpool_output(buffer_size); - std::vector reference_maxpool_output(buffer_size); + std::vector optimized_maxpool_output(buffer_size); + std::vector reference_maxpool_output(buffer_size); reference_integer_ops::MaxPool(params, input_shape, input_data, output_shape, reference_maxpool_output.data()); @@ -80,7 +81,7 @@ void CreateDataAndRunMaxPool(bool padding_same) { auto output_shape = RuntimeShape({batch, output_height, output_width, output_depth}); const int buffer_size = input_shape.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); PoolParams params; diff --git a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h index 502ecf0ee6426e..726a279bfaef13 100644 --- a/tensorflow/lite/kernels/internal/optimized/batch_matmul.h +++ b/tensorflow/lite/kernels/internal/optimized/batch_matmul.h @@ -117,111 +117,6 @@ inline void BatchMatMul(const RuntimeShape& lhs_shape, const float* lhs_data, } } -inline void BatchMatMul(const RuntimeShape& lhs_shape, const int8_t* lhs_data, - const RuntimeShape& rhs_shape, const int8_t* rhs_data, - const float* scaling_factors, - const int32_t* input_offset, int32_t* row_sums, - const RuntimeShape& output_shape, - int32_t* accum_scratch, float* output_data, - bool* compute_row_sums, CpuBackendContext* context) { - using ::tflite::cpu_backend_gemm::Gemm; - using ::tflite::cpu_backend_gemm::GemmParams; - using ::tflite::cpu_backend_gemm::MatrixParams; - - const RuntimeShape extended_lhs_shape = - RuntimeShape::ExtendedShape(5, lhs_shape); - const RuntimeShape extended_rhs_shape = - RuntimeShape::ExtendedShape(5, rhs_shape); - - // Determine which dimension is the broadcast dimension. - auto broadcast_dim = [](int lhs_dim, int rhs_dim) { - if (lhs_dim == rhs_dim) return lhs_dim; - if (lhs_dim == 1) return rhs_dim; - TFLITE_DCHECK_EQ(rhs_dim, 1); - return lhs_dim; - }; - - // Compute the "extent" for iterating on this dimension. - // If we are broadcasting, then don't advance (i.e return 0). - auto extent = [](const RuntimeShape& shape, int x) { - if (shape.Dims(x) == 1) { - return 0; - } - int prod = 1; - for (int i = x + 1; i < shape.DimensionsCount(); ++i) { - prod *= shape.Dims(i); - } - return prod; - }; - - const int batch_dim0 = - broadcast_dim(extended_lhs_shape.Dims(0), extended_rhs_shape.Dims(0)); - const int batch_dim1 = - broadcast_dim(extended_lhs_shape.Dims(1), extended_rhs_shape.Dims(1)); - const int batch_dim2 = - broadcast_dim(extended_lhs_shape.Dims(2), extended_rhs_shape.Dims(2)); - - const int lhs_ext0 = extent(extended_lhs_shape, 0); - const int lhs_ext1 = extent(extended_lhs_shape, 1); - const int lhs_ext2 = extent(extended_lhs_shape, 2); - const int rhs_ext0 = extent(extended_rhs_shape, 0); - const int rhs_ext1 = extent(extended_rhs_shape, 1); - const int rhs_ext2 = extent(extended_rhs_shape, 2); - - // Set params for each matrix multiply. - const int lhs_rows = extended_lhs_shape.Dims(3); - const int rhs_cols = extended_rhs_shape.Dims(4); - const int accum_depth = extended_lhs_shape.Dims(4); - - const int ioff_ext0 = rhs_ext0 == 0 ? 0 : rhs_cols; - const int ioff_ext1 = rhs_ext1 == 0 ? 0 : rhs_cols; - const int ioff_ext2 = rhs_ext2 == 0 ? 0 : rhs_cols; - const int woff_ext0 = lhs_ext0 == 0 ? 0 : lhs_rows; - const int woff_ext1 = lhs_ext1 == 0 ? 0 : lhs_rows; - const int woff_ext2 = lhs_ext2 == 0 ? 0 : lhs_rows; - - if (!compute_row_sums || *compute_row_sums) { - int num_weights_matrices = 1; - for (int i = 1; i < extended_lhs_shape.DimensionsCount() - 2; ++i) { - num_weights_matrices *= extended_lhs_shape.Dims(i); - } - tensor_utils::ReductionSumVector( - lhs_data, row_sums, num_weights_matrices * lhs_rows, accum_depth); - if (compute_row_sums) { - *compute_row_sums = false; - } - } - - for (int b0 = 0; b0 < batch_dim0; ++b0) { - const int8_t* lhs_ptr0 = lhs_data + (b0 * lhs_ext0); - const int8_t* rhs_ptr0 = rhs_data + (b0 * rhs_ext0); - const int32_t* ioff_ptr0 = input_offset + (b0 * ioff_ext0); - const float* scale_ptr0 = scaling_factors + (b0 * ioff_ext0); - int32_t* woff_ptr0 = row_sums + (b0 * woff_ext0); - for (int b1 = 0; b1 < batch_dim1; ++b1) { - const int8_t* lhs_ptr1 = lhs_ptr0 + b1 * lhs_ext1; - const int8_t* rhs_ptr1 = rhs_ptr0 + b1 * rhs_ext1; - const int32_t* ioff_ptr1 = ioff_ptr0 + (b1 * ioff_ext1); - const float* scale_ptr1 = scale_ptr0 + (b1 * ioff_ext1); - int32_t* woff_ptr1 = woff_ptr0 + (b1 * woff_ext1); - for (int b2 = 0; b2 < batch_dim2; ++b2) { - const int8_t* lhs_ptr2 = lhs_ptr1 + b2 * lhs_ext2; - const int8_t* rhs_ptr2 = rhs_ptr1 + b2 * rhs_ext2; - const int32_t* ioff_ptr2 = ioff_ptr1 + (b2 * ioff_ext2); - const float* scale_ptr2 = scale_ptr1 + (b2 * ioff_ext2); - int32_t* woff_ptr2 = woff_ptr1 + (b2 * woff_ext2); - float* out_ptr = output_data + ((b0 * batch_dim1 * batch_dim2) + - b1 * batch_dim2 + b2) * - lhs_rows * rhs_cols; - tensor_utils::MatrixBatchVectorMultiplyAccumulate( - lhs_ptr2, lhs_rows, accum_depth, rhs_ptr2, scale_ptr2, rhs_cols, - out_ptr, /*per_channel_scale=*/nullptr, ioff_ptr2, accum_scratch, - woff_ptr2, compute_row_sums, context); - } - } - } -} - inline void BatchMatMul(const FullyConnectedParams& params, const RuntimeShape& lhs_shape, const int8_t* lhs_data, const RuntimeShape& rhs_shape, const int8_t* rhs_data, diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h index ccd8d3e758ee9a..841b38050bc854 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_3x3_filter_common.h @@ -199,7 +199,7 @@ template <> struct DivideByPOT { template static inline IntegerType Run(IntegerType x, int exponent) { - return vqrshlq_s32(x, vdupq_n_s32(static_cast(-exponent))); + return vqrshlq_s32(x, vdupq_n_s32(static_cast(-exponent))); } template static inline IntegerType RunMult(IntegerType x, IntegerType exponent) { @@ -207,7 +207,7 @@ struct DivideByPOT { } template static inline IntegerType RunMult(IntegerType x, int exponent) { - return vqrshlq_s32(x, vdupq_n_s32(static_cast(exponent))); + return vqrshlq_s32(x, vdupq_n_s32(static_cast(exponent))); } }; #endif // ARM NEON @@ -231,18 +231,18 @@ struct QuantizationTypeImpl {}; template <> struct QuantizationTypeImpl { - typedef uint8 ExternalType; + typedef uint8_t ExternalType; static constexpr int kIntSymmetricZeroPoint = 128; - static constexpr uint8 kUint8SignBit = 0x80; + static constexpr uint8_t kUint8SignBit = 0x80; }; template <> struct QuantizationTypeImpl { - typedef int8 ExternalType; + typedef int8_t ExternalType; static constexpr int kIntSymmetricZeroPoint = 0; - static constexpr uint8 kUint8SignBit = 0x0; + static constexpr uint8_t kUint8SignBit = 0x0; }; template < @@ -250,16 +250,16 @@ template < inline DotProduct3x3KernelType CategorizeDotProductKernel( const RuntimeShape& input_shape, const RuntimeShape& filter_shape, const RuntimeShape& output_shape, const DepthwiseParams& params, - const int32* output_shift_ptr = nullptr) { + const int32_t* output_shift_ptr = nullptr) { constexpr int kSymmetricZeroPoint = QuantizationTypeImpl::kIntSymmetricZeroPoint; const int padding = std::max(params.padding_values.width, params.padding_values.height); const int stride = params.stride_width; - const int32 input_depth = input_shape.Dims(3); - const int32 depth_multiplier = params.depth_multiplier; - const int32 filter_height = filter_shape.Dims(1); - const int32 filter_width = filter_shape.Dims(2); + const int32_t input_depth = input_shape.Dims(3); + const int32_t depth_multiplier = params.depth_multiplier; + const int32_t filter_height = filter_shape.Dims(1); + const int32_t filter_width = filter_shape.Dims(2); bool supported = stride == params.stride_height && stride <= 2 && padding <= 1 && filter_width == 3 && filter_height == 3 && @@ -311,19 +311,19 @@ struct DepthwiseConvParams { int64_t output_depth; int64_t output_row_size; int64_t filter_row_size; - int32 input_offset; - int32 output_offset; - int32 filter_offset; - int32 output_multiplier; - int32 output_activation_min; - int32 output_activation_max; - int32 output_right_shift; - int32 input_width; - int32 input_height; - int32 stride_width; - int32 stride_height; - int32 output_width; - int32 output_height; + int32_t input_offset; + int32_t output_offset; + int32_t filter_offset; + int32_t output_multiplier; + int32_t output_activation_min; + int32_t output_activation_max; + int32_t output_right_shift; + int32_t input_width; + int32_t input_height; + int32_t stride_width; + int32_t stride_height; + int32_t output_width; + int32_t output_height; float float_output_activation_min; float float_output_activation_max; }; @@ -335,51 +335,51 @@ struct DepthwiseConvParams { struct DepthwiseConvDotProdParams { int64_t input_depth; int64_t output_depth; - int32 stride; - int32 bias_increment; + int32_t stride; + int32_t bias_increment; // - int32 input_offset; - int32 output_offset; - int32 output_multiplier; - int32 output_shift; - int32 quantized_activation_min; - int32 quantized_activation_max; + int32_t input_offset; + int32_t output_offset; + int32_t output_multiplier; + int32_t output_shift; + int32_t quantized_activation_min; + int32_t quantized_activation_max; // - int32 padding_left; - int32 padding_right; - int32 padding_top; - int32 padding_bottom; + int32_t padding_left; + int32_t padding_right; + int32_t padding_top; + int32_t padding_bottom; // - int32 depth_micro_repeats; + int32_t depth_micro_repeats; // - int32 width_macro_count; - int32 input_width_overall_micro_repeats; - int32 input_width_micro_repeats; - int32 residual_width; - int32 output_width_overall_micro_repeats; - int32 output_width_micro_repeats; - int32 output_residual_width; - int32 workspace_width_micro_repeats; + int32_t width_macro_count; + int32_t input_width_overall_micro_repeats; + int32_t input_width_micro_repeats; + int32_t residual_width; + int32_t output_width_overall_micro_repeats; + int32_t output_width_micro_repeats; + int32_t output_residual_width; + int32_t workspace_width_micro_repeats; // - int32 height_macro_count; - int32 inbound_block_height; - int32 outbound_block_height; - int32 input_height_stride; - int32 output_height_stride; - int32 workspace_height_stride; + int32_t height_macro_count; + int32_t inbound_block_height; + int32_t outbound_block_height; + int32_t input_height_stride; + int32_t output_height_stride; + int32_t workspace_height_stride; // - int32 four_over_stride; + int32_t four_over_stride; // - const int32* output_multiplier_per_channel; - const int32* output_shift_per_channel; + const int32_t* output_multiplier_per_channel; + const int32_t* output_shift_per_channel; }; -template +template struct DepthwiseConvWindow {}; -template +template struct DepthwiseConvWindowPerChannel {}; enum class EdgeType { kCorner, kHorizontal, kVertical, kCenter }; @@ -397,13 +397,13 @@ struct DepthwiseConvPartialPerChannel {}; // this is the cache line size. template inline void ShuffleInput(const T* input_ptr, int64_t input_depth, - int32 input_width, int32 input_height, - int64_t output_depth, int32 output_width, - int32 output_height, T* output_ptr) { + int32_t input_width, int32_t input_height, + int64_t output_depth, int32_t output_width, + int32_t output_height, T* output_ptr) { const int64_t input_row_size = input_depth * input_width; - for (int32 y = 0; y < output_height; y++) { + for (int32_t y = 0; y < output_height; y++) { const T* ptr = input_ptr; - for (int32 x = 0; x < output_width; x++) { + for (int32_t x = 0; x < output_width; x++) { memcpy(output_ptr, ptr, output_depth); output_ptr += output_depth; ptr += input_depth; @@ -413,21 +413,21 @@ inline void ShuffleInput(const T* input_ptr, int64_t input_depth, } // Calculates the input size depending on stride and output. -inline int32 get_shuffle_input_size(int32 stride, int32 output) { +inline int32_t get_shuffle_input_size(int32_t stride, int32_t output) { return stride * (output - 1) + 3; } // Indicates the input and output dimensions used when shuffling input // activations. struct ShuffleParams { - int32 output_width; - int32 output_height; - int32 input_width; - int32 input_height; + int32_t output_width; + int32_t output_height; + int32_t input_width; + int32_t input_height; ShuffleParams() = default; - ShuffleParams(int32 output_width, int32 output_height, int32 stride_width, - int32 stride_height) + ShuffleParams(int32_t output_width, int32_t output_height, + int32_t stride_width, int32_t stride_height) : output_width(output_width), output_height(output_height), input_width(get_shuffle_input_size(stride_width, output_width)), @@ -438,17 +438,17 @@ template < QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8> inline bool Fast3x3FilterKernelSupported( const RuntimeShape& input_shape, const RuntimeShape& filter_shape, - int32 stride_width, int32 stride_height, int32 dilation_width_factor, - int32 dilation_height_factor, int32 pad_width, int32 pad_height, - int32 depth_multiplier, const RuntimeShape& output_shape, - int32 output_shift, const int32* output_shift_ptr = nullptr) { - const int32 input_height = input_shape.Dims(1); - const int32 input_width = input_shape.Dims(2); - const int32 input_depth = input_shape.Dims(3); - const int32 filter_height = filter_shape.Dims(1); - const int32 filter_width = filter_shape.Dims(2); - const int32 output_height = output_shape.Dims(1); - const int32 output_width = output_shape.Dims(2); + int32_t stride_width, int32_t stride_height, int32_t dilation_width_factor, + int32_t dilation_height_factor, int32_t pad_width, int32_t pad_height, + int32_t depth_multiplier, const RuntimeShape& output_shape, + int32_t output_shift, const int32_t* output_shift_ptr = nullptr) { + const int32_t input_height = input_shape.Dims(1); + const int32_t input_width = input_shape.Dims(2); + const int32_t input_depth = input_shape.Dims(3); + const int32_t filter_height = filter_shape.Dims(1); + const int32_t filter_width = filter_shape.Dims(2); + const int32_t output_height = output_shape.Dims(1); + const int32_t output_width = output_shape.Dims(2); bool supported = filter_width == 3 && filter_height == 3 && depth_multiplier == 1 && @@ -466,14 +466,14 @@ inline bool Fast3x3FilterKernelSupported( // Handle case where padding is zero but padding type is not kValid. // This would require special boundary case handling that is not supported. - const int32 out_x = output_width - 1; - const int32 out_y = output_height - 1; + const int32_t out_x = output_width - 1; + const int32_t out_y = output_height - 1; - const int32 in_x_origin = (out_x * stride_width) - pad_width; - const int32 in_y_origin = (out_y * stride_height) - pad_height; + const int32_t in_x_origin = (out_x * stride_width) - pad_width; + const int32_t in_y_origin = (out_y * stride_height) - pad_height; - const int32 in_x_end = in_x_origin + filter_width; - const int32 in_y_end = in_y_origin + filter_height; + const int32_t in_x_end = in_x_origin + filter_width; + const int32_t in_y_end = in_y_origin + filter_height; // Supported only if filter on the right and bottom boundary lies completely // within the input if padding is zero. @@ -525,7 +525,7 @@ struct ProcessPerDepth { template + int32_t max_padding> struct PackMacroBlock { // Routine is contained in a static Run() method. No default template version // is supplied, so that all implementations are deliberate choices of template @@ -543,7 +543,7 @@ struct PackMacroBlock { // See the comments preceding DepthwiseConvDotProduct3x3() for further notes. template + DepthwiseConvDepthMultiplication depth_multiplication, int32_t stride> struct KernelMacroBlock { // Routine is contained in a static Run() method. No default template version // is supplied, so that all implementations are deliberate choices of template diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 4c631fc8b45ae1..68f3da1a2936c1 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -41,9 +41,9 @@ struct QuantizedDepthwiseConvKernel {}; template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8x2_t filter_u8; filter_u8.val[0] = vld1_u8(filter_ptr); @@ -88,9 +88,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. const uint8x8_t filter_u8 = vld1_u8(filter_ptr); const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); @@ -156,9 +156,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. const uint8x8_t filter_u8 = vld1_u8(filter_ptr); const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); @@ -226,9 +226,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. int16x8_t filter[2]; for (int i = 0; i < 2; i++) { @@ -303,9 +303,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -369,9 +369,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -483,9 +483,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -529,7 +529,7 @@ struct QuantizedDepthwiseConvKernel { int32x2_t acc = vld1_s32(acc_buffer_ptr); // Load the inputs, add input_offset. - const uint32 input = *input_ptr++ + input_offset; + const uint32_t input = *input_ptr++ + input_offset; // Multiply-accumulate acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); @@ -543,9 +543,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -624,7 +624,7 @@ struct QuantizedDepthwiseConvKernel { int32x4_t acc = vld1q_s32(acc_buffer_ptr); // Load the inputs, add input_offset. - const uint32 input = *input_ptr++ + input_offset; + const uint32_t input = *input_ptr++ + input_offset; // Multiply-accumulate acc = vmlal_n_s16(acc, filter, input); @@ -638,9 +638,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -708,9 +708,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. int16x8_t filter[2]; for (int i = 0; i < 2; i++) { @@ -793,15 +793,15 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // We will have to duplicate bytes in a NEON register, 3-fold. // We will do that by register-level table-look-up using VTBL instructions. // Here we prepare the registers containing the table-lookup indices. - static const uint8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, - {2, 3, 3, 3, 4, 4, 4, 5}, - {5, 5, 6, 6, 6, 7, 7, 7}}; + static const uint8_t dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; uint8x8_t dup3_indices[3]; for (int i = 0; i < 3; i++) { dup3_indices[i] = vld1_u8(dup3_indices_array[i]); @@ -809,8 +809,8 @@ struct QuantizedDepthwiseConvKernel { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const uint8* local_filter_ptr = filter_ptr; - const uint8* local_input_ptr = input_ptr; + const uint8_t* local_filter_ptr = filter_ptr; + const uint8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 8 input channels at a time. for (; ic <= input_depth - 8; ic += 8) { @@ -864,10 +864,10 @@ struct QuantizedDepthwiseConvKernel { } // Handle one input channel at a time. for (; ic < input_depth; ic++) { - const int16 input_val = *local_input_ptr++ + input_offset; + const int16_t input_val = *local_input_ptr++ + input_offset; for (int i = 0; i < 3; i++) { - const int16 filter_val = local_filter_ptr[i] + filter_offset; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } local_filter_ptr += 3; } @@ -879,13 +879,13 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const uint8* local_filter_ptr = filter_ptr; - const uint8* local_input_ptr = input_ptr; + const uint8_t* local_filter_ptr = filter_ptr; + const uint8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 8 input channels at a time. for (; ic <= input_depth - 8; ic += 8) { @@ -929,10 +929,10 @@ struct QuantizedDepthwiseConvKernel { // Handle one input channel at a time. for (; ic < input_depth; ic++) { // Load the inputs. - const int16 input_val = *local_input_ptr++ + input_offset; + const int16_t input_val = *local_input_ptr++ + input_offset; for (int i = 0; i < 2; i++) { - const int16 filter_val = local_filter_ptr[i] + filter_offset; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t filter_val = local_filter_ptr[i] + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } local_filter_ptr += 2; } @@ -944,13 +944,13 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const uint8* local_filter_ptr = filter_ptr; - const uint8* local_input_ptr = input_ptr; + const uint8_t* local_filter_ptr = filter_ptr; + const uint8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 16 input channels at a time. for (; ic <= input_depth - 16; ic += 16) { @@ -1055,9 +1055,9 @@ struct QuantizedDepthwiseConvKernel { } // Handle one input channel at a time. for (; ic < input_depth; ic++) { - const int16 input_val = *local_input_ptr++ + input_offset; - const int16 filter_val = *local_filter_ptr++ + filter_offset; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t input_val = *local_input_ptr++ + input_offset; + const int16_t filter_val = *local_filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } input_ptr += input_ptr_increment; } @@ -1067,9 +1067,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8[2]; for (int i = 0; i < 2; i++) { @@ -1121,9 +1121,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. const uint8x8_t filter_u8 = vld1_u8(filter_ptr); const int16x8_t filter_s16 = vreinterpretq_s16_u16(vmovl_u8(filter_u8)); @@ -1155,9 +1155,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8[2]; for (int i = 0; i < 2; i++) { @@ -1172,9 +1172,9 @@ struct QuantizedDepthwiseConvKernel { } // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - uint8 input_u8 = *input_ptr; + uint8_t input_u8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_u8 + input_offset); + int16_t input = static_cast(input_u8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc[4]; for (int i = 0; i < 4; i++) { @@ -1199,9 +1199,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); @@ -1217,9 +1217,9 @@ struct QuantizedDepthwiseConvKernel { filter_3 = vaddq_s16(filter_3, vdupq_n_s16(filter_offset)); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - uint8 input_u8 = *input_ptr; + uint8_t input_u8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_u8 + input_offset); + int16_t input = static_cast(input_u8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); @@ -1255,9 +1255,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8. // We load the first 16 bytes into filter_u8_{0,1} as usual. @@ -1275,9 +1275,9 @@ struct QuantizedDepthwiseConvKernel { filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset)); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - uint8 input_u8 = *input_ptr; + uint8_t input_u8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_u8 + input_offset); + int16_t input = static_cast(input_u8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); @@ -1304,18 +1304,18 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. const uint8x8_t filter_u8 = vld1_u8(filter_ptr); const int16x8_t filter = vaddq_s16( vreinterpretq_s16_u16(vmovl_u8(filter_u8)), vdupq_n_s16(filter_offset)); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - uint8 input_u8 = *input_ptr; + uint8_t input_u8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_u8 + input_offset); + int16_t input = static_cast(input_u8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc[2]; for (int i = 0; i < 2; i++) { @@ -1336,9 +1336,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8 = vdup_n_u8(0); filter_u8 = vset_lane_u8(filter_ptr[0], filter_u8, 0); @@ -1357,11 +1357,11 @@ struct QuantizedDepthwiseConvKernel { int32x4_t acc = vld1q_s32(acc_buffer_ptr); // Load the inputs, add input_offset. uint16x4_t input_u16 = vdup_n_u16(0); - input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], - input_u16, 0); + input_u16 = vset_lane_u16( + (reinterpret_cast(input_ptr))[0], input_u16, 0); input_ptr += input_ptr_increment; - input_u16 = vset_lane_u16((reinterpret_cast(input_ptr))[0], - input_u16, 1); + input_u16 = vset_lane_u16( + (reinterpret_cast(input_ptr))[0], input_u16, 1); input_ptr += input_ptr_increment; const int16x4_t input_s16 = vreinterpret_s16_u16( vget_low_u16(vmovl_u8(vreinterpret_u8_u16(input_u16)))); @@ -1399,9 +1399,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { if (num_output_pixels <= 0) { return; } @@ -1463,9 +1463,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const uint8* input_ptr, int16 input_offset, - int input_ptr_increment, const uint8* filter_ptr, - int16 filter_offset, int32* acc_buffer_ptr) { + const uint8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const uint8_t* filter_ptr, + int16_t filter_offset, int32_t* acc_buffer_ptr) { // Load the filters, add filter_offset. uint8x8_t filter_u8_0 = vld1_u8(filter_ptr); uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 4); @@ -1512,14 +1512,12 @@ struct QuantizedDepthwiseConvKernel { // Accumulates the effect of one row of the filter, on a segment of one row // of the output, accessing the corresponding one row of the input. template -void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, - int input_depth, int input_width, - const uint8* input_data, int16 input_offset, - int pad_width, int depth_multiplier, - int filter_width, const uint8* filter_data, - int16 filter_offset, int out_x_buffer_start, - int out_x_buffer_end, int output_depth, - int32* acc_buffer) { +void QuantizedDepthwiseConvAccumRow( + int stride, int dilation_factor, int input_depth, int input_width, + const uint8_t* input_data, int16_t input_offset, int pad_width, + int depth_multiplier, int filter_width, const uint8_t* filter_data, + int16_t filter_offset, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, int32_t* acc_buffer) { ruy::profiler::ScopeLabel label(TFLITE_PRETTY_FUNCTION); // Consistency check parameters. This is important in particular to ensure // that we keep the number of template instantiations minimal, so we don't @@ -1535,7 +1533,7 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, } TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); const int input_ptr_increment = stride * input_depth; - const uint8* filter_base_ptr = filter_data; + const uint8_t* filter_base_ptr = filter_data; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { // For the current (filter_x, filter_y) point in the filter, // compute the boundaries of the corresponding output row segment. @@ -1571,11 +1569,11 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, const int out_x_loop_end = std::min(out_x_buffer_end, out_x_loop_end_unclamped); - int32* acc_buffer_ptr = + int32_t* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x; - const uint8* input_ptr = input_data + in_x_origin * input_depth; + const uint8_t* input_ptr = input_data + in_x_origin * input_depth; const int num_output_pixels = out_x_loop_end - out_x_loop_start; QuantizedDepthwiseConvKernel< kAllowStrided, kFixedInputDepth, @@ -1590,12 +1588,12 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, // generic fallback of DepthwiseConvAccumRow, portable, non-templatized. inline void QuantizedDepthwiseConvAccumRowGeneric( int stride, int dilation_factor, int input_depth, int input_width, - const uint8* input_data, int16 input_offset, int pad_width, - int depth_multiplier, int filter_width, const uint8* filter_data, - int16 filter_offset, int out_x_buffer_start, int out_x_buffer_end, - int output_depth, int32* acc_buffer) { + const uint8_t* input_data, int16_t input_offset, int pad_width, + int depth_multiplier, int filter_width, const uint8_t* filter_data, + int16_t filter_offset, int out_x_buffer_start, int out_x_buffer_end, + int output_depth, int32_t* acc_buffer) { ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)"); - const uint8* filter_base_ptr = filter_data; + const uint8_t* filter_base_ptr = filter_data; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { const int out_x_loop_start = std::max( out_x_buffer_start, @@ -1605,19 +1603,19 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( (pad_width + input_width - dilation_factor * filter_x + stride - 1) / stride); - int32* acc_buffer_ptr = + int32_t* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x; - const uint8* input_ptr = input_data + in_x_origin * input_depth; + const uint8_t* input_ptr = input_data + in_x_origin * input_depth; const int input_ptr_increment = (stride - 1) * input_depth; for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { - const uint8* filter_ptr = filter_base_ptr; + const uint8_t* filter_ptr = filter_base_ptr; for (int ic = 0; ic < input_depth; ++ic) { - const int16 input_val = *input_ptr++ + input_offset; + const int16_t input_val = *input_ptr++ + input_offset; for (int m = 0; m < depth_multiplier; m++) { - const int16 filter_val = *filter_ptr++ + filter_offset; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t filter_val = *filter_ptr++ + filter_offset; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } } input_ptr += input_ptr_increment; @@ -1628,8 +1626,8 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( // Initializes the accumulator buffer with bias values. inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, - const int32* bias_data, - int32* acc_buffer) { + const int32_t* bias_data, + int32_t* acc_buffer) { int i = 0; #ifdef USE_NEON if (output_depth == 1) { @@ -1701,21 +1699,21 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, inline void DepthwiseConvGeneral( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, int thread_start, int thread_end, int thread_dim) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, int thread_start, int thread_end, int thread_dim) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = params.padding_values.width; const int pad_height = params.padding_values.height; const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; @@ -1730,7 +1728,7 @@ inline void DepthwiseConvGeneral( const int output_width = output_shape.Dims(2); #ifdef USE_NEON const bool shift_left = (output_shift > 0); - const int32 multiplier_power_of_two = shift_left ? (1 << output_shift) : 1; + const int32_t multiplier_power_of_two = shift_left ? (1 << output_shift) : 1; #endif // The default Accbuffer size is 2048, will allocate a bigger memory if it's @@ -1739,11 +1737,11 @@ inline void DepthwiseConvGeneral( // a scratch tensor. static const int kStackAccBufferSize = 2048; int acc_buffer_size = kStackAccBufferSize; - int32 stack_acc_buffer[kStackAccBufferSize]; - int32* acc_buffer = stack_acc_buffer; - std::unique_ptr heap_acc_buffer; + int32_t stack_acc_buffer[kStackAccBufferSize]; + int32_t* acc_buffer = stack_acc_buffer; + std::unique_ptr heap_acc_buffer; if (kStackAccBufferSize < output_depth) { - heap_acc_buffer.reset(new int32[output_depth]); + heap_acc_buffer.reset(new int32_t[output_depth]); acc_buffer = heap_acc_buffer.get(); acc_buffer_size = output_depth; } @@ -1846,7 +1844,7 @@ inline void DepthwiseConvGeneral( break; } - uint8* output_ptr = output_data + output_ptr_offset; + uint8_t* output_ptr = output_data + output_ptr_offset; int batch_step = (output_height + row_start - row_end) * output_width * output_depth; for (int b = batch_start; b < batch_end; ++b) { @@ -2016,13 +2014,13 @@ inline void DepthwiseConvGeneral( // Handle leftover values, one by one. This is very slow. for (; i < num_output_values; i++) { - int32 acc = acc_buffer[i]; + int32_t acc = acc_buffer[i]; acc = MultiplyByQuantizedMultiplier(acc, output_multiplier, output_shift); acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - *output_ptr++ = static_cast(acc); + *output_ptr++ = static_cast(acc); } } } @@ -2035,15 +2033,15 @@ inline void DepthwiseConvGeneral( template inline void DepthwiseConvWithRounding( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, const CpuFlags& cpu_flags, int thread_start, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, const CpuFlags& cpu_flags, int thread_start, int thread_end, int thread_dim) { ruy::profiler::ScopeLabel label("DepthwiseConv/8bit"); const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; TFLITE_DCHECK_GE(dilation_width_factor, 1); @@ -2112,10 +2110,10 @@ inline void DepthwiseConvWithRounding( inline void DepthwiseConvImpl( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, const CpuFlags& cpu_flags, int thread_start, + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, const CpuFlags& cpu_flags, int thread_start, int thread_end, int thread_dim) { return DepthwiseConvWithRounding( params, input_shape, input_data, filter_shape, filter_data, bias_shape, diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h index 7dc040becb2795..633feb53528d0a 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h @@ -31,16 +31,16 @@ namespace optimized_ops { namespace depthwise_conv { #ifdef USE_NEON -inline int8x16_t util_vld1q_x8(const uint8* data_addr) { +inline int8x16_t util_vld1q_x8(const uint8_t* data_addr) { return vreinterpretq_s8_u8(vld1q_u8(data_addr)); } -inline int8x16_t util_vld1q_x8(const int8* data_addr) { +inline int8x16_t util_vld1q_x8(const int8_t* data_addr) { return vld1q_s8(data_addr); } -inline int8x8_t util_vld1_x8(const uint8* data_addr) { +inline int8x8_t util_vld1_x8(const uint8_t* data_addr) { return vreinterpret_s8_u8(vld1_u8(data_addr)); } -inline int8x8_t util_vld1_x8(const int8* data_addr) { +inline int8x8_t util_vld1_x8(const int8_t* data_addr) { return vld1_s8(data_addr); } #endif @@ -5785,7 +5785,7 @@ inline void DepthwiseConv3x3Filter( // Perform any necessary cache hinting and pre-writing. template struct WorkspacePrefetchWrite { - static inline void Run(int8 fill_data, int size, int8* workspace) {} + static inline void Run(int8_t fill_data, int size, int8_t* workspace) {} }; #if defined(__aarch64__) @@ -12867,7 +12867,7 @@ inline void DepthwiseConvDotProduct3x3Impl( const RuntimeShape& filter_shape, const typename QuantizationTypeImpl::ExternalType* filter_data, - const RuntimeShape& bias_shape, const int32* bias_data, + const RuntimeShape& bias_shape, const int32_t* bias_data, const RuntimeShape& output_shape, typename QuantizationTypeImpl::ExternalType* output_data, int thread_start, int thread_end, int thread_dim) { @@ -12937,8 +12937,8 @@ inline void DepthwiseConvDotProduct3x3Impl( // Kernel subroutines need to be able to operate consistently on an bias // array. Where there is no bias, we provide one filled with zeros. constexpr int kMinBiasLoad = 8; - int32 zero_bias_data[kMinBiasLoad]; - int32 bias_increment; + int32_t zero_bias_data[kMinBiasLoad]; + int32_t bias_increment; if (bias_data) { bias_increment = 4; } else { @@ -13103,9 +13103,9 @@ inline void DepthwiseConvDotProduct3x3Impl( // Filter workspace is for shuffle: only first depth/8 is used. // indexed as [depth/8][sub-block][height][depth][width]. TFLITE_DCHECK_EQ(kDepthwiseConvAdjustedBiasLimit % 8, 0); - int8 macroblock_workspace[kDepthwiseConvScratchWorkspaceSize]; - int32 adjusted_bias_data[kDepthwiseConvAdjustedBiasLimit]; - int8 filter_workspace[kDepthwiseConvAdjustedBiasLimit >> 3][3][2][4][4]; + int8_t macroblock_workspace[kDepthwiseConvScratchWorkspaceSize]; + int32_t adjusted_bias_data[kDepthwiseConvAdjustedBiasLimit]; + int8_t filter_workspace[kDepthwiseConvAdjustedBiasLimit >> 3][3][2][4][4]; // Output depth characterization. // @@ -13400,10 +13400,10 @@ inline void DepthwiseConvDotProduct3x3Impl( template inline void DepthwiseConvDotProduct3x3( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, int thread_start, int thread_end, int thread_dim) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, int thread_start, int thread_end, int thread_dim) { DepthwiseConvDotProduct3x3Impl< implementation, depthwise_conv::QuantizationType::kNonPerChannelUint8>( params, input_shape, input_data, filter_shape, filter_data, bias_shape, @@ -13414,10 +13414,10 @@ inline void DepthwiseConvDotProduct3x3( template inline void DepthwiseConvDotProduct3x3PerChannel( const DepthwiseParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - int thread_start, int thread_end, int thread_dim) { + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, int thread_start, int thread_end, int thread_dim) { DepthwiseConvDotProduct3x3Impl< implementation, depthwise_conv::QuantizationType::kPerChannelInt8>( params, input_shape, input_data, filter_shape, filter_data, bias_shape, diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h index 5ad334a6a06814..63dca21cd87dea 100644 --- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h +++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_transitional.h @@ -37,13 +37,13 @@ namespace depthwise_conv { #ifdef USE_NEON -inline void util_vst1_u8(uint8* data_addr, uint8x8_t reg) { +inline void util_vst1_u8(uint8_t* data_addr, uint8x8_t reg) { return vst1_u8(data_addr, reg); } -inline void util_vst1_x8(uint8* data_addr, int8x8_t reg) { +inline void util_vst1_x8(uint8_t* data_addr, int8x8_t reg) { return vst1_u8(data_addr, vreinterpret_u8_s8(reg)); } -inline void util_vst1_x8(int8* data_addr, int8x8_t reg) { +inline void util_vst1_x8(int8_t* data_addr, int8x8_t reg) { return vst1_s8(data_addr, reg); } @@ -94,15 +94,15 @@ struct ProcessPerDepth::kIntSymmetricZeroPoint; // Load filter data in, 8-bytes down depth / sub-block at a time. // // loaded_filter has dimensions height 3, width 4, sub-block 0 or 1, // depth 4. - uint8 loaded_filter[3][4][2][4]; + uint8_t loaded_filter[3][4][2][4]; for (int y = 0; y < 3; ++y) { for (int x = 0; x < 3; ++x) { memcpy(loaded_filter[y][x][0], &filter_block[3 * y * depth + x * depth], @@ -139,16 +139,16 @@ struct ProcessPerDepth::kIntSymmetricZeroPoint; TFLITE_DCHECK_GE(input_offset, -255); TFLITE_DCHECK_LE(input_offset, 0); // For instance, if input_offset == 128, no adjustment is needed. - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; for (int s = 0; s < 2; ++s) { for (int z = 0; z < 4; ++z) { @@ -161,17 +161,17 @@ struct ProcessPerDepthoutput_depth; const int depth_micro_repeats = function_params->depth_micro_repeats; const int bias_increment = function_params->bias_increment; - const int32 input_offset = function_params->input_offset; + const int32_t input_offset = function_params->input_offset; - int8 filter_bank[3][2][4][4]; - int32 adjusted_bias_block[2][4]; + int8_t filter_bank[3][2][4][4]; + int32_t adjusted_bias_block[2][4]; for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { FillFilterBank(depth, filter_data + 8 * j_depth, filter_bank); @@ -191,39 +191,40 @@ struct ProcessPerDepth struct ProcessPerDepth { - static inline void Run(const uint8* filter_data, const int32* bias_data, - int8* shuffled_filter_data, int32* adjusted_bias_data, + static inline void Run(const uint8_t* filter_data, const int32_t* bias_data, + int8_t* shuffled_filter_data, + int32_t* adjusted_bias_data, const DepthwiseConvDotProdParams* function_params) { const int depth = function_params->output_depth; const int depth_micro_repeats = function_params->depth_micro_repeats; const int bias_increment = function_params->bias_increment; // Simulate NEON-register transposition of subset of filter. - int8 filter_bank_a_0[4][4]; // Depth 4, width 4. - int8 filter_bank_a_1[4][4]; - int8 filter_bank_a_2[4][4]; - int8 filter_bank_b_0[4][4]; - int8 filter_bank_b_1[4][4]; - int8 filter_bank_b_2[4][4]; + int8_t filter_bank_a_0[4][4]; // Depth 4, width 4. + int8_t filter_bank_a_1[4][4]; + int8_t filter_bank_a_2[4][4]; + int8_t filter_bank_b_0[4][4]; + int8_t filter_bank_b_1[4][4]; + int8_t filter_bank_b_2[4][4]; // Load filter data in, essentially dropping the [depth/8] dimension, which // is equivalent to loading just the depth needed for one micro-block. // // loaded_filter has dimensions height 3, width 4, sub-block 0 or 1, // depth 4. - uint8 loaded_filter_0[4][2][4]; - uint8 loaded_filter_1[4][2][4]; - uint8 loaded_filter_2[4][2][4]; + uint8_t loaded_filter_0[4][2][4]; + uint8_t loaded_filter_1[4][2][4]; + uint8_t loaded_filter_2[4][2][4]; constexpr int kSymmetricZeroPoint = QuantizationTypeImpl::kIntSymmetricZeroPoint; - const int32 input_offset = function_params->input_offset; + const int32_t input_offset = function_params->input_offset; TFLITE_DCHECK_GE(input_offset, -255); TFLITE_DCHECK_LE(input_offset, 0); - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { - const uint8* filter_block = filter_data + 8 * j_depth; + const uint8_t* filter_block = filter_data + 8 * j_depth; // Filter data is provided as filter_block[3][3][depth/8][2][4]. // height 3, width 3, micro-blocks, sub-block 0 or 1, depth 4. @@ -273,8 +274,8 @@ struct ProcessPerDepth::ExternalType* filter_data, - const int32* bias_data, int8* shuffled_filter_data, - int32* adjusted_bias_data, + const int32_t* bias_data, int8_t* shuffled_filter_data, + int32_t* adjusted_bias_data, const DepthwiseConvDotProdParams* function_params) { const int depth = function_params->output_depth; const int depth_micro_repeats = function_params->depth_micro_repeats; @@ -319,14 +320,14 @@ struct ProcessPerDepth::kIntSymmetricZeroPoint; - constexpr uint8 kSignBit = + constexpr uint8_t kSignBit = QuantizationTypeImpl::kUint8SignBit; - const int32 input_offset = function_params->input_offset; + const int32_t input_offset = function_params->input_offset; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(input_offset, -255); TFLITE_DCHECK_LE(input_offset, 0); } - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; const int8x16_t ones_vector = vdupq_n_s8(1); // Simulate NEON-register transposition of subset of filter. @@ -440,8 +441,8 @@ struct ProcessPerDepth::ExternalType* filter_data, - const int32* bias_data, int8* shuffled_filter_data, - int32* adjusted_bias_data, + const int32_t* bias_data, int8_t* shuffled_filter_data, + int32_t* adjusted_bias_data, const DepthwiseConvDotProdParams* function_params) { ProcessPerDepthIntrinsics(filter_data, bias_data, shuffled_filter_data, adjusted_bias_data, function_params); @@ -449,7 +450,7 @@ struct ProcessPerDepth +template struct PackMacroBlock< DepthwiseConvImplementation::kUseCModel3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kNoMultiplication, max_padding> { @@ -457,11 +458,11 @@ struct PackMacroBlock< // // Requirement: depth_micro_repeats > 0. static inline void CopyMacroBlock( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const DepthwiseConvDotProdParams& function_params, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data) { + int8_t* scratch_block_data) { TFLITE_DCHECK_LE(max_padding, 1); // Strides. @@ -509,13 +510,13 @@ struct PackMacroBlock< constexpr int kSymmetricZeroPoint = QuantizationTypeImpl::kIntSymmetricZeroPoint; - const int32 input_offset_difference = + const int32_t input_offset_difference = function_params.input_offset + kSymmetricZeroPoint; // We load data into a temporary buffer and then save, to match subsequent // processing. This will make it easier to combine stages into one ASM // routine. - int8 tmp_load[4][2][4]; + int8_t tmp_load[4][2][4]; int copy_block_height = block_height; if (leading_height_padding) { @@ -552,7 +553,7 @@ struct PackMacroBlock< // each micro block. // Load, and apply symmetric offset. - int8* scratch_data = + int8_t* scratch_data = scratch_block_data + k_height * workspace_height_stride + j_width * 4 * 8 + i_depth * 4 * 8 * width_overall_micro_repeats; const typename QuantizationTypeImpl::ExternalType* @@ -589,7 +590,7 @@ struct PackMacroBlock< // equivalence of the two approaches. static inline void MicroTransposeBlocks( const DepthwiseConvDotProdParams& function_params, - int8* scratch_block_data) { + int8_t* scratch_block_data) { const int workspace_height_stride = function_params.workspace_height_stride; const int width_overall_micro_repeats = function_params.input_width_overall_micro_repeats; @@ -598,15 +599,15 @@ struct PackMacroBlock< // Transpositions are 4x4, but doing 2 at a time is more efficient in the // NEON code we are simulating. - int8 tmp_load[4][2][4]; // [width][sub-block][depth] - int8 tmp_transposed[4][2][4]; // [depth][sub-block][width] - int8 tmp_interleaved[2][4][4]; // [sub-block][depth][width] + int8_t tmp_load[4][2][4]; // [width][sub-block][depth] + int8_t tmp_transposed[4][2][4]; // [depth][sub-block][width] + int8_t tmp_interleaved[2][4][4]; // [sub-block][depth][width] // The outer 3 loops go through all the micro blocks in a macro block. for (int k_height = 0; k_height < block_height; ++k_height) { for (int j_width = 0; j_width < width_overall_micro_repeats; ++j_width) { for (int i_depth = 0; i_depth < depth_micro_repeats; ++i_depth) { - int8* scratch_data = + int8_t* scratch_data = scratch_block_data + k_height * workspace_height_stride + j_width * 4 * 8 + i_depth * 4 * 8 * width_overall_micro_repeats; // A. Load data @@ -639,10 +640,10 @@ struct PackMacroBlock< } static inline void Run( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { CopyMacroBlock(height_block_number, width_block_number, *function_params, input_block_data, scratch_block_data); @@ -650,15 +651,15 @@ struct PackMacroBlock< } }; -template +template struct PackMacroBlock< DepthwiseConvImplementation::kUseCModel3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kUnitInputDepth, max_padding> { static inline void Run( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { // Currently support for padding is limited to 1 on any side. TFLITE_DCHECK_LE(max_padding, 1); @@ -699,7 +700,7 @@ struct PackMacroBlock< constexpr int kSymmetricZeroPoint = QuantizationTypeImpl::kIntSymmetricZeroPoint; - const int32 input_offset_difference = + const int32_t input_offset_difference = function_params->input_offset + kSymmetricZeroPoint; int copy_block_height = block_height; @@ -741,7 +742,7 @@ struct PackMacroBlock< for (int k_height = 0; k_height < copy_block_height; ++k_height) { const typename QuantizationTypeImpl::ExternalType* input_data = input_block_data + k_height * input_height_stride; - int8* scratch_data = + int8_t* scratch_data = scratch_block_data + k_height * workspace_height_stride; // Handle leading padding. This is overwritten if there is no padding. @@ -775,10 +776,10 @@ struct PackMacroBlock { static inline void Run( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -807,12 +808,12 @@ struct PackMacroBlock::ExternalType* input_data = input_block_data; @@ -930,10 +931,10 @@ struct PackMacroBlock { static inline void Run( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { // Just use C model code for case of padding. Optimized versions merge the // modifications therein to handle padding. @@ -946,15 +947,15 @@ struct PackMacroBlock +template struct PackMacroBlock< DepthwiseConvImplementation::kUseUnwound3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kUnitInputDepth, max_padding> { static inline void Run( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -987,11 +988,11 @@ struct PackMacroBlock< padding_bottom > 0 && height_block_number == (function_params->height_macro_count - 1); - const int32 input_offset = function_params->input_offset; - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset = function_params->input_offset; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; // Work through one slice, by row, at a time. - int8* scratch_data_base = scratch_block_data; + int8_t* scratch_data_base = scratch_block_data; int copy_block_height = block_height; if (leading_height_padding) { @@ -1031,7 +1032,7 @@ struct PackMacroBlock< TFLITE_DCHECK_GE(copy_size, input_height_stride - 1); // This is used to simulate what should happen in registers. - int8 tmp_data[16]; + int8_t tmp_data[16]; int scratch_data_offset = 0; int input_block_offset = 0; @@ -1039,7 +1040,7 @@ struct PackMacroBlock< if (copy_size >= 16) { for (int k_height = 0; k_height < copy_block_height; ++k_height) { // Work through one slice, by row, at a time. - int8* scratch_data = scratch_data_base + scratch_data_offset; + int8_t* scratch_data = scratch_data_base + scratch_data_offset; int copy_done = 0; @@ -1109,7 +1110,7 @@ struct PackMacroBlock< } else if (copy_size >= 4) { for (int k_height = 0; k_height < copy_block_height; ++k_height) { // Work through one slice, by row, at a time. - int8* scratch_data = scratch_data_base + scratch_data_offset; + int8_t* scratch_data = scratch_data_base + scratch_data_offset; int copy_done = 0; @@ -1269,7 +1270,7 @@ struct PackMacroBlock::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { TFLITE_DCHECK_EQ(function_params->padding_bottom, 0); TFLITE_DCHECK_EQ(function_params->padding_top, 0); @@ -1288,7 +1289,7 @@ struct PackMacroBlockinput_depth; TFLITE_DCHECK_GE(depth_micro_repeats, 0); - constexpr uint8 kSignBit = + constexpr uint8_t kSignBit = QuantizationTypeImpl::kUint8SignBit; const int micro_block_size = 4 * 8; const int depth_advance = width_overall_micro_repeats * micro_block_size; @@ -1308,7 +1309,7 @@ struct PackMacroBlock::ExternalType* @@ -1470,10 +1471,10 @@ struct PackMacroBlock::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { #ifdef __aarch64__ PreloadInputBlock(input_block_data, function_params); @@ -1489,12 +1490,12 @@ struct PackMacroBlock { static inline void PackMacroBlockIntrinsics( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { - constexpr uint8 kSignBit = + constexpr uint8_t kSignBit = QuantizationTypeImpl::kUint8SignBit; const int workspace_height_stride = @@ -1538,8 +1539,8 @@ struct PackMacroBlock 0 && height_block_number == (function_params->height_macro_count - 1); - const int32 input_offset = function_params->input_offset; - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset = function_params->input_offset; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON // code. Note the blocks of 4x4 are still interleaved down the depth. @@ -1550,7 +1551,7 @@ struct PackMacroBlock::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { #ifdef __aarch64__ PreloadInputBlock(input_block_data, function_params); @@ -1905,10 +1906,10 @@ struct PackMacroBlock { static inline void PackMacroBlockIntrinsics( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -1941,11 +1942,11 @@ struct PackMacroBlock 0 && height_block_number == (function_params->height_macro_count - 1); - const int32 input_offset = function_params->input_offset; - const int32 input_offset_difference = input_offset + kSymmetricZeroPoint; + const int32_t input_offset = function_params->input_offset; + const int32_t input_offset_difference = input_offset + kSymmetricZeroPoint; // Work through one slice, by row, at a time. - int8* scratch_data_base = scratch_block_data; + int8_t* scratch_data_base = scratch_block_data; int copy_block_height = block_height; if (leading_height_padding) { @@ -1987,7 +1988,7 @@ struct PackMacroBlock::kUint8SignBit; // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON @@ -2010,7 +2011,7 @@ struct PackMacroBlock( + half_work_reg = vld1_lane_s8(reinterpret_cast( input_block_data + input_block_offset), half_work_reg, 1); half_work_reg = - vld1_lane_s8(reinterpret_cast(input_block_data + - input_block_offset + 1), + vld1_lane_s8(reinterpret_cast( + input_block_data + input_block_offset + 1), half_work_reg, 2); half_work_reg = - vld1_lane_s8(reinterpret_cast(input_block_data + - input_block_offset + 2), + vld1_lane_s8(reinterpret_cast( + input_block_data + input_block_offset + 2), half_work_reg, 3); if (quantization_type == QuantizationType::kNonPerChannelUint8) { @@ -2222,7 +2223,7 @@ struct PackMacroBlock( + reinterpret_cast( input_block_data + input_block_offset + copy_size - 1 - i), half_work_reg, 0); } @@ -2269,10 +2270,10 @@ struct PackMacroBlock::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { #ifdef __aarch64__ PreloadInputBlock(input_block_data, function_params); @@ -2290,10 +2291,10 @@ struct PackMacroBlock { static inline void PackMacroBlockIntrinsics( - int32 height_block_number, int32 width_block_number, + int32_t height_block_number, int32_t width_block_number, const typename QuantizationTypeImpl::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -2313,7 +2314,7 @@ struct PackMacroBlock::kUint8SignBit; // Transpositions are 4x4, but doing 2 at a time is more efficient in NEON @@ -2350,7 +2351,7 @@ struct PackMacroBlock( + reinterpret_cast( input_block_data + input_block_offset + copy_size - 1 - i), half_work_reg, 0); } @@ -2509,10 +2510,10 @@ struct PackMacroBlock::ExternalType* input_block_data, - int8* scratch_block_data, + int8_t* scratch_block_data, const DepthwiseConvDotProdParams* function_params) { #ifdef __aarch64__ PreloadInputBlock(input_block_data, function_params); @@ -2529,7 +2530,7 @@ struct PackMacroBlock 0 || residual_depth > 0. -template +template struct KernelMacroBlock< DepthwiseConvImplementation::kUseCModel3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kNoMultiplication, stride> { @@ -2556,25 +2557,25 @@ struct KernelMacroBlock< int workspace_height_stride, int width_micro_stride, bool no_right_block, - const int8* input_block, - int8 selected_data[3][4][4]) { + const int8_t* input_block, + int8_t selected_data[3][4][4]) { TFLITE_DCHECK_GE(offset, 0); TFLITE_DCHECK_LT(offset, 4); // The input banks have same format as selected_data. - int8 left_bank[3][4][4]; - int8 right_bank[3][4][4]; + int8_t left_bank[3][4][4]; + int8_t right_bank[3][4][4]; // Work through one slice, by row, at a time. for (int k_height = 0; k_height < 3; ++k_height) { // Simulate demangling of mangled storage arrangement. - const int8* left_input_block = + const int8_t* left_input_block = &input_block[k_height * workspace_height_stride + sub_block * 2 * 8]; memcpy(left_bank[k_height][0], left_input_block, 16); if (no_right_block) { memset(right_bank[k_height][0], 0, 16); } else { - const int8* right_input_block = + const int8_t* right_input_block = &input_block[k_height * workspace_height_stride + sub_block * 2 * 8 + width_micro_stride]; memcpy(right_bank[k_height][0], right_input_block, 16); @@ -2591,19 +2592,19 @@ struct KernelMacroBlock< // Straight implementation of 3x3 filter within sub-micro block. static inline void Calculate3x3FilterOutput( const DepthwiseConvDotProdParams& params, int sub_block, - const int8 selected_data[3][4][4], const int8 filter_bank[3][2][4][4], - const int32* bias_data, uint8 output_values[4]) { - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 output_multiplier = params.output_multiplier; - const int32 output_shift = params.output_shift; - const int32 output_offset = params.output_offset; + const int8_t selected_data[3][4][4], const int8_t filter_bank[3][2][4][4], + const int32_t* bias_data, uint8_t output_values[4]) { + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + const int32_t output_multiplier = params.output_multiplier; + const int32_t output_shift = params.output_shift; + const int32_t output_offset = params.output_offset; for (int d = 0; d < 4; ++d) { - int32 acc = 0; + int32_t acc = 0; for (int y = 0; y < 3; ++y) { for (int x = 0; x < 4; ++x) { - int32 input_val = selected_data[y][d][x]; - int32 filter_val = filter_bank[y][sub_block][d][x]; + int32_t input_val = selected_data[y][d][x]; + int32_t filter_val = filter_bank[y][sub_block][d][x]; acc += filter_val * input_val; } } @@ -2614,13 +2615,13 @@ struct KernelMacroBlock< acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_values[d] = static_cast(acc); + output_values[d] = static_cast(acc); } } - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -2649,10 +2650,10 @@ struct KernelMacroBlock< constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; // Simulate NEON-register transposition of subset of filter. - int8 filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. + int8_t filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. // Simulate NEON-register input data concatenation + sub-selection. - int8 sub_selected_input_data[3][4][4]; // Height 3, depth 4, width 4. - uint8 output_values[4]; // Depth 4. + int8_t sub_selected_input_data[3][4][4]; // Height 3, depth 4, width 4. + uint8_t output_values[4]; // Depth 4. // The outer 3 loops go through all the micro blocks in a macro block, and // separately treat the two sub-blocks within each micro block. @@ -2663,11 +2664,11 @@ struct KernelMacroBlock< for (int s = 0; s < 2; ++s) { for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + workspace_height_stride * k_height * stride_val + depth_micro_stride * j_depth; - uint8* output_data = + uint8_t* output_data = output_block_data + output_height_stride * k_height + 8 * j_depth; for (int i_width = 0; i_width < output_width_overall_micro_repeats; @@ -2677,7 +2678,7 @@ struct KernelMacroBlock< : four_over_stride; const bool no_right_block = (output_width - 1) * stride_val < 2; TFLITE_DCHECK_LE(output_width * stride_val, 4); - const int8* input_data = + const int8_t* input_data = scratch_data + width_micro_stride * i_width; // Iterate over input width shifts within sub-micro blocks. for (int x = 0; x < output_width; ++x) { @@ -2706,7 +2707,7 @@ struct KernelMacroBlock< // Parameters for repeats and residual sizes are in terms of outputs. // // Requirement: depth_micro_repeats > 0 || residual_depth > 0. -template +template struct KernelMacroBlock< DepthwiseConvImplementation::kUseCModel3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kUnitInputDepth, stride> { @@ -2732,8 +2733,8 @@ struct KernelMacroBlock< static inline void ConcatenateInputSubBlocks(int offset, int workspace_height_stride, bool no_right_block, - const int8* input_block, - int8 selected_data[3][4]) { + const int8_t* input_block, + int8_t selected_data[3][4]) { TFLITE_DCHECK_GE(offset, 0); TFLITE_DCHECK_LT(offset, 4); if (no_right_block) { @@ -2753,21 +2754,21 @@ struct KernelMacroBlock< // Straight implementation of 3x3 filter within sub-micro block. static inline void Calculate3x3FilterOutput( const DepthwiseConvDotProdParams& function_params, int sub_block, - const int8 selected_data[3][4], const int8 filter_bank[3][2][4][4], - const int32* bias_data, uint8 output_values[4]) { - const int32 output_activation_min = + const int8_t selected_data[3][4], const int8_t filter_bank[3][2][4][4], + const int32_t* bias_data, uint8_t output_values[4]) { + const int32_t output_activation_min = function_params.quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params.quantized_activation_max; - const int32 output_multiplier = function_params.output_multiplier; - const int32 output_shift = function_params.output_shift; - const int32 output_offset = function_params.output_offset; + const int32_t output_multiplier = function_params.output_multiplier; + const int32_t output_shift = function_params.output_shift; + const int32_t output_offset = function_params.output_offset; for (int d = 0; d < 4; ++d) { - int32 acc = 0; + int32_t acc = 0; for (int y = 0; y < 3; ++y) { for (int x = 0; x < 4; ++x) { - int32 input_val = selected_data[y][x]; - int32 filter_val = filter_bank[y][sub_block][d][x]; + int32_t input_val = selected_data[y][x]; + int32_t filter_val = filter_bank[y][sub_block][d][x]; acc += filter_val * input_val; } } @@ -2778,13 +2779,13 @@ struct KernelMacroBlock< acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_values[d] = static_cast(acc); + output_values[d] = static_cast(acc); } } - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -2810,10 +2811,10 @@ struct KernelMacroBlock< constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; // Simulate NEON-register transposition of subset of filter. - int8 filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. + int8_t filter_bank[3][2][4][4]; // Height 3, sub-block, depth 4, width 4. // Simulate NEON-register input data concatenation + sub-selection. - int8 sub_selected_input_data[3][4]; // Height 3, depth 4, width 4. - uint8 output_values[4]; // Depth 4. + int8_t sub_selected_input_data[3][4]; // Height 3, depth 4, width 4. + uint8_t output_values[4]; // Depth 4. // The outer 3 loops go through all the micro blocks in a macro block, and // separately treat the two sub-blocks within each micro block. @@ -2824,10 +2825,10 @@ struct KernelMacroBlock< for (int s = 0; s < 2; ++s) { for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + workspace_height_stride * k_height * stride_val; - uint8* output_data = + uint8_t* output_data = output_block_data + output_height_stride * k_height + 8 * j_depth; for (int i_width = 0; i_width < output_width_overall_micro_repeats; @@ -2839,7 +2840,7 @@ struct KernelMacroBlock< output_width_overall_micro_repeats == workspace_width_micro_repeats; TFLITE_DCHECK_LE(output_width * stride_val, 4); - const int8* input_data = scratch_data + 4 * i_width; + const int8_t* input_data = scratch_data + 4 * i_width; // Iterate over input width shifts within 4x4 blocks. for (int x = 0; x < output_width; ++x) { ConcatenateInputSubBlocks(x * stride_val, workspace_height_stride, @@ -2865,13 +2866,13 @@ struct KernelMacroBlock< // // This section is only compiled when kUseUnwound3x3DotProduct versions of // templated functions are selected. -template +template struct KernelMacroBlock< DepthwiseConvImplementation::kUseUnwound3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kNoMultiplication, stride> { - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -2896,31 +2897,31 @@ struct KernelMacroBlock< const int depth_micro_stride = width_micro_stride * input_width_overall_micro_repeats; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; // Simulate NEON-register transposition of subset of filter. - int8 filter_bank_a_0[4][4]; // Depth 4, width 4. - int8 filter_bank_a_1[4][4]; - int8 filter_bank_a_2[4][4]; - int8 filter_bank_b_0[4][4]; - int8 filter_bank_b_1[4][4]; - int8 filter_bank_b_2[4][4]; + int8_t filter_bank_a_0[4][4]; // Depth 4, width 4. + int8_t filter_bank_a_1[4][4]; + int8_t filter_bank_a_2[4][4]; + int8_t filter_bank_b_0[4][4]; + int8_t filter_bank_b_1[4][4]; + int8_t filter_bank_b_2[4][4]; // Simulate NEON-register input data concatenation + sub-selection. // Also sub-block, height 3, depth 4, width 4. - uint8 output_values[4]; // Sub-block, depth 4. + uint8_t output_values[4]; // Sub-block, depth 4. // selected_data has format Depth 4, width 4. - int8 left_bank_0[4][4]; - int8 left_bank_1[4][4]; - int8 left_bank_2[4][4]; - int8 right_bank_0[4][4]; - int8 right_bank_1[4][4]; - int8 right_bank_2[4][4]; + int8_t left_bank_0[4][4]; + int8_t left_bank_1[4][4]; + int8_t left_bank_2[4][4]; + int8_t right_bank_0[4][4]; + int8_t right_bank_1[4][4]; + int8_t right_bank_2[4][4]; memset(right_bank_0[0], 0, 16); memset(right_bank_1[0], 0, 16); memset(right_bank_2[0], 0, 16); @@ -2928,7 +2929,7 @@ struct KernelMacroBlock< constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { - const int8* filter_block = + const int8_t* filter_block = filter_workspace + shuffled_filter_increment * j_depth; memcpy(filter_bank_a_0, filter_block, 16); @@ -2941,13 +2942,13 @@ struct KernelMacroBlock< for (int s = 0; s < 2; ++s) { // Work through one slice, by row, at a time. for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + workspace_height_stride * k_height * stride_val + depth_micro_stride * j_depth; - uint8* output_data = + uint8_t* output_data = output_block_data + output_height_stride * k_height + 8 * j_depth; - const int8* input_data_0 = scratch_data + s * 2 * 8; + const int8_t* input_data_0 = scratch_data + s * 2 * 8; // Load first sub-micro block of data into operational banks. memcpy(left_bank_0[0], input_data_0, 16); @@ -2961,7 +2962,7 @@ struct KernelMacroBlock< ? residual_width : four_over_stride; TFLITE_DCHECK_LE(output_width * stride_val, 4); - const int8* input_data = + const int8_t* input_data = input_data_0 + width_micro_stride * i_width; const bool no_right_block = (output_width - 1) * stride_val < 2; @@ -2981,20 +2982,20 @@ struct KernelMacroBlock< for (int x = 0; x < output_width; ++x) { // Operate on depth of 4 in batches. for (int d = 0; d < 4; ++d) { - int32 acc = 0; + int32_t acc = 0; for (int x = 0; x < 4; ++x) { - int32 input_val = left_bank_0[d][x]; - int32 filter_val = filter_bank_a_0[d][x]; + int32_t input_val = left_bank_0[d][x]; + int32_t filter_val = filter_bank_a_0[d][x]; acc += filter_val * input_val; } for (int x = 0; x < 4; ++x) { - int32 input_val = left_bank_1[d][x]; - int32 filter_val = filter_bank_a_1[d][x]; + int32_t input_val = left_bank_1[d][x]; + int32_t filter_val = filter_bank_a_1[d][x]; acc += filter_val * input_val; } for (int x = 0; x < 4; ++x) { - int32 input_val = left_bank_2[d][x]; - int32 filter_val = filter_bank_a_2[d][x]; + int32_t input_val = left_bank_2[d][x]; + int32_t filter_val = filter_bank_a_2[d][x]; acc += filter_val * input_val; } acc += bias_data[d]; @@ -3004,7 +3005,7 @@ struct KernelMacroBlock< acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_values[d] = static_cast(acc); + output_values[d] = static_cast(acc); } for (int d = 0; d < 4; ++d) { @@ -3079,13 +3080,13 @@ struct KernelMacroBlock< } }; -template +template struct KernelMacroBlock< DepthwiseConvImplementation::kUseUnwound3x3DotProduct, quantization_type, DepthwiseConvDepthMultiplication::kUnitInputDepth, stride> { - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { const int workspace_height_stride = function_params->workspace_height_stride; @@ -3103,13 +3104,13 @@ struct KernelMacroBlock< const int output_height_stride = function_params->output_height_stride; const int bias_increment = function_params->bias_increment; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; TFLITE_DCHECK(depth_micro_repeats > 0); @@ -3118,22 +3119,22 @@ struct KernelMacroBlock< constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; // Simulate NEON-register transposition of subset of filter. - int8 filter_bank_a_0[4][4]; // Depth 4, width 4. - int8 filter_bank_a_1[4][4]; - int8 filter_bank_a_2[4][4]; - int8 filter_bank_b_0[4][4]; - int8 filter_bank_b_1[4][4]; - int8 filter_bank_b_2[4][4]; + int8_t filter_bank_a_0[4][4]; // Depth 4, width 4. + int8_t filter_bank_a_1[4][4]; + int8_t filter_bank_a_2[4][4]; + int8_t filter_bank_b_0[4][4]; + int8_t filter_bank_b_1[4][4]; + int8_t filter_bank_b_2[4][4]; // Simulate NEON-register input data concatenation + sub-selection. // Also sub-block, height 3, depth 4, width 4. - int8 input_bank_0[8]; - int8 input_bank_1[8]; - int8 input_bank_2[8]; + int8_t input_bank_0[8]; + int8_t input_bank_1[8]; + int8_t input_bank_2[8]; TFLITE_DCHECK_GE(depth_micro_repeats, 1); - uint8 output_values[2][4]; // Sub-block, depth 4. + uint8_t output_values[2][4]; // Sub-block, depth 4. for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { memcpy(filter_bank_a_0, filter_workspace, 16); @@ -3145,10 +3146,10 @@ struct KernelMacroBlock< // Work through one slice, by row, at a time. for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + workspace_height_stride * k_height * stride_val; - uint8* output_data = + uint8_t* output_data = output_block_data + output_height_stride * k_height + 8 * j_depth; memcpy(input_bank_0, scratch_data, 4); @@ -3162,7 +3163,7 @@ struct KernelMacroBlock< : four_over_stride; TFLITE_DCHECK_LE(output_width * stride_val, 4); - const int8* input_data = scratch_data + 4 * i_width; + const int8_t* input_data = scratch_data + 4 * i_width; memcpy(input_bank_0 + 4, input_data + 4, 4); memcpy(input_bank_1 + 4, input_data + workspace_height_stride + 4, 4); @@ -3177,16 +3178,16 @@ struct KernelMacroBlock< { const int s = 0; for (int d = 0; d < 4; ++d) { - int32 acc = bias_data[s * 4 + d]; + int32_t acc = bias_data[s * 4 + d]; for (int x = 0; x < 4; ++x) { - int32 input_val_0 = input_bank_0[offset + x]; - int32 filter_val_0 = filter_bank_a_0[d][x]; + int32_t input_val_0 = input_bank_0[offset + x]; + int32_t filter_val_0 = filter_bank_a_0[d][x]; acc += filter_val_0 * input_val_0; - int32 input_val_1 = input_bank_1[offset + x]; - int32 filter_val_1 = filter_bank_a_1[d][x]; + int32_t input_val_1 = input_bank_1[offset + x]; + int32_t filter_val_1 = filter_bank_a_1[d][x]; acc += filter_val_1 * input_val_1; - int32 input_val_2 = input_bank_2[offset + x]; - int32 filter_val_2 = filter_bank_a_2[d][x]; + int32_t input_val_2 = input_bank_2[offset + x]; + int32_t filter_val_2 = filter_bank_a_2[d][x]; acc += filter_val_2 * input_val_2; } acc = reference_ops::depthwise_conv::DepthwiseConvRound< @@ -3195,7 +3196,7 @@ struct KernelMacroBlock< acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_values[s][d] = static_cast(acc); + output_values[s][d] = static_cast(acc); output_data[s * 4 + d] = output_values[s][d]; } @@ -3203,16 +3204,16 @@ struct KernelMacroBlock< { const int s = 1; for (int d = 0; d < 4; ++d) { - int32 acc = bias_data[s * 4 + d]; + int32_t acc = bias_data[s * 4 + d]; for (int x = 0; x < 4; ++x) { - int32 input_val_0 = input_bank_0[offset + x]; - int32 filter_val_0 = filter_bank_b_0[d][x]; + int32_t input_val_0 = input_bank_0[offset + x]; + int32_t filter_val_0 = filter_bank_b_0[d][x]; acc += filter_val_0 * input_val_0; - int32 input_val_1 = input_bank_1[offset + x]; - int32 filter_val_1 = filter_bank_b_1[d][x]; + int32_t input_val_1 = input_bank_1[offset + x]; + int32_t filter_val_1 = filter_bank_b_1[d][x]; acc += filter_val_1 * input_val_1; - int32 input_val_2 = input_bank_2[offset + x]; - int32 filter_val_2 = filter_bank_b_2[d][x]; + int32_t input_val_2 = input_bank_2[offset + x]; + int32_t filter_val_2 = filter_bank_b_2[d][x]; acc += filter_val_2 * input_val_2; } acc = reference_ops::depthwise_conv::DepthwiseConvRound< @@ -3221,7 +3222,7 @@ struct KernelMacroBlock< acc += output_offset; acc = std::max(acc, output_activation_min); acc = std::min(acc, output_activation_max); - output_values[s][d] = static_cast(acc); + output_values[s][d] = static_cast(acc); output_data[s * 4 + d] = output_values[s][d]; } @@ -3270,8 +3271,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, uint8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8; @@ -3297,13 +3298,13 @@ struct KernelMacroBlock< const int depth_micro_stride = width_micro_stride * input_width_overall_micro_repeats; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LT(output_activation_min, 256); @@ -3319,13 +3320,13 @@ struct KernelMacroBlock< TFLITE_DCHECK_LT(output_offset, 32768); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const uint8x16_t output_activation_min_vec = - vdupq_n_u8(static_cast(output_activation_min)); + vdupq_n_u8(static_cast(output_activation_min)); const uint8x16_t output_activation_max_vec = - vdupq_n_u8(static_cast(output_activation_max)); + vdupq_n_u8(static_cast(output_activation_max)); - const int8* input_data_depthwise = scratch_block_data; + const int8_t* input_data_depthwise = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data_depthwise = output_block_data; for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { @@ -3363,11 +3364,11 @@ struct KernelMacroBlock< if (block_height == 4) { for (int s = 0; s < 2; ++s) { // Work through one slice, by row, at a time. - const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + const int8_t* input_data_base = input_data_depthwise + 2 * 8 * s; typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise + 4 * s; - const int8* next_input_data = input_data_base; + const int8_t* next_input_data = input_data_base; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -3796,7 +3797,7 @@ struct KernelMacroBlock< vshlq_n_u32(vreinterpretq_u32_s8(filter_reg_2_a), 8)); } } else { - const int8* input_data_base = input_data_depthwise; + const int8_t* input_data_base = input_data_depthwise; typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise; @@ -3806,7 +3807,7 @@ struct KernelMacroBlock< bias_data += kBiasIncrement; for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* next_input_data = input_data_base; + const int8_t* next_input_data = input_data_base; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -3909,9 +3910,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -3933,8 +3934,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, uint8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8; @@ -3966,13 +3967,13 @@ struct KernelMacroBlock< const int depth_micro_stride = width_micro_stride * input_width_overall_micro_repeats; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LT(output_activation_min, 256); @@ -3989,18 +3990,18 @@ struct KernelMacroBlock< // This version only does min/max on 64 bits. const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const uint8x8_t output_activation_min_vec = - vdup_n_u8(static_cast(output_activation_min)); + vdup_n_u8(static_cast(output_activation_min)); const uint8x8_t output_activation_max_vec = - vdup_n_u8(static_cast(output_activation_max)); + vdup_n_u8(static_cast(output_activation_max)); constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; TFLITE_DCHECK_LE(block_height, 2); for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { - const int8* filter_block = + const int8_t* filter_block = filter_workspace + shuffled_filter_increment * j_depth; if (block_height == 2) { @@ -4014,11 +4015,11 @@ struct KernelMacroBlock< filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + depth_micro_stride * j_depth; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; - const int8* input_data_0 = scratch_data + s * 2 * 8; + const int8_t* input_data_0 = scratch_data + s * 2 * 8; const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); @@ -4059,7 +4060,7 @@ struct KernelMacroBlock< for (; i_width < adjusted_width_micro_repeats; ++i_width) { const int output_width = kFourOverStride; TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = + const int8_t* input_data = input_data_0 + width_micro_stride * i_width; acc0 = adjusted_bias_data; acc1 = adjusted_bias_data; @@ -4226,11 +4227,11 @@ struct KernelMacroBlock< filter_reg_1_b = vld1q_s8(filter_block + 16 + 32); filter_reg_2_b = vld1q_s8(filter_block + 16 + 64); - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + depth_micro_stride * j_depth; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; - const int8* input_data_0 = scratch_data; + const int8_t* input_data_0 = scratch_data; const int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); bias_data += kBiasIncrement; @@ -4265,7 +4266,8 @@ struct KernelMacroBlock< ? residual_width : kFourOverStride; TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = input_data_0 + width_micro_stride * i_width; + const int8_t* input_data = + input_data_0 + width_micro_stride * i_width; const bool no_right_block = i_width == output_width_micro_repeats && output_width_overall_micro_repeats == workspace_width_micro_repeats; @@ -4378,9 +4380,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -4408,8 +4410,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, uint8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8; @@ -4431,13 +4433,13 @@ struct KernelMacroBlock< TFLITE_DCHECK(depth_micro_repeats > 0); - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LT(output_activation_min, 256); @@ -4453,11 +4455,11 @@ struct KernelMacroBlock< TFLITE_DCHECK_LT(output_offset, 32768); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const uint8x16_t output_activation_min_vec = - vdupq_n_u8(static_cast(output_activation_min)); + vdupq_n_u8(static_cast(output_activation_min)); const uint8x16_t output_activation_max_vec = - vdupq_n_u8(static_cast(output_activation_max)); + vdupq_n_u8(static_cast(output_activation_max)); typename QuantizationTypeImpl::ExternalType* output_data_depthwise = output_block_data; @@ -4508,7 +4510,7 @@ struct KernelMacroBlock< typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise + 4 * s; - const int8* next_input_data = scratch_block_data; + const int8_t* next_input_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -4989,7 +4991,7 @@ struct KernelMacroBlock< bias_data += kBiasIncrement; for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* next_input_data = + const int8_t* next_input_data = scratch_block_data + k_height * workspace_height_stride; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -5074,9 +5076,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -5098,8 +5100,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, uint8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kNonPerChannelUint8; @@ -5120,13 +5122,13 @@ struct KernelMacroBlock< const int output_height_stride = function_params->output_height_stride; constexpr int kBiasIncrement = 4; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_multiplier = function_params->output_multiplier; - const int32 output_shift = function_params->output_shift; - const int32 output_offset = function_params->output_offset; + const int32_t output_multiplier = function_params->output_multiplier; + const int32_t output_shift = function_params->output_shift; + const int32_t output_offset = function_params->output_offset; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); TFLITE_DCHECK_LT(output_activation_min, 256); @@ -5144,11 +5146,11 @@ struct KernelMacroBlock< TFLITE_DCHECK_GE(depth_micro_repeats, 1); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const uint8x16_t output_activation_min_vec = - vdupq_n_u8(static_cast(output_activation_min)); + vdupq_n_u8(static_cast(output_activation_min)); const uint8x16_t output_activation_max_vec = - vdupq_n_u8(static_cast(output_activation_max)); + vdupq_n_u8(static_cast(output_activation_max)); for (int j_depth = 0; j_depth < (depth_micro_repeats * 1 + 0); ++j_depth) { int8x16_t filter_reg_0_a; @@ -5177,7 +5179,7 @@ struct KernelMacroBlock< bias_data += kBiasIncrement; if (block_height == 2) { - const int8* scratch_data = scratch_block_data; + const int8_t* scratch_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; @@ -5216,7 +5218,7 @@ struct KernelMacroBlock< int i_width = 0; for (; i_width < adjusted_width_micro_repeats; ++i_width) { - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -5388,7 +5390,7 @@ struct KernelMacroBlock< } for (; i_width < output_width_overall_micro_repeats; ++i_width) { // output_width == 1. - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -5491,7 +5493,7 @@ struct KernelMacroBlock< } else { TFLITE_DCHECK_EQ(block_height, 1); // Work through one slice, by row, at a time. - const int8* scratch_data = scratch_block_data; + const int8_t* scratch_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; @@ -5520,7 +5522,7 @@ struct KernelMacroBlock< TFLITE_DCHECK_LE(output_width, 2); TFLITE_DCHECK_GE(output_width, 1); TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -5634,9 +5636,9 @@ struct KernelMacroBlock< } } - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - uint8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, uint8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -5664,8 +5666,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, int8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kPerChannelInt8; @@ -5691,14 +5693,14 @@ struct KernelMacroBlock< const int depth_micro_stride = width_micro_stride * input_width_overall_micro_repeats; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_offset = function_params->output_offset; - const int32* output_shift_per_channel = + const int32_t output_offset = function_params->output_offset; + const int32_t* output_shift_per_channel = function_params->output_shift_per_channel; - const int32* output_multiplier_per_channel = + const int32_t* output_multiplier_per_channel = function_params->output_multiplier_per_channel; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); @@ -5717,13 +5719,13 @@ struct KernelMacroBlock< TFLITE_DCHECK_LT(output_offset, 32768); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const int8x16_t output_activation_min_vec = - vdupq_n_s8(static_cast(output_activation_min)); + vdupq_n_s8(static_cast(output_activation_min)); const int8x16_t output_activation_max_vec = - vdupq_n_s8(static_cast(output_activation_max)); + vdupq_n_s8(static_cast(output_activation_max)); - const int8* input_data_depthwise = scratch_block_data; + const int8_t* input_data_depthwise = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data_depthwise = output_block_data; for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { @@ -5761,11 +5763,11 @@ struct KernelMacroBlock< if (block_height == 4) { for (int s = 0; s < 2; ++s) { // Work through one slice, by row, at a time. - const int8* input_data_base = input_data_depthwise + 2 * 8 * s; + const int8_t* input_data_base = input_data_depthwise + 2 * 8 * s; typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise + 4 * s; - const int8* next_input_data = input_data_base; + const int8_t* next_input_data = input_data_base; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -6199,7 +6201,7 @@ struct KernelMacroBlock< vshlq_n_u32(vreinterpretq_u32_s8(filter_reg_2_a), 8)); } } else { - const int8* input_data_base = input_data_depthwise; + const int8_t* input_data_base = input_data_depthwise; typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise; @@ -6218,7 +6220,7 @@ struct KernelMacroBlock< vld1q_s32(output_multiplier_per_channel + j_depth * 8 + 4); for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* next_input_data = input_data_base; + const int8_t* next_input_data = input_data_base; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -6323,9 +6325,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - int8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -6347,8 +6349,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, int8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kPerChannelInt8; @@ -6380,14 +6382,14 @@ struct KernelMacroBlock< const int depth_micro_stride = width_micro_stride * input_width_overall_micro_repeats; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_offset = function_params->output_offset; - const int32* output_shift_per_channel = + const int32_t output_offset = function_params->output_offset; + const int32_t* output_shift_per_channel = function_params->output_shift_per_channel; - const int32* output_multiplier_per_channel = + const int32_t* output_multiplier_per_channel = function_params->output_multiplier_per_channel; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); @@ -6407,18 +6409,18 @@ struct KernelMacroBlock< // This version only does min/max on 64 bits. const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const int8x8_t output_activation_min_vec = - vdup_n_s8(static_cast(output_activation_min)); + vdup_n_s8(static_cast(output_activation_min)); const int8x8_t output_activation_max_vec = - vdup_n_s8(static_cast(output_activation_max)); + vdup_n_s8(static_cast(output_activation_max)); constexpr int shuffled_filter_increment = 2 * 3 * 4 * 4; TFLITE_DCHECK_LE(block_height, 2); for (int j_depth = 0; j_depth < depth_micro_repeats; ++j_depth) { - const int8* filter_block = + const int8_t* filter_block = filter_workspace + shuffled_filter_increment * j_depth; if (block_height == 2) { @@ -6432,11 +6434,11 @@ struct KernelMacroBlock< filter_reg_1_a = vld1q_s8(filter_block + s * 16 + 32); filter_reg_2_a = vld1q_s8(filter_block + s * 16 + 64); - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + depth_micro_stride * j_depth; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; - const int8* input_data_0 = scratch_data + s * 2 * 8; + const int8_t* input_data_0 = scratch_data + s * 2 * 8; const int32x4_t adjusted_bias_data = vld1q_s32(bias_data); @@ -6482,7 +6484,7 @@ struct KernelMacroBlock< for (; i_width < adjusted_width_micro_repeats; ++i_width) { const int output_width = kFourOverStride; TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = + const int8_t* input_data = input_data_0 + width_micro_stride * i_width; acc0 = adjusted_bias_data; acc1 = adjusted_bias_data; @@ -6649,11 +6651,11 @@ struct KernelMacroBlock< filter_reg_1_b = vld1q_s8(filter_block + 16 + 32); filter_reg_2_b = vld1q_s8(filter_block + 16 + 64); - const int8* scratch_data = + const int8_t* scratch_data = scratch_block_data + depth_micro_stride * j_depth; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; - const int8* input_data_0 = scratch_data; + const int8_t* input_data_0 = scratch_data; const int32x4_t adjusted_bias_data_a = vld1q_s32(bias_data); bias_data += kBiasIncrement; @@ -6697,7 +6699,8 @@ struct KernelMacroBlock< ? residual_width : kFourOverStride; TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = input_data_0 + width_micro_stride * i_width; + const int8_t* input_data = + input_data_0 + width_micro_stride * i_width; const bool no_right_block = i_width == output_width_micro_repeats && output_width_overall_micro_repeats == workspace_width_micro_repeats; @@ -6810,9 +6813,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - int8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -6840,8 +6843,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, int8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kPerChannelInt8; @@ -6863,14 +6866,14 @@ struct KernelMacroBlock< TFLITE_DCHECK(depth_micro_repeats > 0); - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_offset = function_params->output_offset; - const int32* output_shift_per_channel = + const int32_t output_offset = function_params->output_offset; + const int32_t* output_shift_per_channel = function_params->output_shift_per_channel; - const int32* output_multiplier_per_channel = + const int32_t* output_multiplier_per_channel = function_params->output_multiplier_per_channel; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); @@ -6889,11 +6892,11 @@ struct KernelMacroBlock< TFLITE_DCHECK_LT(output_offset, 32768); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const int8x16_t output_activation_min_vec = - vdupq_n_s8(static_cast(output_activation_min)); + vdupq_n_s8(static_cast(output_activation_min)); const int8x16_t output_activation_max_vec = - vdupq_n_s8(static_cast(output_activation_max)); + vdupq_n_s8(static_cast(output_activation_max)); typename QuantizationTypeImpl::ExternalType* output_data_depthwise = output_block_data; @@ -6944,7 +6947,7 @@ struct KernelMacroBlock< typename QuantizationTypeImpl::ExternalType* output_data_base = output_data_depthwise + 4 * s; - const int8* next_input_data = scratch_block_data; + const int8_t* next_input_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -7439,7 +7442,7 @@ struct KernelMacroBlock< vld1q_s32(output_multiplier_per_channel + j_depth * 8 + 4); for (int k_height = 0; k_height < block_height; ++k_height) { - const int8* next_input_data = + const int8_t* next_input_data = scratch_block_data + k_height * workspace_height_stride; typename QuantizationTypeImpl::ExternalType* output_data = output_data_base; @@ -7526,9 +7529,9 @@ struct KernelMacroBlock< } } // NOLINT(readability/fn_size) Manually unrolled. - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - int8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); @@ -7550,8 +7553,8 @@ struct KernelMacroBlock< } static inline void KernelMacroBlockIntrinsics( - const int8* scratch_block_data, const int8* filter_workspace, - const int32* bias_data, int8* output_block_data, + const int8_t* scratch_block_data, const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { static constexpr QuantizationType quantization_type = QuantizationType::kPerChannelInt8; @@ -7572,14 +7575,14 @@ struct KernelMacroBlock< const int output_height_stride = function_params->output_height_stride; constexpr int kBiasIncrement = 4; - const int32 output_activation_min = + const int32_t output_activation_min = function_params->quantized_activation_min; - const int32 output_activation_max = + const int32_t output_activation_max = function_params->quantized_activation_max; - const int32 output_offset = function_params->output_offset; - const int32* output_shift_per_channel = + const int32_t output_offset = function_params->output_offset; + const int32_t* output_shift_per_channel = function_params->output_shift_per_channel; - const int32* output_multiplier_per_channel = + const int32_t* output_multiplier_per_channel = function_params->output_multiplier_per_channel; if (quantization_type == QuantizationType::kNonPerChannelUint8) { TFLITE_DCHECK_GE(output_activation_min, 0); @@ -7600,11 +7603,11 @@ struct KernelMacroBlock< TFLITE_DCHECK_GE(depth_micro_repeats, 1); const int16x8_t output_offset_vec = - vdupq_n_s16(static_cast(output_offset)); + vdupq_n_s16(static_cast(output_offset)); const int8x16_t output_activation_min_vec = - vdupq_n_s8(static_cast(output_activation_min)); + vdupq_n_s8(static_cast(output_activation_min)); const int8x16_t output_activation_max_vec = - vdupq_n_s8(static_cast(output_activation_max)); + vdupq_n_s8(static_cast(output_activation_max)); for (int j_depth = 0; j_depth < (depth_micro_repeats * 1 + 0); ++j_depth) { int8x16_t filter_reg_0_a; @@ -7642,7 +7645,7 @@ struct KernelMacroBlock< vld1q_s32(output_multiplier_per_channel + j_depth * 8 + 4); if (block_height == 2) { - const int8* scratch_data = scratch_block_data; + const int8_t* scratch_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; @@ -7681,7 +7684,7 @@ struct KernelMacroBlock< int i_width = 0; for (; i_width < adjusted_width_micro_repeats; ++i_width) { - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -7853,7 +7856,7 @@ struct KernelMacroBlock< } for (; i_width < output_width_overall_micro_repeats; ++i_width) { // output_width == 1. - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -7956,7 +7959,7 @@ struct KernelMacroBlock< } else { TFLITE_DCHECK_EQ(block_height, 1); // Work through one slice, by row, at a time. - const int8* scratch_data = scratch_block_data; + const int8_t* scratch_data = scratch_block_data; typename QuantizationTypeImpl::ExternalType* output_data = output_block_data + 8 * j_depth; @@ -7985,7 +7988,7 @@ struct KernelMacroBlock< TFLITE_DCHECK_LE(output_width, 2); TFLITE_DCHECK_GE(output_width, 1); TFLITE_DCHECK_LE(output_width * kStrideVal, 4); - const int8* input_data = scratch_data + 4 + 4 * i_width; + const int8_t* input_data = scratch_data + 4 + 4 * i_width; // Load next sub-micro block of data. input_bank_a_reg = vld1q_lane_8x4(input_data, input_bank_a_reg, 1); @@ -8099,9 +8102,9 @@ struct KernelMacroBlock< } } - static inline void Run(const int8* scratch_block_data, - const int8* filter_workspace, const int32* bias_data, - int8* output_block_data, + static inline void Run(const int8_t* scratch_block_data, + const int8_t* filter_workspace, + const int32_t* bias_data, int8_t* output_block_data, const DepthwiseConvDotProdParams* function_params) { KernelMacroBlockIntrinsics(scratch_block_data, filter_workspace, bias_data, output_block_data, function_params); diff --git a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h index d3c14f14689941..e0da94a6a3cbc0 100644 --- a/tensorflow/lite/kernels/internal/optimized/im2col_utils.h +++ b/tensorflow/lite/kernels/internal/optimized/im2col_utils.h @@ -25,14 +25,12 @@ namespace tflite { namespace optimized_ops { template -inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w, - int h, int b, int kheight, int kwidth, - int stride_width, int stride_height, - int pad_width, int pad_height, - int in_width, int in_height, - int in_depth, int single_buffer_length, - int buffer_id, const T* in_data, - T* conv_buffer_data, uint8 zero_byte) { +inline void ExtractPatchIntoBufferColumn( + const RuntimeShape& input_shape, int w, int h, int b, int kheight, + int kwidth, int stride_width, int stride_height, int pad_width, + int pad_height, int in_width, int in_height, int in_depth, + int single_buffer_length, int buffer_id, const T* in_data, + T* conv_buffer_data, uint8_t zero_byte) { ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); // This chunk of code reshapes all the inputs corresponding to @@ -201,7 +199,7 @@ void DilatedIm2col(const ConvParams& params, const RuntimeShape& input_shape, } template -void DilatedIm2col(const ConvParams& params, uint8 zero_byte, +void DilatedIm2col(const ConvParams& params, uint8_t zero_byte, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& filter_shape, const RuntimeShape& output_shape, T* im2col_data) { @@ -211,9 +209,10 @@ void DilatedIm2col(const ConvParams& params, uint8 zero_byte, } template -void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte, - const RuntimeShape& input_shape, const T* input_data, - const RuntimeShape& output_shape, T* output_data) { +void Im2col(const ConvParams& params, int kheight, int kwidth, + uint8_t zero_byte, const RuntimeShape& input_shape, + const T* input_data, const RuntimeShape& output_shape, + T* output_data) { ruy::profiler::ScopeLabel label("Im2col"); const int stride_width = params.stride_width; const int stride_height = params.stride_height; @@ -291,7 +290,7 @@ inline void ExtractPatchIntoBufferColumn3D( int pad_depth, int pad_height, int pad_width, // Padding params. int in_depth, int in_height, int in_width, int in_channel, // Input shape. int output_row_offset, const T* in_data, T* conv_buffer_data, - uint8 zero_byte) { + uint8_t zero_byte) { ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn3D"); // This chunk of code reshapes all the inputs corresponding to @@ -372,7 +371,7 @@ inline void ExtractPatchIntoBufferColumn3D( template void Im2col3D(const Conv3DParams& params, int kdepth, int kheight, int kwidth, - uint8 zero_byte, const RuntimeShape& input_shape, + uint8_t zero_byte, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& im2col_shape, T* im2col_data) { ruy::profiler::ScopeLabel label("Im2col3D"); @@ -417,7 +416,7 @@ void Im2col3D(const Conv3DParams& params, int kdepth, int kheight, int kwidth, template inline void DilatedIm2col3D(const Conv3DParams& params, int filter_depth, int filter_height, int filter_width, - uint8 zero_byte, const RuntimeShape& input_shape, + uint8_t zero_byte, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& im2col_shape, T* im2col_data) { ruy::profiler::ScopeLabel label("DilatedIm2col3D"); diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h index b9727e4ad40c34..cd2a6148fade6f 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/depthwise_conv.h @@ -46,9 +46,9 @@ struct QuantizedDepthwiseConvKernel {}; template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8x2_t filter_s8; filter_s8.val[0] = vld1_s8(filter_ptr); @@ -92,9 +92,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. const int8x8_t filter_s8 = vld1_s8(filter_ptr); const int16x8_t filter = vmovl_s8(filter_s8); @@ -159,9 +159,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. const int8x8_t filter_s8 = vld1_s8(filter_ptr); const int16x8_t filter = vmovl_s8(filter_s8); @@ -227,9 +227,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int16x8_t filter[2]; for (int i = 0; i < 2; i++) { @@ -301,9 +301,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -364,9 +364,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -474,9 +474,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -518,7 +518,7 @@ struct QuantizedDepthwiseConvKernel { int32x2_t acc = vld1_s32(acc_buffer_ptr); // Load the inputs, add input_offset. - const uint32 input = *input_ptr++ + input_offset; + const uint32_t input = *input_ptr++ + input_offset; // Multiply-accumulate acc = vget_low_s32(vmlal_n_s16(vcombine_s32(acc, acc), filter, input)); @@ -532,9 +532,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -610,7 +610,7 @@ struct QuantizedDepthwiseConvKernel { int32x4_t acc = vld1q_s32(acc_buffer_ptr); // Load the inputs, add input_offset. - const uint32 input = *input_ptr++ + input_offset; + const uint32_t input = *input_ptr++ + input_offset; // Multiply-accumulate acc = vmlal_n_s16(acc, filter, input); @@ -624,9 +624,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -691,9 +691,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int16x8_t filter[2]; for (int i = 0; i < 2; i++) { @@ -774,15 +774,15 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // We will have to duplicate bytes in a NEON register, 3-fold. // We will do that by register-level table-look-up using VTBL instructions. // Here we prepare the registers containing the table-lookup indices. - static const int8 dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, - {2, 3, 3, 3, 4, 4, 4, 5}, - {5, 5, 6, 6, 6, 7, 7, 7}}; + static const int8_t dup3_indices_array[3][8] = {{0, 0, 0, 1, 1, 1, 2, 2}, + {2, 3, 3, 3, 4, 4, 4, 5}, + {5, 5, 6, 6, 6, 7, 7, 7}}; int8x8_t dup3_indices[3]; for (int i = 0; i < 3; i++) { dup3_indices[i] = vld1_s8(dup3_indices_array[i]); @@ -790,8 +790,8 @@ struct QuantizedDepthwiseConvKernel { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const int8* local_filter_ptr = filter_ptr; - const int8* local_input_ptr = input_ptr; + const int8_t* local_filter_ptr = filter_ptr; + const int8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 8 input channels at a time. for (; ic <= input_depth - 8; ic += 8) { @@ -842,10 +842,10 @@ struct QuantizedDepthwiseConvKernel { } // Handle one input channel at a time. for (; ic < input_depth; ic++) { - const int16 input_val = *local_input_ptr++ + input_offset; + const int16_t input_val = *local_input_ptr++ + input_offset; for (int i = 0; i < 3; i++) { *acc_buffer_ptr++ += - static_cast(local_filter_ptr[i]) * input_val; + static_cast(local_filter_ptr[i]) * input_val; } local_filter_ptr += 3; } @@ -857,13 +857,13 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const int8* local_filter_ptr = filter_ptr; - const int8* local_input_ptr = input_ptr; + const int8_t* local_filter_ptr = filter_ptr; + const int8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 8 input channels at a time. for (; ic <= input_depth - 8; ic += 8) { @@ -905,10 +905,10 @@ struct QuantizedDepthwiseConvKernel { // Handle one input channel at a time. for (; ic < input_depth; ic++) { // Load the inputs. - const int16 input_val = *local_input_ptr++ + input_offset; + const int16_t input_val = *local_input_ptr++ + input_offset; for (int i = 0; i < 2; i++) { *acc_buffer_ptr++ += - static_cast(local_filter_ptr[i]) * input_val; + static_cast(local_filter_ptr[i]) * input_val; } local_filter_ptr += 2; } @@ -920,13 +920,13 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - const int8* local_filter_ptr = filter_ptr; - const int8* local_input_ptr = input_ptr; + const int8_t* local_filter_ptr = filter_ptr; + const int8_t* local_input_ptr = input_ptr; int ic = 0; // Handle 16 input channels at a time. for (; ic <= input_depth - 16; ic += 16) { @@ -989,9 +989,9 @@ struct QuantizedDepthwiseConvKernel { } // Handle one input channel at a time. for (; ic < input_depth; ic++) { - const int16 input_val = *local_input_ptr++ + input_offset; - const int16 filter_val = *local_filter_ptr++; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t input_val = *local_input_ptr++ + input_offset; + const int16_t filter_val = *local_filter_ptr++; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } input_ptr += input_ptr_increment; } @@ -1001,9 +1001,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8[2]; for (int i = 0; i < 2; i++) { @@ -1052,9 +1052,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. const int8x8_t filter_s8 = vld1_s8(filter_ptr); const int16x8_t filter = vmovl_s8(filter_s8); @@ -1085,9 +1085,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8[2]; for (int i = 0; i < 2; i++) { @@ -1099,9 +1099,9 @@ struct QuantizedDepthwiseConvKernel { } // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - int8 input_s8 = *input_ptr; + int8_t input_s8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_s8 + input_offset); + int16_t input = static_cast(input_s8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc[4]; for (int i = 0; i < 4; i++) { @@ -1126,9 +1126,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8_0 = vld1_s8(filter_ptr + 8 * 0); int8x8_t filter_s8_1 = vld1_s8(filter_ptr + 8 * 1); @@ -1140,9 +1140,9 @@ struct QuantizedDepthwiseConvKernel { int16x8_t filter_3 = vmovl_s8(filter_s8_3); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - int8 input_s8 = *input_ptr; + int8_t input_s8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_s8 + input_offset); + int16_t input = static_cast(input_s8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); @@ -1178,9 +1178,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8. // We load the first 16 bytes into filter_s8_{0,1} as usual. @@ -1195,9 +1195,9 @@ struct QuantizedDepthwiseConvKernel { int16x8_t filter_x = vmovl_s8(filter_s8_x); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - int8 input_s8 = *input_ptr; + int8_t input_s8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_s8 + input_offset); + int16_t input = static_cast(input_s8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); @@ -1224,17 +1224,17 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. const int8x8_t filter_s8 = vld1_s8(filter_ptr); const int16x8_t filter = vmovl_s8(filter_s8); // Handle one output pixel at a time. for (int outp = 0; outp < num_output_pixels; outp++) { - int8 input_s8 = *input_ptr; + int8_t input_s8 = *input_ptr; input_ptr += input_ptr_increment; - int16 input = static_cast(input_s8 + input_offset); + int16_t input = static_cast(input_s8 + input_offset); // Load the accumulators from acc_buffer int32x4_t acc[2]; for (int i = 0; i < 2; i++) { @@ -1255,9 +1255,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8 = vdup_n_s8(0); filter_s8 = vset_lane_s8(filter_ptr[0], filter_s8, 0); @@ -1274,11 +1274,11 @@ struct QuantizedDepthwiseConvKernel { int32x4_t acc = vld1q_s32(acc_buffer_ptr); // Load the inputs, add input_offset. int16x4_t input_s16 = vdup_n_s16(0); - input_s16 = vset_lane_s16((reinterpret_cast(input_ptr))[0], - input_s16, 0); + input_s16 = vset_lane_s16( + (reinterpret_cast(input_ptr))[0], input_s16, 0); input_ptr += input_ptr_increment; - input_s16 = vset_lane_s16((reinterpret_cast(input_ptr))[0], - input_s16, 1); + input_s16 = vset_lane_s16( + (reinterpret_cast(input_ptr))[0], input_s16, 1); input_ptr += input_ptr_increment; input_s16 = vget_low_s16(vmovl_s8(vreinterpret_s8_s16(input_s16))); const int16x4_t input = vadd_s16(input_s16, vdup_n_s16(input_offset)); @@ -1314,9 +1314,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { if (num_output_pixels <= 0) { return; } @@ -1374,9 +1374,9 @@ struct QuantizedDepthwiseConvKernel { template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, - const int8* input_ptr, int16 input_offset, - int input_ptr_increment, const int8* filter_ptr, - int32* acc_buffer_ptr) { + const int8_t* input_ptr, int16_t input_offset, + int input_ptr_increment, const int8_t* filter_ptr, + int32_t* acc_buffer_ptr) { // Load the filters. int8x8_t filter_s8_0 = vld1_s8(filter_ptr); int8x8_t filter_s8_1 = vld1_s8(filter_ptr + 4); @@ -1421,14 +1421,12 @@ struct QuantizedDepthwiseConvKernel { // Accumulates the effect of one row of the filter, on a segment of one row // of the output, accessing the corresponding one row of the input. template -void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, - int input_depth, int input_width, - const int8* input_data, int16 input_offset, - int pad_width, int depth_multiplier, - int filter_width, const int8* filter_data, - int out_x_buffer_start, - int out_x_buffer_end, int output_depth, - int32* acc_buffer) { +void QuantizedDepthwiseConvAccumRow( + int stride, int dilation_factor, int input_depth, int input_width, + const int8_t* input_data, int16_t input_offset, int pad_width, + int depth_multiplier, int filter_width, const int8_t* filter_data, + int out_x_buffer_start, int out_x_buffer_end, int output_depth, + int32_t* acc_buffer) { ruy::profiler::ScopeLabel label(TFLITE_PRETTY_FUNCTION); // Consistency check parameters. This is important in particular to ensure // that we keep the number of template instantiations minimal, so we don't @@ -1444,7 +1442,7 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, } TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier); const int input_ptr_increment = stride * input_depth; - const int8* filter_base_ptr = filter_data; + const int8_t* filter_base_ptr = filter_data; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { // For the current (filter_x, filter_y) point in the filter, // compute the boundaries of the corresponding output row segment. @@ -1480,11 +1478,11 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, const int out_x_loop_end = std::min(out_x_buffer_end, out_x_loop_end_unclamped); - int32* acc_buffer_ptr = + int32_t* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x; - const int8* input_ptr = input_data + in_x_origin * input_depth; + const int8_t* input_ptr = input_data + in_x_origin * input_depth; const int num_output_pixels = out_x_loop_end - out_x_loop_start; QuantizedDepthwiseConvKernel< kAllowStrided, kFixedInputDepth, @@ -1499,12 +1497,12 @@ void QuantizedDepthwiseConvAccumRow(int stride, int dilation_factor, // generic fallback of DepthwiseConvAccumRow, portable, non-templatized. inline void QuantizedDepthwiseConvAccumRowGeneric( int stride, int dilation_factor, int input_depth, int input_width, - const int8* input_data, int16 input_offset, int pad_width, - int depth_multiplier, int filter_width, const int8* filter_data, + const int8_t* input_data, int16_t input_offset, int pad_width, + int depth_multiplier, int filter_width, const int8_t* filter_data, int out_x_buffer_start, int out_x_buffer_end, int output_depth, - int32* acc_buffer) { + int32_t* acc_buffer) { ruy::profiler::ScopeLabel label("DepthwiseConvAccumRowGeneric (slow)"); - const int8* filter_base_ptr = filter_data; + const int8_t* filter_base_ptr = filter_data; for (int filter_x = 0; filter_x < filter_width; ++filter_x) { const int out_x_loop_start = std::max( out_x_buffer_start, @@ -1514,19 +1512,19 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( (pad_width + input_width - dilation_factor * filter_x + stride - 1) / stride); - int32* acc_buffer_ptr = + int32_t* acc_buffer_ptr = acc_buffer + (out_x_loop_start - out_x_buffer_start) * output_depth; const int in_x_origin = (out_x_loop_start * stride) - pad_width + dilation_factor * filter_x; - const int8* input_ptr = input_data + in_x_origin * input_depth; + const int8_t* input_ptr = input_data + in_x_origin * input_depth; const int input_ptr_increment = (stride - 1) * input_depth; for (int out_x = out_x_loop_start; out_x < out_x_loop_end; out_x++) { - const int8* filter_ptr = filter_base_ptr; + const int8_t* filter_ptr = filter_base_ptr; for (int ic = 0; ic < input_depth; ++ic) { - const int16 input_val = *input_ptr++ + input_offset; + const int16_t input_val = *input_ptr++ + input_offset; for (int m = 0; m < depth_multiplier; m++) { - const int16 filter_val = *filter_ptr++; - *acc_buffer_ptr++ += static_cast(filter_val) * input_val; + const int16_t filter_val = *filter_ptr++; + *acc_buffer_ptr++ += static_cast(filter_val) * input_val; } } input_ptr += input_ptr_increment; @@ -1537,8 +1535,8 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( // Initializes the accumulator buffer with bias values. inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, - const int32* bias_data, - int32* acc_buffer) { + const int32_t* bias_data, + int32_t* acc_buffer) { int i = 0; #ifdef USE_NEON if (output_depth == 1) { @@ -1609,21 +1607,21 @@ inline void DepthwiseConvInitAccBuffer(int num_output_pixels, int output_depth, } inline void DepthwiseConvGeneral( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - int thread_start, int thread_end, int thread_dim) { + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, int thread_start, int thread_end, int thread_dim) { const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int pad_width = params.padding_values.width; const int pad_height = params.padding_values.height; const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; - const int32 input_offset = params.input_offset; - const int32 output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; const int batches = MatchingDim(input_shape, 0, output_shape, 0); @@ -1638,12 +1636,12 @@ inline void DepthwiseConvGeneral( static const int kAccBufferMaxSize = 2048; int acc_buffer_size = kAccBufferMaxSize; - int32 stack_acc_buffer[kAccBufferMaxSize]; - int32* acc_buffer = stack_acc_buffer; + int32_t stack_acc_buffer[kAccBufferMaxSize]; + int32_t* acc_buffer = stack_acc_buffer; #ifndef TF_LITE_STATIC_MEMORY - std::unique_ptr heap_acc_buffer; + std::unique_ptr heap_acc_buffer; if (kAccBufferMaxSize < output_depth) { - heap_acc_buffer.reset(new int32[output_depth]); + heap_acc_buffer.reset(new int32_t[output_depth]); acc_buffer = heap_acc_buffer.get(); acc_buffer_size = output_depth; } @@ -1745,7 +1743,7 @@ inline void DepthwiseConvGeneral( break; } - int8* output_ptr = output_data + output_ptr_offset; + int8_t* output_ptr = output_data + output_ptr_offset; int batch_step = (output_rows + row_start - row_end) * output_width * output_depth; for (int b = batch_start; b < batch_end; ++b) { @@ -1802,12 +1800,12 @@ inline void DepthwiseConvGeneral( template inline void DepthwiseConvWithRounding( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - int thread_start, int thread_end, int thread_dim, + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, int thread_start, int thread_end, int thread_dim, const CpuBackendContext& cpu_backend_context) { ruy::profiler::ScopeLabel label("DepthwiseConvInt8/8bit"); const int depth_multiplier = params.depth_multiplier; @@ -1886,12 +1884,12 @@ inline void DepthwiseConvWithRounding( } inline void DepthwiseConvImpl( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - int thread_start, int thread_end, int thread_dim, + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, int thread_start, int thread_end, int thread_dim, const CpuBackendContext& cpu_backend_context) { return DepthwiseConvWithRounding( params, output_multiplier, output_shift, input_shape, input_data, @@ -1902,8 +1900,8 @@ inline void DepthwiseConvImpl( template struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task { DepthwiseConvWorkerTask(const DepthwiseParams& params, - const int32* output_multiplier, - const int32* output_shift, + const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& filter_shape, const T* filter_data, const RuntimeShape& bias_shape, @@ -1936,8 +1934,8 @@ struct DepthwiseConvWorkerTask : cpu_backend_threadpool::Task { private: const DepthwiseParams& params_; - const int32* output_multiplier_; - const int32* output_shift_; + const int32_t* output_multiplier_; + const int32_t* output_shift_; const RuntimeShape& input_shape_; const T* input_data_; const RuntimeShape& filter_shape_; @@ -1967,12 +1965,12 @@ inline int HowManyConvThreads(const RuntimeShape& output_shape, } inline void DepthwiseConvPerChannel( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - CpuBackendContext* cpu_backend_context) { + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("DepthwiseConvInt8"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); @@ -2003,7 +2001,7 @@ inline void DepthwiseConvPerChannel( /*thread_end=*/output_rows, /*thread_dim=*/1, *cpu_backend_context); } else { - std::vector> tasks; + std::vector> tasks; // TODO(b/131746020) don't create new heap allocations every time. // At least we make it a single heap allocation by using reserve(). tasks.reserve(thread_count); diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h index 19f9d4d175f1b5..2df5db9a32fd6a 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h @@ -29,18 +29,18 @@ namespace optimized_integer_ops { template inline void FullyConnectedPerChannel( - const FullyConnectedParams& params, const int32* output_multiplier, + const FullyConnectedParams& params, const int32_t* output_multiplier, const int* output_shift, const RuntimeShape& input_shape, const InputScalar* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, DstScalar* output_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit"); - const int32 input_offset = params.input_offset; - const int32 output_offset = params.output_offset; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t input_offset = params.input_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(b/62193649): This really should be: @@ -62,7 +62,7 @@ inline void FullyConnectedPerChannel( const bool use_caching = (cpu_backend_context != nullptr) && cpu_backend_context->use_caching(); - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = filter_rows; lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; @@ -84,7 +84,7 @@ inline void FullyConnectedPerChannel( dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = output_offset; cpu_backend_gemm::GemmParams< - int32, DstScalar, + int32_t, DstScalar, cpu_backend_gemm::QuantizationFlavor::kIntegerWithPerRowMultiplier> gemm_params; gemm_params.bias = bias_data; @@ -101,18 +101,18 @@ template inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, const InputScalar* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, DstScalar* output_data, CpuBackendContext* cpu_backend_context) { ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit"); - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(b/62193649): This really should be: @@ -134,7 +134,7 @@ inline void FullyConnected( const bool use_caching = (cpu_backend_context != nullptr) && cpu_backend_context->use_caching(); - cpu_backend_gemm::MatrixParams lhs_params; + cpu_backend_gemm::MatrixParams lhs_params; lhs_params.rows = filter_rows; lhs_params.cols = filter_cols; lhs_params.order = cpu_backend_gemm::Order::kRowMajor; @@ -155,7 +155,7 @@ inline void FullyConnected( dst_params.cols = batches; dst_params.order = cpu_backend_gemm::Order::kColMajor; dst_params.zero_point = output_offset; - cpu_backend_gemm::GemmParams gemm_params; + cpu_backend_gemm::GemmParams gemm_params; gemm_params.bias = bias_data; gemm_params.clamp_min = output_activation_min; gemm_params.clamp_max = output_activation_max; diff --git a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h index 0a6d63d3fabea6..abb59a0208a8fa 100644 --- a/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h +++ b/tensorflow/lite/kernels/internal/optimized/integer_ops/pooling.h @@ -35,8 +35,8 @@ namespace tflite { namespace optimized_integer_ops { inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& output_shape, - int8* output_data) { + const int8_t* input_data, const RuntimeShape& output_shape, + int8_t* output_data) { ruy::profiler::ScopeLabel label("MaxPool/8bit"); // Here, and in other pooling ops, in order to maintain locality of reference, @@ -59,7 +59,7 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, const int stride_height = params.stride_height; const int stride_width = params.stride_width; - int8 acc[kPoolingAccTrancheSize]; + int8_t acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { // We proceed through the depth in tranches (see comment above). The // depth_base is the depth at the beginning of the tranche. The @@ -82,15 +82,15 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, std::min(params.filter_height, input_height - in_y_origin); memset(acc, params.quantized_activation_min, tranche_depth * sizeof(acc[0])); - const int8* input_ptr = + const int8_t* input_ptr = input_data + depth_base + depth * (in_x_origin + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const int8* input_row_ptr = + const int8_t* input_row_ptr = input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { - const int8* input_channel_ptr = input_row_ptr; + const int8_t* input_channel_ptr = input_row_ptr; int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -115,8 +115,8 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, input_row_ptr += depth; } } - int8* output_ptr = output_data + Offset(output_shape, batch, out_y, - out_x, depth_base); + int8_t* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -133,10 +133,10 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } #endif for (; channel < tranche_depth; ++channel) { - int8 a = acc[channel]; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + int8_t a = acc[channel]; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); } } } @@ -145,8 +145,9 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape, } inline bool AveragePool(const PoolParams& params, - const RuntimeShape& input_shape, const int8* input_data, - const RuntimeShape& output_shape, int8* output_data) { + const RuntimeShape& input_shape, + const int8_t* input_data, + const RuntimeShape& output_shape, int8_t* output_data) { ruy::profiler::ScopeLabel label("AveragePool/8bitWith32bitAccumulator"); // Here, and in other pooling ops, in order to maintain locality of reference, @@ -169,7 +170,7 @@ inline bool AveragePool(const PoolParams& params, const int stride_height = params.stride_height; const int stride_width = params.stride_width; - int32 acc[kPoolingAccTrancheSize]; + int32_t acc[kPoolingAccTrancheSize]; for (int batch = 0; batch < batches; ++batch) { // We proceed through the depth in tranches (see comment above). The // depth_base is the depth at the beginning of the tranche. The @@ -194,15 +195,15 @@ inline bool AveragePool(const PoolParams& params, (filter_x_end - filter_x_start) * (filter_y_end - filter_y_start); if (filter_count == 0) return false; memset(acc, 0, tranche_depth * sizeof(acc[0])); - const int8* input_ptr = + const int8_t* input_ptr = input_data + depth_base + depth * (in_x_origin + input_width * (in_y_origin + input_height * batch)); for (int fy = filter_y_start; fy < filter_y_end; fy++) { - const int8* input_row_ptr = + const int8_t* input_row_ptr = input_ptr + depth * (fy * input_width + filter_x_start); for (int fx = filter_x_start; fx < filter_x_end; fx++) { - const int8* input_channel_ptr = input_row_ptr; + const int8_t* input_channel_ptr = input_row_ptr; int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 16; channel += 16) { @@ -238,12 +239,12 @@ inline bool AveragePool(const PoolParams& params, input_row_ptr += depth; } } - int8* output_ptr = output_data + Offset(output_shape, batch, out_y, - out_x, depth_base); + int8_t* output_ptr = output_data + Offset(output_shape, batch, out_y, + out_x, depth_base); int channel = 0; #ifdef USE_NEON for (; channel <= tranche_depth - 8; channel += 8) { - int16 buf[8]; + int16_t buf[8]; for (int i = 0; i < 8; i++) { buf[i] = acc[channel + i] > 0 @@ -257,12 +258,12 @@ inline bool AveragePool(const PoolParams& params, } #endif for (; channel < tranche_depth; ++channel) { - int16 a = acc[channel] > 0 - ? (acc[channel] + filter_count / 2) / filter_count - : (acc[channel] - filter_count / 2) / filter_count; - a = std::max(a, params.quantized_activation_min); - a = std::min(a, params.quantized_activation_max); - output_ptr[channel] = static_cast(a); + int16_t a = acc[channel] > 0 + ? (acc[channel] + filter_count / 2) / filter_count + : (acc[channel] - filter_count / 2) / filter_count; + a = std::max(a, params.quantized_activation_min); + a = std::min(a, params.quantized_activation_max); + output_ptr[channel] = static_cast(a); } } } diff --git a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h index eb83c9b54bbd97..8c8c7288143055 100644 --- a/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h +++ b/tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h @@ -213,14 +213,14 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, template inline void LegacyDepthwiseConvWithRounding( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, int thread_start, int thread_end, int thread_dim) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, int thread_start, int thread_end, int thread_dim) { ruy::profiler::ScopeLabel label("DepthwiseConv/8bit"); const int depth_multiplier = params.depth_multiplier; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; TFLITE_DCHECK_GE(dilation_width_factor, 1); @@ -267,10 +267,10 @@ inline void LegacyDepthwiseConvWithRounding( inline void LegacyDepthwiseConvImpl( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, int thread_start, int thread_end, int thread_dim) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, int thread_start, int thread_end, int thread_dim) { return LegacyDepthwiseConvWithRounding< DepthwiseConvOutputRounding::kAwayFromZero>( params, input_shape, input_data, filter_shape, filter_data, bias_shape, @@ -278,16 +278,16 @@ inline void LegacyDepthwiseConvImpl( thread_dim); } -inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int dilation_width_factor, int dilation_height_factor, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::DepthwiseParams op_params; // Padding type is ignored, but still set. @@ -318,15 +318,15 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, /*thread_end=*/output_height, /*thread_dim=*/1); } -inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, @@ -338,15 +338,15 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Legacy, for compatibility with old checked-in code. template -void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int depth_multiplier, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int pad_height, int depth_multiplier, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { if (Ac == FusedActivationFunctionType::kNone) { TFLITE_DCHECK_EQ(output_activation_min, 0); TFLITE_DCHECK_EQ(output_activation_max, 255); @@ -361,15 +361,15 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Legacy, for compatibility with old checked-in code. template -void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride, stride, pad_width, pad_height, depth_multiplier, @@ -437,10 +437,10 @@ inline int HowManyConvThreads(const RuntimeShape& output_shape, inline void DepthwiseConv( const DepthwiseParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) { ruy::profiler::ScopeLabel label("DepthwiseConv"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); @@ -477,7 +477,7 @@ inline void DepthwiseConv( for (int i = 0; i < thread_count; ++i) { int thread_end = thread_start + (thread_dim_size - thread_start) / (thread_count - i); - tasks[i] = new LegacyDepthwiseConvWorkerTask( + tasks[i] = new LegacyDepthwiseConvWorkerTask( params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data, thread_start, thread_end, thread_dim); @@ -490,8 +490,8 @@ inline void DepthwiseConv( template struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task { LegacyPerChannelDepthwiseConvWorkerTask( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& filter_shape, const T* filter_data, const RuntimeShape& bias_shape, const TS* bias_data, const RuntimeShape& output_shape, T* output_data, int thread_start, @@ -521,8 +521,8 @@ struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task { private: const DepthwiseParams& params_; - const int32* output_multiplier_; - const int32* output_shift_; + const int32_t* output_multiplier_; + const int32_t* output_shift_; const RuntimeShape& input_shape_; const T* input_data_; const RuntimeShape& filter_shape_; @@ -537,12 +537,12 @@ struct LegacyPerChannelDepthwiseConvWorkerTask : public gemmlowp::Task { }; inline void DepthwiseConvPerChannel( - const DepthwiseParams& params, const int32* output_multiplier, - const int32* output_shift, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - gemmlowp::GemmContext* gemmlowp_context = nullptr) { + const DepthwiseParams& params, const int32_t* output_multiplier, + const int32_t* output_shift, const RuntimeShape& input_shape, + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context = nullptr) { ruy::profiler::ScopeLabel label("DepthwiseConvInt8"); TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); @@ -581,7 +581,7 @@ inline void DepthwiseConvPerChannel( for (int i = 0; i < thread_count; ++i) { int thread_end = thread_start + (thread_dim_size - thread_start) / (thread_count - i); - tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask( + tasks[i] = new LegacyPerChannelDepthwiseConvWorkerTask( params, output_multiplier, output_shift, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data, thread_start, thread_end, thread_dim); @@ -713,17 +713,17 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims, } struct GemmlowpOutputPipeline { - typedef gemmlowp::VectorMap + typedef gemmlowp::VectorMap ColVectorMap; typedef std::tuple, gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToUint8> Pipeline; - static Pipeline MakeExp(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_left_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32_t* bias_data, int output_rows, + int32_t output_offset, int32_t output_multiplier, + int output_left_shift, int32_t output_activation_min, + int32_t output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -741,17 +741,17 @@ struct GemmlowpOutputPipeline { }; struct GemmlowpOutputPipelineInt8 { - typedef gemmlowp::VectorMap + typedef gemmlowp::VectorMap ColVectorMap; typedef std::tuple, gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent, gemmlowp::OutputStageClamp, gemmlowp::OutputStageSaturatingCastToInt8> Pipeline; - static Pipeline MakeExp(const int32* bias_data, int output_rows, - int32 output_offset, int32 output_multiplier, - int output_left_shift, int32 output_activation_min, - int32 output_activation_max) { + static Pipeline MakeExp(const int32_t* bias_data, int output_rows, + int32_t output_offset, int32_t output_multiplier, + int output_left_shift, int32_t output_activation_min, + int32_t output_activation_max) { ColVectorMap bias_vector(bias_data, output_rows); gemmlowp::OutputStageBiasAddition bias_addition_stage; bias_addition_stage.bias_vector = bias_vector; @@ -770,13 +770,14 @@ struct GemmlowpOutputPipelineInt8 { #ifdef USE_NEON inline void LegacyFullyConnectedAsGEMVWorkerImpl( - const RuntimeShape& input_shape, const uint8* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const uint8* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, const RuntimeShape& output_shape, - uint8* output_data, int row_start, int row_end) { + const RuntimeShape& input_shape, const uint8_t* input_data, + int32_t input_offset, const RuntimeShape& filter_shape, + const uint8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + const RuntimeShape& output_shape, uint8_t* output_data, int row_start, + int row_end) { ruy::profiler::ScopeLabel label("FullyConnectedAsGEMV/8bit"); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); @@ -806,7 +807,7 @@ inline void LegacyFullyConnectedAsGEMVWorkerImpl( int in = 0; for (; in <= input_size - 16; in += 16) { const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); - const uint8* filter_ptr = filter_data + in + out * input_size; + const uint8_t* filter_ptr = filter_data + in + out * input_size; uint8x16_t filter_val_u8_0 = vld1q_u8(filter_ptr); optimized_ops_preload_l1_stream(filter_ptr + 64); filter_ptr += input_size; @@ -884,7 +885,7 @@ inline void LegacyFullyConnectedAsGEMVWorkerImpl( } for (; in <= input_size - 8; in += 8) { const uint8x8_t input_val_u8 = vld1_u8(input_data + in); - const uint8* filter_ptr = filter_data + in + out * input_size; + const uint8_t* filter_ptr = filter_data + in + out * input_size; uint8x8_t filter_val_u8_0 = vld1_u8(filter_ptr); filter_ptr += input_size; uint8x8_t filter_val_u8_1 = vld1_u8(filter_ptr); @@ -920,16 +921,16 @@ inline void LegacyFullyConnectedAsGEMVWorkerImpl( vget_high_s16(input_val)); } if (in < input_size) { - int32 buf[16]; + int32_t buf[16]; vst1q_s32(buf + 0, acc0); vst1q_s32(buf + 4, acc1); vst1q_s32(buf + 8, acc2); vst1q_s32(buf + 12, acc3); for (; in < input_size; in++) { int lane = (in + 8 - input_size) % 4; - const int32 input_val = input_data[in] + input_offset; + const int32_t input_val = input_data[in] + input_offset; for (int k = 0; k < kPeel; k++) { - int32 filter_val = + int32_t filter_val = filter_data[in + (out + k) * input_size] + filter_offset; buf[lane + 4 * k] += filter_val * input_val; } @@ -958,7 +959,7 @@ inline void LegacyFullyConnectedAsGEMVWorkerImpl( int32x4_t bias_vec = vld1q_s32(bias_data + out); reduced = vaddq_s32(reduced, bias_vec); if (shift_left) { - const int32 multiplier_power_of_two = 1 << output_shift; + const int32_t multiplier_power_of_two = 1 << output_shift; reduced = vmulq_n_s32(reduced, multiplier_power_of_two); reduced = vqrdmulhq_n_s32(reduced, output_multiplier); } else { @@ -988,13 +989,13 @@ inline void LegacyFullyConnectedAsGEMVWorkerImpl( struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { LegacyFullyConnectedAsGEMVWorkerTask( - const RuntimeShape& input_shape, const uint8* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const uint8* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - const RuntimeShape& output_shape, uint8* output_data, int row_start, + const RuntimeShape& input_shape, const uint8_t* input_data, + int32_t input_offset, const RuntimeShape& filter_shape, + const uint8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + const RuntimeShape& output_shape, uint8_t* output_data, int row_start, int row_end) : input_shape_(input_shape), input_data_(input_data), @@ -1024,32 +1025,33 @@ struct LegacyFullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { } const RuntimeShape& input_shape_; - const uint8* input_data_; - int32 input_offset_; + const uint8_t* input_data_; + int32_t input_offset_; const RuntimeShape& filter_shape_; - const uint8* filter_data_; - int32 filter_offset_; + const uint8_t* filter_data_; + int32_t filter_offset_; const RuntimeShape& bias_shape_; - const int32* bias_data_; - int32 output_offset_; - int32 output_multiplier_; + const int32_t* bias_data_; + int32_t output_offset_; + int32_t output_multiplier_; int output_shift_; - int32 output_activation_min_; - int32 output_activation_max_; + int32_t output_activation_min_; + int32_t output_activation_max_; const RuntimeShape& output_shape_; - uint8* output_data_; + uint8_t* output_data_; int row_start_; int row_end_; }; inline void FullyConnectedAsGEMV( - const RuntimeShape& input_shape, const uint8* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const uint8* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + const RuntimeShape& input_shape, const uint8_t* input_data, + int32_t input_offset, const RuntimeShape& filter_shape, + const uint8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + const RuntimeShape& output_shape, uint8_t* output_data, + gemmlowp::GemmContext* gemmlowp_context) { const int output_dim_count = output_shape.DimensionsCount(); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); const int output_rows = output_shape.Dims(output_dim_count - 1); @@ -1090,18 +1092,18 @@ inline void FullyConnectedAsGEMV( inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext* gemmlowp_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("FullyConnected/8bit"); - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(b/62193649): This really should be: @@ -1132,16 +1134,16 @@ inline void FullyConnected( TFLITE_DCHECK_EQ(output_rows, filter_rows); TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_rows); - gemmlowp::MatrixMap filter_matrix( - filter_data, output_rows, filter_cols, filter_cols); - gemmlowp::MatrixMap input_matrix( + gemmlowp::MatrixMap + filter_matrix(filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( input_data, filter_cols, batches, filter_cols); - gemmlowp::MatrixMap output_matrix( + gemmlowp::MatrixMap output_matrix( output_data, output_rows, batches, output_rows); const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( bias_data, output_rows, output_offset, output_multiplier, output_shift, output_activation_min, output_activation_max); - gemmlowp::GemmWithOutputPipeline( gemmlowp_context, filter_matrix, input_matrix, &output_matrix, filter_offset, input_offset, output_pipeline); @@ -1634,18 +1636,18 @@ inline void GEMVForLstmCellWithSymmetricRange( inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data_int32, const RuntimeShape& output_shape, - int16* output_data, gemmlowp::GemmContext* gemmlowp_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data_int32, const RuntimeShape& output_shape, + int16_t* output_data, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("FullyConnected/Uint8Int16"); - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; // This is a copy of the reference implementation. We do not currently have a // properly optimized version. (void)gemmlowp_context; // only used in properly optimized code. @@ -1690,13 +1692,13 @@ inline void FullyConnected( } } #endif - gemmlowp::MatrixMap weights_matrix( - filter_data, output_depth, accum_depth); - gemmlowp::MatrixMap input_matrix( + gemmlowp::MatrixMap + weights_matrix(filter_data, output_depth, accum_depth); + gemmlowp::MatrixMap input_matrix( input_data, accum_depth, batches); - gemmlowp::MatrixMap output_matrix( + gemmlowp::MatrixMap output_matrix( output_data, output_depth, batches); - typedef gemmlowp::VectorMap + typedef gemmlowp::VectorMap ColVectorMap; ColVectorMap bias_vector(bias_data_int32, output_depth); gemmlowp::OutputStageBiasAddition bias_addition_stage; @@ -1713,19 +1715,19 @@ inline void FullyConnected( auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, clamp_stage, saturating_cast_int16_stage); - gemmlowp::GemmWithOutputPipeline( gemmlowp_context, weights_matrix, input_matrix, &output_matrix, filter_offset, input_offset, output_pipeline); } -inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; @@ -1744,13 +1746,16 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, gemmlowp_context); } -inline void FullyConnected( - const uint8* input_data, const Dims<4>& input_dims, int32 input_offset, - const uint8* filter_data, const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data_int32, const Dims<4>& bias_dims, int32 output_offset, - int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, int16* output_data, const Dims<4>& output_dims, - gemmlowp::GemmContext* gemmlowp_context) { +inline void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data_int32, + const Dims<4>& bias_dims, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, + int32_t output_activation_max, int16_t* output_data, + const Dims<4>& output_dims, + gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; op_params.input_offset = input_offset; op_params.weights_offset = filter_offset; @@ -1769,13 +1774,13 @@ inline void FullyConnected( // legacy, for compatibility with old checked-in code template -void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { static_assert(Ac == FusedActivationFunctionType::kNone || @@ -1793,12 +1798,13 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, #ifdef USE_NEON inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl( const RuntimeShape& input_shape, const int8_t* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const int8_t* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, const RuntimeShape& output_shape, - int8_t* output_data, int row_start, int row_end) { + int32_t input_offset, const RuntimeShape& filter_shape, + const int8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + const RuntimeShape& output_shape, int8_t* output_data, int row_start, + int row_end) { ruy::profiler::ScopeLabel label("FullyConnectedAsGEMVInt8/8bit"); TFLITE_DCHECK_GE(input_shape.DimensionsCount(), 1); TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); @@ -1931,16 +1937,16 @@ inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl( vget_high_s16(input_val)); } if (in < input_size) { - int32 buf[16]; + int32_t buf[16]; vst1q_s32(buf + 0, acc0); vst1q_s32(buf + 4, acc1); vst1q_s32(buf + 8, acc2); vst1q_s32(buf + 12, acc3); for (; in < input_size; in++) { int lane = (in + 8 - input_size) % 4; - const int32 input_val = input_data[in] + input_offset; + const int32_t input_val = input_data[in] + input_offset; for (int k = 0; k < kPeel; k++) { - int32 filter_val = + int32_t filter_val = filter_data[in + (out + k) * input_size] + filter_offset; buf[lane + 4 * k] += filter_val * input_val; } @@ -1969,7 +1975,7 @@ inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl( int32x4_t bias_vec = vld1q_s32(bias_data + out); reduced = vaddq_s32(reduced, bias_vec); if (shift_left) { - const int32 multiplier_power_of_two = 1 << output_shift; + const int32_t multiplier_power_of_two = 1 << output_shift; reduced = vmulq_n_s32(reduced, multiplier_power_of_two); reduced = vqrdmulhq_n_s32(reduced, output_multiplier); } else { @@ -2000,11 +2006,11 @@ inline void LegacyInt8FullyConnectedAsGEMVWorkerImpl( struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { LegacyInt8FullyConnectedAsGEMVWorkerTask( const RuntimeShape& input_shape, const int8_t* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const int8_t* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, + int32_t input_offset, const RuntimeShape& filter_shape, + const int8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, const RuntimeShape& output_shape, int8_t* output_data, int row_start, int row_end) : input_shape_(input_shape), @@ -2036,17 +2042,17 @@ struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { const RuntimeShape& input_shape_; const int8_t* input_data_; - int32 input_offset_; + int32_t input_offset_; const RuntimeShape& filter_shape_; const int8_t* filter_data_; - int32 filter_offset_; + int32_t filter_offset_; const RuntimeShape& bias_shape_; - const int32* bias_data_; - int32 output_offset_; - int32 output_multiplier_; + const int32_t* bias_data_; + int32_t output_offset_; + int32_t output_multiplier_; int output_shift_; - int32 output_activation_min_; - int32 output_activation_max_; + int32_t output_activation_min_; + int32_t output_activation_max_; const RuntimeShape& output_shape_; int8_t* output_data_; int row_start_; @@ -2055,12 +2061,13 @@ struct LegacyInt8FullyConnectedAsGEMVWorkerTask : public gemmlowp::Task { inline void LegacyInt8FullyConnectedAsGEMV( const RuntimeShape& input_shape, const int8_t* input_data, - int32 input_offset, const RuntimeShape& filter_shape, - const int8_t* filter_data, int32 filter_offset, - const RuntimeShape& bias_shape, const int32* bias_data, int32 output_offset, - int32 output_multiplier, int output_shift, int32 output_activation_min, - int32 output_activation_max, const RuntimeShape& output_shape, - int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) { + int32_t input_offset, const RuntimeShape& filter_shape, + const int8_t* filter_data, int32_t filter_offset, + const RuntimeShape& bias_shape, const int32_t* bias_data, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + const RuntimeShape& output_shape, int8_t* output_data, + gemmlowp::GemmContext* gemmlowp_context) { const int output_dim_count = output_shape.DimensionsCount(); const int batches = FlatSizeSkipDim(output_shape, output_dim_count - 1); const int output_rows = output_shape.Dims(output_dim_count - 1); @@ -2104,20 +2111,20 @@ inline void LegacyInt8FullyConnectedAsGEMV( inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const int8* input_data, const RuntimeShape& filter_shape, - const int8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, int8* output_data, - gemmlowp::GemmContext* gemmlowp_context) { + const int8_t* input_data, const RuntimeShape& filter_shape, + const int8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int8_t* output_data, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("FullyConnectedInt8/8bit"); #ifdef USE_NEON - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_GE(filter_shape.DimensionsCount(), 2); TFLITE_DCHECK_GE(output_shape.DimensionsCount(), 1); // TODO(b/62193649): This really should be: @@ -2174,13 +2181,13 @@ inline void FullyConnected( } struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task { - LegacyShuffledFullyConnectedWorkerTask(const uint8* input_data, - const int8* shuffled_weights_data, + LegacyShuffledFullyConnectedWorkerTask(const uint8_t* input_data, + const int8_t* shuffled_weights_data, int batches, int output_depth, int output_stride, int accum_depth, - const int32* bias_data, - int32 output_multiplier, - int output_shift, int16* output_data) + const int32_t* bias_data, + int32_t output_multiplier, + int output_shift, int16_t* output_data) : input_data_(input_data), shuffled_weights_data_(shuffled_weights_data), batches_(batches), @@ -2199,30 +2206,30 @@ struct LegacyShuffledFullyConnectedWorkerTask : gemmlowp::Task { output_shift_, output_data_); } - const uint8* input_data_; - const int8* shuffled_weights_data_; + const uint8_t* input_data_; + const int8_t* shuffled_weights_data_; int batches_; int output_depth_; int output_stride_; int accum_depth_; - const int32* bias_data_; - int32 output_multiplier_; + const int32_t* bias_data_; + int32_t output_multiplier_; int output_shift_; - int16* output_data_; + int16_t* output_data_; }; inline void ShuffledFullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& weights_shape, - const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, uint8* shuffled_input_workspace_data, + const uint8_t* input_data, const RuntimeShape& weights_shape, + const uint8_t* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int16_t* output_data, uint8_t* shuffled_input_workspace_data, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("ShuffledFullyConnected/8bit"); - const int32 output_multiplier = params.output_multiplier; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; (void)gemmlowp_context; // only used in optimized code. TFLITE_DCHECK_EQ(output_activation_min, -32768); TFLITE_DCHECK_EQ(output_activation_max, 32767); @@ -2246,8 +2253,8 @@ inline void ShuffledFullyConnected( // so that just reinterpreting them as int8 values is equivalent to // subtracting 128 from them, thus implementing for free the subtraction of // the zero_point value 128. - const int8* int8_shuffled_weights_data = - reinterpret_cast(shuffled_weights_data); + const int8_t* int8_shuffled_weights_data = + reinterpret_cast(shuffled_weights_data); // Shuffling and xoring of input activations into the workspace buffer if (batches == 1) { @@ -2264,12 +2271,12 @@ inline void ShuffledFullyConnected( } #endif } else if (batches == 4) { - uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data; + uint8_t* shuffled_input_workspace_ptr = shuffled_input_workspace_data; int c = 0; #ifdef USE_NEON const uint8x16_t signbit = vdupq_n_u8(0x80); for (c = 0; c < accum_depth; c += 16) { - const uint8* src_data_ptr = input_data + c; + const uint8_t* src_data_ptr = input_data + c; uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth); uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth); uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth); @@ -2337,12 +2344,12 @@ inline void ShuffledFullyConnected( } inline void ShuffledFullyConnected( - const uint8* input_data, const Dims<4>& input_dims, - const uint8* shuffled_weights_data, const Dims<4>& weights_dims, - const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - int16* output_data, const Dims<4>& output_dims, - uint8* shuffled_input_workspace_data, + const uint8_t* input_data, const Dims<4>& input_dims, + const uint8_t* shuffled_weights_data, const Dims<4>& weights_dims, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_multiplier, int output_shift, int32_t output_activation_min, + int32_t output_activation_max, int16_t* output_data, + const Dims<4>& output_dims, uint8_t* shuffled_input_workspace_data, gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; op_params.output_multiplier = output_multiplier; @@ -2363,7 +2370,7 @@ inline void ExtractPatchIntoBufferColumn( const Dims<4>& input_dims, int w, int h, int b, int kheight, int kwidth, int stride_width, int stride_height, int pad_width, int pad_height, int in_width, int in_height, int in_depth, int single_buffer_length, - int buffer_id, const T* in_data, T* conv_buffer_data, uint8 zero_byte) { + int buffer_id, const T* in_data, T* conv_buffer_data, uint8_t zero_byte) { ExtractPatchIntoBufferColumn( DimsToShape(input_dims), w, h, b, kheight, kwidth, stride_width, stride_height, pad_width, pad_height, in_width, in_height, in_depth, @@ -2375,7 +2382,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, const Dims<4>& filter_dims, int stride_width, int stride_height, int dilation_width_factor, int dilation_height_factor, int pad_width, int pad_height, - const Dims<4>& output_dims, uint8 zero_byte, + const Dims<4>& output_dims, uint8_t zero_byte, T* im2col_data) { tflite::ConvParams op_params; // Padding type is ignored, but still set. @@ -2395,7 +2402,7 @@ void DilatedIm2col(const T* input_data, const Dims<4>& input_dims, template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int kheight, - int kwidth, uint8 zero_byte, T* output_data, + int kwidth, uint8_t zero_byte, T* output_data, const Dims<4>& output_dims) { tflite::ConvParams op_params; // Padding type is ignored, but still set. @@ -2415,7 +2422,7 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride_width, template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int kheight, int kwidth, - uint8 zero_byte, T* output_data, const Dims<4>& output_dims) { + uint8_t zero_byte, T* output_data, const Dims<4>& output_dims) { Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, kwidth, zero_byte, output_data, output_dims); } @@ -2441,7 +2448,7 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, ruy::profiler::ScopeLabel label("Conv"); // NB: the float 0.0f value is represented by all zero bytes. - const uint8 float_zero_byte = 0x00; + const uint8_t float_zero_byte = 0x00; const float* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); @@ -2622,28 +2629,29 @@ void Conv(const float* input_data, const Dims<4>& input_dims, } inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, const RuntimeShape& im2col_shape, - uint8* im2col_data, gemmlowp::GemmContext* gemmlowp_context) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, const RuntimeShape& im2col_shape, + uint8_t* im2col_data, + gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("Conv/8bit"); const int stride_width = params.stride_width; const int stride_height = params.stride_height; const int dilation_width_factor = params.dilation_width_factor; const int dilation_height_factor = params.dilation_height_factor; - const int32 input_offset = params.input_offset; - const int32 filter_offset = params.weights_offset; - const int32 output_offset = params.output_offset; - const int32 output_multiplier = params.output_multiplier; + const int32_t input_offset = params.input_offset; + const int32_t filter_offset = params.weights_offset; + const int32_t output_offset = params.output_offset; + const int32_t output_multiplier = params.output_multiplier; const int output_shift = params.output_shift; - const int32 output_activation_min = params.quantized_activation_min; - const int32 output_activation_max = params.quantized_activation_max; + const int32_t output_activation_min = params.quantized_activation_min; + const int32_t output_activation_max = params.quantized_activation_max; TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const uint8* gemm_input_data = nullptr; + const uint8_t* gemm_input_data = nullptr; const RuntimeShape* gemm_input_shape = nullptr; const int filter_width = filter_shape.Dims(2); const int filter_height = filter_shape.Dims(1); @@ -2712,31 +2720,32 @@ inline void Conv(const ConvParams& params, const RuntimeShape& input_shape, } #endif - gemmlowp::MatrixMap filter_matrix( - filter_data, filter_rows, filter_cols); - gemmlowp::MatrixMap input_matrix( + gemmlowp::MatrixMap + filter_matrix(filter_data, filter_rows, filter_cols); + gemmlowp::MatrixMap input_matrix( gemm_input_data, gemm_input_rows, gemm_input_cols); - gemmlowp::MatrixMap output_matrix( + gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols); const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( bias_data, output_rows, output_offset, output_multiplier, output_shift, output_activation_min, output_activation_max); - gemmlowp::GemmWithOutputPipeline( gemmlowp_context, filter_matrix, input_matrix, &output_matrix, filter_offset, input_offset, output_pipeline); } -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int dilation_width_factor, int dilation_height_factor, int pad_width, int pad_height, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, - uint8* im2col_data, const Dims<4>& im2col_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims, uint8_t* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { tflite::ConvParams op_params; // Padding type is ignored, but still set. @@ -2761,16 +2770,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context); } -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { Conv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, @@ -2781,16 +2790,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -2810,15 +2819,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { +void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemmlowp_context) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu6 || @@ -2835,7 +2845,7 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims, template void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int kheight, int kwidth, - uint8 zero_byte, T* output_data, const Dims<4>& output_dims) { + uint8_t zero_byte, T* output_data, const Dims<4>& output_dims) { Im2col(input_data, input_dims, stride, stride, pad_width, pad_height, kheight, kwidth, zero_byte, output_data, output_dims); } @@ -2863,13 +2873,14 @@ void ConvAsGemm(const float* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, +void ConvAsGemm(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label("ConvAsGemm/8bit"); static_assert(Ac == FusedActivationFunctionType::kNone || @@ -2890,16 +2901,16 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims, TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1); TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1); TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1); - gemmlowp::MatrixMap filter_matrix( - filter_data, output_rows, filter_cols, filter_cols); - gemmlowp::MatrixMap input_matrix( + gemmlowp::MatrixMap + filter_matrix(filter_data, output_rows, filter_cols, filter_cols); + gemmlowp::MatrixMap input_matrix( input_data, filter_cols, output_cols, filter_cols); - gemmlowp::MatrixMap output_matrix( + gemmlowp::MatrixMap output_matrix( output_data, output_rows, output_cols, output_rows); const auto& output_pipeline = GemmlowpOutputPipeline::MakeExp( bias_data, output_rows, output_offset, output_multiplier, -output_shift, output_activation_min, output_activation_max); - gemmlowp::GemmWithOutputPipeline( gemmlowp_context, filter_matrix, input_matrix, &output_matrix, filter_offset, input_offset, output_pipeline); @@ -2962,7 +2973,7 @@ template void TransposeIm2col(const T* input_data, const Dims<4>& input_dims, const Dims<4>& filter_dims, int stride_width, int stride_height, int pad_width, int pad_height, - const Dims<4>& output_dims, uint8 zero_byte, + const Dims<4>& output_dims, uint8_t zero_byte, T* im2col_data) { tflite::ConvParams op_params; // Padding type is ignored, but still set. @@ -3120,25 +3131,25 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, template inline void LstmCell( const LstmCellParams& params, const RuntimeShape& unextended_input_shape, - const uint8* input_data_uint8, + const uint8_t* input_data_uint8, const RuntimeShape& unextended_prev_activ_shape, - const uint8* prev_activ_data_uint8, const RuntimeShape& weights_shape, - const uint8* weights_data_uint8, const RuntimeShape& unextended_bias_shape, - const int32* bias_data_int32, + const uint8_t* prev_activ_data_uint8, const RuntimeShape& weights_shape, + const uint8_t* weights_data_uint8, + const RuntimeShape& unextended_bias_shape, const int32_t* bias_data_int32, const RuntimeShape& unextended_prev_state_shape, - const int16* prev_state_data_int16, + const int16_t* prev_state_data_int16, const RuntimeShape& unextended_output_state_shape, - int16* output_state_data_int16, + int16_t* output_state_data_int16, const RuntimeShape& unextended_output_activ_shape, - uint8* output_activ_data_uint8, + uint8_t* output_activ_data_uint8, const RuntimeShape& unextended_concat_temp_shape, - uint8* concat_temp_data_uint8, + uint8_t* concat_temp_data_uint8, const RuntimeShape& unextended_activ_temp_shape, - int16* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) { + int16_t* activ_temp_data_int16, gemmlowp::GemmContext* gemmlowp_context) { ruy::profiler::ScopeLabel label( "LstmCell/quantized (8bit external, 16bit internal)"); - int32 weights_zero_point = params.weights_zero_point; - int32 accum_multiplier = params.accum_multiplier; + int32_t weights_zero_point = params.weights_zero_point; + int32_t accum_multiplier = params.accum_multiplier; int accum_shift = params.accum_shift; TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); @@ -3193,8 +3204,8 @@ inline void LstmCell( TFLITE_DCHECK_EQ(fc_output_depth, 4 * output_depth); // Depth-concatenate prev_activ and input data together. - uint8 const* concat_input_arrays_data[2] = {input_data_uint8, - prev_activ_data_uint8}; + const uint8_t* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; const RuntimeShape* concat_input_arrays_shapes[2] = {&input_shape, &prev_activ_shape}; tflite::ConcatenationParams concat_params; @@ -3220,13 +3231,13 @@ inline void LstmCell( } #endif if (!gemm_already_performed) { - gemmlowp::MatrixMap + gemmlowp::MatrixMap weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth); - gemmlowp::MatrixMap input_matrix( - concat_temp_data_uint8, fc_accum_depth, fc_batches); - gemmlowp::MatrixMap output_matrix( + gemmlowp::MatrixMap + input_matrix(concat_temp_data_uint8, fc_accum_depth, fc_batches); + gemmlowp::MatrixMap output_matrix( activ_temp_data_int16, fc_output_depth, fc_batches); - typedef gemmlowp::VectorMap + typedef gemmlowp::VectorMap ColVectorMap; ColVectorMap bias_vector(bias_data_int32, fc_output_depth); gemmlowp::OutputStageBiasAddition bias_addition_stage; @@ -3239,21 +3250,23 @@ inline void LstmCell( auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, saturating_cast_int16_stage); gemmlowp::GemmWithOutputPipeline< - uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + uint8_t, int16_t, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( gemmlowp_context, weights_matrix, input_matrix, &output_matrix, -weights_zero_point, -128, output_pipeline); } // Rest of the LSTM cell: tanh and logistic math functions, and some adds // and muls, all done in 16-bit fixed-point. - const int16* input_gate_input_ptr = activ_temp_data_int16; - const int16* input_modulation_gate_input_ptr = + const int16_t* input_gate_input_ptr = activ_temp_data_int16; + const int16_t* input_modulation_gate_input_ptr = activ_temp_data_int16 + output_depth; - const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth; - const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth; - const int16* prev_state_ptr = prev_state_data_int16; - int16* output_state_data_ptr = output_state_data_int16; - uint8* output_activ_data_ptr = output_activ_data_uint8; + const int16_t* forget_gate_input_ptr = + activ_temp_data_int16 + 2 * output_depth; + const int16_t* output_gate_input_ptr = + activ_temp_data_int16 + 3 * output_depth; + const int16_t* prev_state_ptr = prev_state_data_int16; + int16_t* output_state_data_ptr = output_state_data_int16; + uint8_t* output_activ_data_ptr = output_activ_data_uint8; for (int b = 0; b < outer_size; ++b) { int c = 0; @@ -3391,10 +3404,10 @@ inline void LstmCell( *output_state_data_ptr++ = new_state.raw(); // Down-scale the output activations to 8-bit integers, saturating, // and store back to memory. - int16 rescaled_output_activ = + int16_t rescaled_output_activ = gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); - int16 clamped_output_activ = - std::max(-128, std::min(127, rescaled_output_activ)); + int16_t clamped_output_activ = std::max( + -128, std::min(127, rescaled_output_activ)); *output_activ_data_ptr++ = 128 + clamped_output_activ; } input_gate_input_ptr += 3 * output_depth; @@ -3405,17 +3418,18 @@ inline void LstmCell( } template -void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, - const uint8* prev_activ_data_uint8, - const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, - const Dims<4>& weights_dims, const int32* bias_data_int32, - const Dims<4>& bias_dims, const int16* prev_state_data_int16, - const Dims<4>& prev_state_dims, int16* output_state_data_int16, - const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, - const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, - const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, - const Dims<4>& activ_temp_dims, int32 weights_zero_point, - int32 accum_multiplier, int accum_shift, +void LstmCell(const uint8_t* input_data_uint8, const Dims<4>& input_dims, + const uint8_t* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8, + const Dims<4>& weights_dims, const int32_t* bias_data_int32, + const Dims<4>& bias_dims, const int16_t* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16_t* output_state_data_int16, + const Dims<4>& output_state_dims, + uint8_t* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32_t weights_zero_point, + int32_t accum_multiplier, int accum_shift, gemmlowp::GemmContext* gemmlowp_context) { tflite::LstmCellParams op_params; op_params.weights_zero_point = weights_zero_point; @@ -3458,9 +3472,9 @@ void L2Normalization(const float* input_data, const RuntimeShape& input_shape, output_data); } -inline void L2Normalization(const uint8* input_data, +inline void L2Normalization(const uint8_t* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, uint8* output_data, + int32_t input_zero_point, uint8_t* output_data, const RuntimeShape& output_shape) { tflite::L2NormalizationParams op_params; op_params.input_zero_point = input_zero_point; @@ -3476,9 +3490,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { +inline void L2Normalization(const uint8_t* input_data, + const Dims<4>& input_dims, int32_t input_zero_point, + uint8_t* output_data, const Dims<4>& output_dims) { L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, output_data, DimsToShape(output_dims)); } @@ -3506,14 +3520,15 @@ void Add(const float* input1_data, const Dims<4>& input1_dims, } template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { +inline void Add(int left_shift, const uint8_t* input1_data, + const Dims<4>& input1_dims, int32_t input1_offset, + int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, + int input2_shift, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -3545,15 +3560,15 @@ inline void Add(int left_shift, const uint8* input1_data, } template -void Add(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +void Add(const int32_t* input1_data, const Dims<4>& input1_dims, + const int32_t* input2_data, const Dims<4>& input2_dims, + int32_t* output_data, const Dims<4>& output_dims) { ruy::profiler::ScopeLabel label("Add/int32"); TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); tflite::ArithmeticParams op_params; - op_params.quantized_activation_min = std::numeric_limits::min(); - op_params.quantized_activation_max = std::numeric_limits::max(); + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); Add(op_params, DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), output_data); @@ -3573,15 +3588,15 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, } template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastAdd(int left_shift, const uint8_t* input1_data, + const Dims<4>& input1_dims, int32_t input1_offset, + int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, + int input2_shift, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || @@ -3616,12 +3631,13 @@ inline void BroadcastAdd(int left_shift, const uint8* input1_data, template inline void BroadcastAddFivefold( int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, int input2_shift, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -3672,11 +3688,11 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, } template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, +inline void Add(const int16_t* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16_t* input2_data, const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { + int16_t output_activation_min, int16_t output_activation_max, + int16_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -3728,12 +3744,12 @@ void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, output_data); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastMul(const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, const uint8_t* input2_data, + const Dims<4>& input2_dims, int32_t input2_offset, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; SetActivationParams(output_activation_min, output_activation_max, &op_params); @@ -3750,12 +3766,12 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, // legacy, for compatibility with old checked-in code template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastMul(const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, const uint8_t* input2_data, + const Dims<4>& input2_dims, int32_t input2_offset, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, input2_dims, input2_offset, output_offset, output_multiplier, @@ -3808,11 +3824,11 @@ bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, output_dims); } -inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, +inline bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::PoolParams params; params.stride_height = stride_height; @@ -3829,11 +3845,11 @@ inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, +bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu6 || @@ -3851,10 +3867,10 @@ bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, + int stride, int pad_width, int pad_height, int filter_width, + int filter_height, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { return AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, filter_width, filter_height, @@ -3902,11 +3918,12 @@ void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, filter_width, filter_height, output_data, output_dims); } -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims) { PoolParams params; params.stride_height = stride_height; params.stride_width = stride_width; @@ -3922,10 +3939,10 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int filter_width, int filter_height, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -3943,10 +3960,10 @@ void MaxPool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, +void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, filter_width, filter_height, output_activation_min, output_activation_max, output_data, output_dims); @@ -3993,10 +4010,10 @@ void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride, } inline void Softmax(const SoftmaxParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { - const int32 input_beta_multiplier = params.input_multiplier; - const int32 input_beta_left_shift = params.input_left_shift; + const RuntimeShape& input_shape, const uint8_t* input_data, + const RuntimeShape& output_shape, uint8_t* output_data) { + const int32_t input_beta_multiplier = params.input_multiplier; + const int32_t input_beta_left_shift = params.input_left_shift; const int diff_min = params.diff_min; // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as @@ -4006,9 +4023,10 @@ inline void Softmax(const SoftmaxParams& params, static const int kScaledDiffIntegerBits = 5; static const int kAccumulationIntegerBits = 12; using FixedPointScaledDiff = - gemmlowp::FixedPoint; - using FixedPointAccum = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + gemmlowp::FixedPoint; + using FixedPointAccum = + gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; ruy::profiler::ScopeLabel label("Softmax/8bit"); const int trailing_dim = input_shape.DimensionsCount() - 1; @@ -4018,11 +4036,11 @@ inline void Softmax(const SoftmaxParams& params, MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); for (int b = 0; b < outer_size; ++b) { - const uint8* input_data_ptr = input_data + b * depth; - uint8* output_data_ptr = output_data + b * depth; + const uint8_t* input_data_ptr = input_data + b * depth; + uint8_t* output_data_ptr = output_data + b * depth; // Determine the largest entry in the current row - uint8 max_in_row = 0; + uint8_t max_in_row = 0; { int c = 0; #ifdef USE_NEON @@ -4114,9 +4132,10 @@ inline void Softmax(const SoftmaxParams& params, FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); #endif for (; c < depth; ++c) { - int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + int32_t input_diff = + static_cast(input_data_ptr[c]) - max_in_row; if (input_diff >= diff_min) { - const int32 input_diff_rescaled = + const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( input_diff, input_beta_multiplier, input_beta_left_shift); const FixedPointScaledDiff scaled_diff_f8 = @@ -4172,16 +4191,17 @@ inline void Softmax(const SoftmaxParams& params, } #endif for (; c < depth; ++c) { - int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + int32_t input_diff = + static_cast(input_data_ptr[c]) - max_in_row; if (input_diff >= diff_min) { - const int32 input_diff_rescaled = + const int32_t input_diff_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( input_diff, input_beta_multiplier, input_beta_left_shift); const FixedPointScaledDiff scaled_diff_f8 = FixedPointScaledDiff::FromRaw(input_diff_rescaled); FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - int32 unsat_output = gemmlowp::RoundingDivideByPOT( + int32_t unsat_output = gemmlowp::RoundingDivideByPOT( (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0); @@ -4209,20 +4229,20 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, - const RuntimeShape& output_shape) { +inline void Softmax(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_beta_multiplier, + int32_t input_beta_left_shift, int diff_min, + uint8_t* output_data, const RuntimeShape& output_shape) { SoftmaxParams params; params.input_multiplier = input_beta_multiplier; params.input_left_shift = input_beta_left_shift; params.diff_min = diff_min; Softmax(params, input_shape, input_data, output_shape, output_data); } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, - const Dims<4>& output_dims) { +inline void Softmax(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_beta_multiplier, + int32_t input_beta_left_shift, int diff_min, + uint8_t* output_data, const Dims<4>& output_dims) { Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, input_beta_left_shift, diff_min, output_data, DimsToShape(output_dims)); @@ -4241,11 +4261,12 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const RuntimeShape& output_shape) { +inline void LogSoftmax(const uint8_t* input_data, + const RuntimeShape& input_shape, + int32_t input_multiplier, int32_t input_left_shift, + int32_t reverse_scaling_divisor, + int32_t reverse_scaling_right_shift, int diff_min, + uint8_t* output_data, const RuntimeShape& output_shape) { SoftmaxParams params; params.input_multiplier = input_multiplier; params.input_left_shift = input_left_shift; @@ -4256,11 +4277,11 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, output_data); } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_multiplier, int32_t input_left_shift, + int32_t reverse_scaling_divisor, + int32_t reverse_scaling_right_shift, int diff_min, + uint8_t* output_data, const Dims<4>& output_dims) { reference_ops::LogSoftmax( input_data, DimsToShape(input_dims), input_multiplier, input_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, @@ -4268,12 +4289,12 @@ inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, } inline void Logistic(const LogisticParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { + const RuntimeShape& input_shape, const uint8_t* input_data, + const RuntimeShape& output_shape, uint8_t* output_data) { ruy::profiler::ScopeLabel label("Logistic/Uint8"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int32 input_multiplier = params.input_multiplier; + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int32_t input_multiplier = params.input_multiplier; const int input_left_shift = params.input_left_shift; const int size = MatchingFlatSize(input_shape, output_shape); @@ -4378,39 +4399,39 @@ inline void Logistic(const LogisticParams& params, #endif // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const uint8 input_val_u8 = input_data[c]; - const int32 input_val_centered = - static_cast(input_val_u8) - input_zero_point; - uint8 output_val; + const uint8_t input_val_u8 = input_data[c]; + const int32_t input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8_t output_val; if (input_val_centered < -input_range_radius) { output_val = 0; } else if (input_val_centered > input_range_radius) { output_val = 255; } else { - const int32 input_val_rescaled = + const int32_t input_val_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( input_val_centered, input_multiplier, input_left_shift); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + int32_t output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); if (output_val_s32 == 256) { output_val_s32 = 255; } TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); - output_val = static_cast(output_val_s32); + output_val = static_cast(output_val_s32); } output_data[c] = output_val; } } -inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +inline void Logistic(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const RuntimeShape& output_shape) { LogisticParams params; params.input_zero_point = input_zero_point; params.input_range_radius = input_range_radius; @@ -4425,31 +4446,31 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, output_data); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { +inline void Logistic(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const Dims<4>& output_dims) { Logistic(input_data, DimsToShape(input_dims), input_zero_point, input_range_radius, input_multiplier, input_left_shift, output_data, DimsToShape(output_dims)); } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, - const RuntimeShape& output_shape, int16* output_data) { +inline void Logistic(const RuntimeShape& input_shape, const int16_t* input_data, + const RuntimeShape& output_shape, int16_t* output_data) { LogisticParams params; // No params currently needed by int16 Logistic. Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, - int16* output_data, const RuntimeShape& output_shape) { +inline void Logistic(const int16_t* input_data, const RuntimeShape& input_shape, + int16_t* output_data, const RuntimeShape& output_shape) { LogisticParams params; // No params currently needed by int16 Logistic. Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16_t* input_data, const Dims<4>& input_dims, + int16_t* output_data, const Dims<4>& output_dims) { Logistic(input_data, DimsToShape(input_dims), output_data, DimsToShape(output_dims)); } @@ -4461,13 +4482,13 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& output_shape, - uint8* output_data) { + const uint8_t* input_data, const RuntimeShape& output_shape, + uint8_t* output_data) { // Note that this is almost the exact same code as in Logistic(). ruy::profiler::ScopeLabel label("Tanh"); - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int32 input_multiplier = params.input_multiplier; + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int32_t input_multiplier = params.input_multiplier; const int input_left_shift = params.input_left_shift; const int size = MatchingFlatSize(input_shape, output_shape); @@ -4580,40 +4601,40 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, #endif // Leftover loop: handle one value at a time with scalar code. for (; c < size; ++c) { - const uint8 input_val_u8 = input_data[c]; - const int32 input_val_centered = - static_cast(input_val_u8) - input_zero_point; - uint8 output_val; + const uint8_t input_val_u8 = input_data[c]; + const int32_t input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8_t output_val; if (input_val_centered < -input_range_radius) { output_val = 0; } else if (input_val_centered > input_range_radius) { output_val = 255; } else { - const int32 input_val_rescaled = + const int32_t input_val_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( input_val_centered, input_multiplier, input_left_shift); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); using gemmlowp::RoundingDivideByPOT; - int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + int32_t output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); output_val_s32 += output_zero_point; if (output_val_s32 == 256) { output_val_s32 = 255; } TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); - output_val = static_cast(output_val_s32); + output_val = static_cast(output_val_s32); } output_data[c] = output_val; } } -inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +inline void Tanh(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const RuntimeShape& output_shape) { TanhParams params; params.input_zero_point = input_zero_point; params.input_range_radius = input_range_radius; @@ -4622,25 +4643,25 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, Tanh(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { +inline void Tanh(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const Dims<4>& output_dims) { Tanh(input_data, DimsToShape(input_dims), input_zero_point, input_range_radius, input_multiplier, input_left_shift, output_data, DimsToShape(output_dims)); } -inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, - int input_left_shift, int16* output_data, +inline void Tanh(const int16_t* input_data, const RuntimeShape& input_shape, + int input_left_shift, int16_t* output_data, const RuntimeShape& output_shape) { TanhParams params; params.input_left_shift = input_left_shift; Tanh(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, - int input_left_shift, int16* output_data, +inline void Tanh(const int16_t* input_data, const Dims<4>& input_dims, + int input_left_shift, int16_t* output_data, const Dims<4>& output_dims) { Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, DimsToShape(output_dims)); @@ -4692,10 +4713,10 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, output_activation_max, output_data, output_dims); } -inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32 output_activation_min, int32 output_activation_max, - int32* output_data, const Dims<4>& output_dims) { +inline void Mul(const int32_t* input1_data, const Dims<4>& input1_dims, + const int32_t* input2_data, const Dims<4>& input2_dims, + int32_t output_activation_min, int32_t output_activation_max, + int32_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; op_params.quantized_activation_min = output_activation_min; op_params.quantized_activation_max = output_activation_max; @@ -4706,9 +4727,9 @@ inline void Mul(const int32* input1_data, const Dims<4>& input1_dims, } template -void Mul(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +void Mul(const int32_t* input1_data, const Dims<4>& input1_dims, + const int32_t* input2_data, const Dims<4>& input2_dims, + int32_t* output_data, const Dims<4>& output_dims) { TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); tflite::ArithmeticParams op_params; // No parameters needed. @@ -4718,9 +4739,9 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, DimsToShape(output_dims), output_data); } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Mul(const int16_t* input1_data, const Dims<4>& input1_dims, + const int16_t* input2_data, const Dims<4>& input2_dims, + int16_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; // No parameters needed. @@ -4729,10 +4750,10 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, output_data); } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int32 output_offset, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void Mul(const int16_t* input1_data, const Dims<4>& input1_dims, + const int16_t* input2_data, const Dims<4>& input2_dims, + int32_t output_offset, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; op_params.output_offset = output_offset; @@ -4802,7 +4823,7 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, } inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, - const int32* output_size_data, + const int32_t* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims, bool align_corners) { tflite::ResizeBilinearParams op_params; @@ -4813,10 +4834,11 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims), output_data); } -inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, - const int32* output_size_data, - const Dims<4>& output_size_dims, uint8* output_data, - const Dims<4>& output_dims, bool align_corners) { +inline void ResizeBilinear(const uint8_t* input_data, const Dims<4>& input_dims, + const int32_t* output_size_data, + const Dims<4>& output_size_dims, + uint8_t* output_data, const Dims<4>& output_dims, + bool align_corners) { tflite::ResizeBilinearParams op_params; op_params.align_corners = align_corners; op_params.half_pixel_centers = false; @@ -4827,7 +4849,7 @@ inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, - const int32* output_size_data, + const int32_t* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, @@ -4835,19 +4857,19 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } // legacy, for compatibility with old checked-in code -inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, - const int32* output_size_data, - const Dims<4>& output_size_dims, uint8* output_data, - const Dims<4>& output_dims) { +inline void ResizeBilinear(const uint8_t* input_data, const Dims<4>& input_dims, + const int32_t* output_size_data, + const Dims<4>& output_size_dims, + uint8_t* output_data, const Dims<4>& output_dims) { ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, output_data, output_dims, /*align_corners=*/false); } template inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, + const int32_t* block_shape_data, const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, + const int32_t* crops_data, const Dims<4>& crops_dims, T* output_data, const Dims<4>& output_dims) { BatchToSpaceND(DimsToShape(input_dims), input_data, DimsToShape(block_shape_dims), block_shape_data, @@ -4930,8 +4952,8 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, DimsToShape(output_dims), output_data); } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, +inline void Dequantize(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t zero_point, double scale, float* output_data, const Dims<4>& output_dims) { tflite::DequantizationParams op_params; op_params.zero_point = zero_point; diff --git a/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h b/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h index 2e1abf7a59ac12..909965e136443e 100644 --- a/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h +++ b/tensorflow/lite/kernels/internal/optimized/resize_bilinear.h @@ -45,7 +45,7 @@ namespace resize_bilinear { // (a) Optimizations can be tried experimentally. // (b) Optimizations can be specialized for architectures, eg Intel vs ARM. -inline int16x8_t Load8IntoLowerS16(const uint8* data_ptr) { +inline int16x8_t Load8IntoLowerS16(const uint8_t* data_ptr) { return vreinterpretq_s16_u16(vmovl_u8(vld1_u8(data_ptr))); } @@ -54,7 +54,7 @@ inline uint16x8_t Move8IntoUpperU16(const uint8x8_t vec_val) { return vshlq_n_u16(vmovl_u8(vec_val), 8); } -inline uint16x8_t Load8IntoUpperU16(const uint8* data_ptr) { +inline uint16x8_t Load8IntoUpperU16(const uint8_t* data_ptr) { return Move8IntoUpperU16(vld1_u8(data_ptr)); } @@ -107,7 +107,7 @@ struct op_int16x8_t { // This really selects vshlq_n_s16, but requires a longer implementation to // convert the shift argument back to a constant. In some compiles are macros // requiring constant args. - inline op_int16x8_t operator<<=(int32 left_shift) { + inline op_int16x8_t operator<<=(int32_t left_shift) { switch (left_shift) { case 1: val = vshlq_n_s16(val, 1); @@ -127,7 +127,7 @@ struct op_int16x8_t { // This really selects vshrq_n_u16, but requires a longer implementation to // convert the shift argument back to a constant. In some compiles are macros // requiring constant args. - inline op_int16x8_t operator>>=(int32 right_shift) { + inline op_int16x8_t operator>>=(int32_t right_shift) { switch (right_shift) { case 1: val = vshrq_n_s16(val, 1); @@ -154,11 +154,11 @@ struct op_int16x8_t { lhs -= rhs; return lhs; } - friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32 left_shift) { + friend inline op_int16x8_t operator<<(op_int16x8_t lhs, int32_t left_shift) { lhs <<= left_shift; return lhs; } - friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32 right_shift) { + friend inline op_int16x8_t operator>>(op_int16x8_t lhs, int32_t right_shift) { lhs >>= right_shift; return lhs; } @@ -191,7 +191,7 @@ struct op_uint16x8_t { // This really selects vshlq_n_s16, but requires a longer implementation to // convert the shift argument back to a constant. In some compiles are macros // requiring constant args. - inline op_uint16x8_t operator<<=(int32 left_shift) { + inline op_uint16x8_t operator<<=(int32_t left_shift) { switch (left_shift) { case 1: val = vshlq_n_u16(val, 1); @@ -211,7 +211,7 @@ struct op_uint16x8_t { // This really selects vshrq_n_u16, but requires a longer implementation to // convert the shift argument back to a constant. In some compiles are macros // requiring constant args. - inline op_uint16x8_t operator>>=(int32 right_shift) { + inline op_uint16x8_t operator>>=(int32_t right_shift) { switch (right_shift) { case 1: val = vshrq_n_u16(val, 1); @@ -238,11 +238,13 @@ struct op_uint16x8_t { lhs -= rhs; return lhs; } - friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs, int32 left_shift) { + friend inline op_uint16x8_t operator<<(op_uint16x8_t lhs, + int32_t left_shift) { lhs <<= left_shift; return lhs; } - friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs, int32 right_shift) { + friend inline op_uint16x8_t operator>>(op_uint16x8_t lhs, + int32_t right_shift) { lhs >>= right_shift; return lhs; } @@ -262,20 +264,20 @@ inline op_uint16x8_t VReinterpretQU16S16(const op_int16x8_t& other) { // // This optimization is for the half_pixel_centers == true version, for uint8. // There are versions for NEON and non-NEON compilation. -inline void ResizeBilinear888Uint8(int32 batches, int32 input_height, - int32 input_width, int32 depth, - const uint8* input_data, - uint8* output_data) { +inline void ResizeBilinear888Uint8(int32_t batches, int32_t input_height, + int32_t input_width, int32_t depth, + const uint8_t* input_data, + uint8_t* output_data) { TFLITE_DCHECK_GE(input_height, 1); TFLITE_DCHECK_GE(input_width, 1); TFLITE_DCHECK_EQ(depth % 8, 0); - const int32 input_row_stride = input_width * depth; - const int32 output_row_stride = input_row_stride * 8; + const int32_t input_row_stride = input_width * depth; + const int32_t output_row_stride = input_row_stride * 8; for (int b = 0; b < batches; ++b) { - const uint8* input_base_ptr = + const uint8_t* input_base_ptr = input_data + b * input_row_stride * input_height; - uint8* output_base_ptr = + uint8_t* output_base_ptr = output_data + b * output_row_stride * input_height * 8; #ifdef USE_NEON @@ -361,24 +363,24 @@ inline void ResizeBilinear888Uint8(int32 batches, int32 input_height, } // Fill out remainder of top margin. std::memcpy(output_base_ptr + output_row_stride, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); output_base_ptr += output_row_stride * 4; // Main rows. for (int k = 0; k < (input_height - 1); ++k) { for (int c_block = 0; c_block < depth; c_block += 8) { - uint8* output_base_ptr_0 = output_base_ptr; - uint8* output_base_ptr_1; - uint8* output_base_ptr_2; - uint8* output_base_ptr_3; - uint8* output_base_ptr_4; - uint8* output_base_ptr_5; - uint8* output_base_ptr_6; - uint8* output_base_ptr_7; + uint8_t* output_base_ptr_0 = output_base_ptr; + uint8_t* output_base_ptr_1; + uint8_t* output_base_ptr_2; + uint8_t* output_base_ptr_3; + uint8_t* output_base_ptr_4; + uint8_t* output_base_ptr_5; + uint8_t* output_base_ptr_6; + uint8_t* output_base_ptr_7; op_uint16x8_t accum_0_c_v; op_uint16x8_t accum_1_c_v; @@ -774,11 +776,11 @@ inline void ResizeBilinear888Uint8(int32 batches, int32 input_height, } // Fill out remainder of bottom margin. std::memcpy(output_base_ptr + output_row_stride, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); std::memcpy(output_base_ptr + output_row_stride * 2, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); std::memcpy(output_base_ptr + output_row_stride * 3, output_base_ptr, - output_row_stride * sizeof(uint8)); + output_row_stride * sizeof(uint8_t)); #else // USE_NEON for (int c_block = 0; c_block < depth; c_block += 8) { @@ -1227,7 +1229,7 @@ inline void ResizeBilinear888Uint8(int32 batches, int32 input_height, } // namespace resize_bilinear #ifdef USE_NEON -inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, +inline void ResizeBilinearKernel(const float* input_ptr, int32_t depth, float scale, float* output_ptr) { int ic = 0; // Handle 32 input channels at a time. @@ -1323,21 +1325,22 @@ inline void ResizeBilinearKernel(const float* input_ptr, int32 depth, } #endif -inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, - int32 x, int32 y, int32 depth, int32 batch, +inline void ResizeBilinearKernel2x2(int32_t x0, int32_t x1, int32_t y0, + int32_t y1, int32_t x, int32_t y, + int32_t depth, int32_t batch, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4); TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4); - const int32 input_width = input_shape.Dims(2); - const int32 output_width = output_shape.Dims(2); + const int32_t input_width = input_shape.Dims(2); + const int32_t output_width = output_shape.Dims(2); - const int32 input_x_offset = (x1 - x0) * depth; - const int32 input_y_offset = (y1 - y0) * depth * input_width; - const int32 output_x_offset = depth; - const int32 output_y_offset = depth * output_width; + const int32_t input_x_offset = (x1 - x0) * depth; + const int32_t input_y_offset = (y1 - y0) * depth * input_width; + const int32_t output_x_offset = depth; + const int32_t output_y_offset = depth * output_width; #ifdef USE_NEON TFLITE_DCHECK(x1 >= x0); @@ -1440,7 +1443,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, } // Handle one input channel at a time. for (; ic < depth; ic++) { - const int32 input_offset = Offset(input_shape, batch, y0, x0, ic); + const int32_t input_offset = Offset(input_shape, batch, y0, x0, ic); float x0y0 = input_data[input_offset]; float x1y0 = input_data[input_offset + input_x_offset]; @@ -1448,7 +1451,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, float x1y1 = input_data[input_offset + input_x_offset + input_y_offset]; // Top left corner. - const int32 output_offset = Offset(output_shape, batch, y, x, ic); + const int32_t output_offset = Offset(output_shape, batch, y, x, ic); output_data[output_offset] = x0y0; // Top right corner. @@ -1489,9 +1492,9 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1, #endif } -inline void ResizeBilinear2x2(int32 batches, int32 input_height, - int32 input_width, int32 depth, - int32 output_height, int32 output_width, +inline void ResizeBilinear2x2(int32_t batches, int32_t input_height, + int32_t input_width, int32_t depth, + int32_t output_height, int32_t output_width, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, @@ -1499,8 +1502,8 @@ inline void ResizeBilinear2x2(int32 batches, int32 input_height, for (int b = 0; b < batches; b++) { for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) { for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) { - int32 x1 = std::min(x0 + 1, input_width - 1); - int32 y1 = std::min(y0 + 1, input_height - 1); + int32_t x1 = std::min(x0 + 1, input_width - 1); + int32_t y1 = std::min(y0 + 1, input_height - 1); ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape, input_data, output_shape, output_data); } @@ -1509,32 +1512,32 @@ inline void ResizeBilinear2x2(int32 batches, int32 input_height, } inline void ResizeBilinearGeneric( - int32 batches, int32 input_height, int32 input_width, int32 depth, - int32 output_height, int32 output_width, float height_scale, + int32_t batches, int32_t input_height, int32_t input_width, int32_t depth, + int32_t output_height, int32_t output_width, float height_scale, float width_scale, const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data, const bool half_pixel_centers) { memset(output_data, 0, batches * output_height * output_width * depth * sizeof(float)); - int32 output_offset = 0; + int32_t output_offset = 0; for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { float input_y; - int32 y0, y1; + int32_t y0, y1; reference_ops::ComputeInterpolationValues( y, height_scale, half_pixel_centers, input_height, &input_y, &y0, &y1); for (int x = 0; x < output_width; ++x) { float input_x; - int32 x0, x1; + int32_t x0, x1; reference_ops::ComputeInterpolationValues( x, width_scale, half_pixel_centers, input_width, &input_x, &x0, &x1); float* output_ptr = &output_data[output_offset]; // Run kernel on the 4 corners of the bilinear resize algorithm. - int32 input_offset = Offset(input_shape, b, y0, x0, 0); + int32_t input_offset = Offset(input_shape, b, y0, x0, 0); float scale = (1 - (input_y - y0)) * (1 - (input_x - x0)); const float* input_ptr = &input_data[input_offset]; ResizeBilinearKernel(input_ptr, depth, scale, output_ptr); @@ -1562,8 +1565,8 @@ inline void ResizeBilinearGeneric( template inline void ResizeBilinearGenericSmallChannel( - int32 batches, int32 input_height, int32 input_width, int32 depth, - int32 output_height, int32 output_width, float height_scale, + int32_t batches, int32_t input_height, int32_t input_width, int32_t depth, + int32_t output_height, int32_t output_width, float height_scale, float width_scale, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, T* output_data, const bool half_pixel_centers) { @@ -1573,21 +1576,21 @@ inline void ResizeBilinearGenericSmallChannel( for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { float input_y; - int32 y0, y1; + int32_t y0, y1; reference_ops::ComputeInterpolationValues( y, height_scale, half_pixel_centers, input_height, &input_y, &y0, &y1); for (int x = 0; x < output_width; ++x) { float input_x; - int32 x0, x1; + int32_t x0, x1; reference_ops::ComputeInterpolationValues( x, width_scale, half_pixel_centers, input_width, &input_x, &x0, &x1); - int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0), - Offset(input_shape, b, y0, x1, 0), - Offset(input_shape, b, y1, x0, 0), - Offset(input_shape, b, y1, x1, 0)}; + int32_t input_offset[4] = {Offset(input_shape, b, y0, x0, 0), + Offset(input_shape, b, y0, x1, 0), + Offset(input_shape, b, y1, x0, 0), + Offset(input_shape, b, y1, x1, 0)}; float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)), (1 - (input_y - y0)) * (input_x - x0), (input_y - y0) * (1 - (input_x - x0)), @@ -1610,7 +1613,7 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, const float* input_data, const RuntimeShape& output_size_shape, - const int32* output_size_data, + const int32_t* output_size_data, const RuntimeShape& unextended_output_shape, float* output_data) { ruy::profiler::ScopeLabel label("ResizeBilinear"); @@ -1623,14 +1626,14 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape); - int32 batches = MatchingDim(input_shape, 0, output_shape, 0); - int32 input_height = input_shape.Dims(1); - int32 input_width = input_shape.Dims(2); - int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + int32_t batches = MatchingDim(input_shape, 0, output_shape, 0); + int32_t input_height = input_shape.Dims(1); + int32_t input_width = input_shape.Dims(2); + int32_t depth = MatchingDim(input_shape, 3, output_shape, 3); TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2); - int32 output_height = output_size_data[0]; - int32 output_width = output_size_data[1]; + int32_t output_height = output_size_data[0]; + int32_t output_width = output_size_data[1]; // Specialize for 2x2 upsample. if (!op_params.align_corners && !op_params.half_pixel_centers && @@ -1659,11 +1662,11 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, // or int16 arithmetic. inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, - const uint8* input_data, + const uint8_t* input_data, const RuntimeShape& output_size_shape, - const int32* output_size_data, + const int32_t* output_size_data, const RuntimeShape& unextended_output_shape, - uint8* output_data) { + uint8_t* output_data) { ruy::profiler::ScopeLabel label("ResizeBilinearUint8"); // If half_pixel_centers is True, align_corners must be False. TFLITE_DCHECK(!op_params.half_pixel_centers || !op_params.align_corners); @@ -1674,18 +1677,18 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape output_shape = RuntimeShape::ExtendedShape(4, unextended_output_shape); - int32 batches = MatchingDim(input_shape, 0, output_shape, 0); - int32 input_height = input_shape.Dims(1); - int32 input_width = input_shape.Dims(2); - int32 depth = MatchingDim(input_shape, 3, output_shape, 3); + int32_t batches = MatchingDim(input_shape, 0, output_shape, 0); + int32_t input_height = input_shape.Dims(1); + int32_t input_width = input_shape.Dims(2); + int32_t depth = MatchingDim(input_shape, 3, output_shape, 3); TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2); - int32 output_height = output_size_data[0]; - int32 output_width = output_size_data[1]; + int32_t output_height = output_size_data[0]; + int32_t output_width = output_size_data[1]; if (!op_params.align_corners && op_params.half_pixel_centers && ((depth % 8) == 0)) { - const int32 scale = output_height / input_height; + const int32_t scale = output_height / input_height; // Restricting the minimum output dimensions may not be necessary, but // ensures that kernels can use unrolling with minimal code size. if ((output_height >= 8) && (output_width >= 8) && @@ -1709,7 +1712,7 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, ? (static_cast(input_width - 1) / (output_width - 1)) : (static_cast(input_width) / output_width); - ResizeBilinearGenericSmallChannel( + ResizeBilinearGenericSmallChannel( batches, input_height, input_width, depth, output_height, output_width, height_scale, width_scale, input_shape, input_data, output_shape, output_data, op_params.half_pixel_centers); @@ -1718,11 +1721,11 @@ inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, // TODO(b/180609127) Create optimized int8 version from uint8. Call from here. inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params, const RuntimeShape& unextended_input_shape, - const int8* input_data, + const int8_t* input_data, const RuntimeShape& unextended_output_size_shape, - const int32* output_size_data, + const int32_t* output_size_data, const RuntimeShape& unextended_output_shape, - int8* output_data) { + int8_t* output_data) { reference_ops::ResizeBilinearInteger(op_params, unextended_input_shape, input_data, unextended_output_size_shape, output_size_data, diff --git a/tensorflow/lite/kernels/internal/portable_tensor_utils.h b/tensorflow/lite/kernels/internal/portable_tensor_utils.h index d37fe6e4c89836..ed59fd012ca637 100644 --- a/tensorflow/lite/kernels/internal/portable_tensor_utils.h +++ b/tensorflow/lite/kernels/internal/portable_tensor_utils.h @@ -317,7 +317,7 @@ void ApplyLayerNormFloat(const int16_t* input, void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output); -// Same as above but the internal calcualtion is float. +// Same as above but the internal calculation is float. void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output); @@ -333,7 +333,7 @@ void ApplySigmoidFloat(const int16_t* input, int32_t n_batch, int32_t n_input, void ApplyTanh(int32_t intger_bits, const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output); -// Apply Tanh to a quantized vector. Tbe internal calculation is in float. +// Apply Tanh to a quantized vector. The internal calculation is in float. // - Input has 2^(integer_bits) as scale. // - Output has Q0.15 as scale. void ApplyTanhFloat(const int16_t* input, int32_t n_batch, int32_t n_input, diff --git a/tensorflow/lite/kernels/internal/quantization_util.h b/tensorflow/lite/kernels/internal/quantization_util.h index 0ee914b0689ed1..eb4e84013e1144 100644 --- a/tensorflow/lite/kernels/internal/quantization_util.h +++ b/tensorflow/lite/kernels/internal/quantization_util.h @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/lite/kernels/internal/compatibility.h" -#include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -103,6 +102,7 @@ QuantizationParams ChooseQuantizationParams(double rmin, double rmax) { return ChooseQuantizationParams(rmin, rmax, false); } +// LINT.IfChange // Converts a floating-point number to an integer. For all inputs x where // static_cast(x) is legal according to the C++ standard, the result // is identical to that cast (i.e. the result is x with its fractional part @@ -167,6 +167,7 @@ IntOut SafeCast(FloatIn x) { return x < 0 ? std::numeric_limits::min() : std::numeric_limits::max(); } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h) // Decompose a double multiplier into a Q0.31 int32 representation of its // significand, and shift representation of NEGATIVE its exponent --- diff --git a/tensorflow/lite/kernels/internal/quantization_util_test.cc b/tensorflow/lite/kernels/internal/quantization_util_test.cc index aec0b2ba54bfde..aa9c2741f3c679 100644 --- a/tensorflow/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/lite/kernels/internal/quantization_util_test.cc @@ -160,13 +160,13 @@ TEST(QuantizationUtilTest, SafeCast) { // 255 | 30.0 // 128 | 10.0 TEST(QuantizationUtilTest, ChooseQuantizationParams) { - QuantizationParams qp = ChooseQuantizationParams(-10.0, 30.0); + QuantizationParams qp = ChooseQuantizationParams(-10.0, 30.0); EXPECT_NEAR(qp.scale, 0.156863, 1e-5); EXPECT_EQ(qp.zero_point, 64); } TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) { - QuantizationParams qp = ChooseQuantizationParams(0.0, 30.0); + QuantizationParams qp = ChooseQuantizationParams(0.0, 30.0); EXPECT_NEAR(qp.scale, 0.117647, 1e-5); EXPECT_EQ(qp.zero_point, 0); } @@ -174,23 +174,23 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMinBoundary) { #if GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroNotInRange) { // Assumption is that zero is within the range. - EXPECT_DEATH(ChooseQuantizationParams(10.0, 30.0), ""); + EXPECT_DEATH(ChooseQuantizationParams(10.0, 30.0), ""); } TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangePositive) { // Assumption is that zero is within the range. - EXPECT_DEATH(ChooseQuantizationParams(30.0, 30.0), ""); + EXPECT_DEATH(ChooseQuantizationParams(30.0, 30.0), ""); } #endif // GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsEmptyRangeZero) { - QuantizationParams qp = ChooseQuantizationParams(0.0, 0.0); + QuantizationParams qp = ChooseQuantizationParams(0.0, 0.0); EXPECT_NEAR(qp.scale, 0.0, 1e-5); EXPECT_EQ(qp.zero_point, 0); } TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) { - QuantizationParams qp = ChooseQuantizationParams(-10.0, 0.0); + QuantizationParams qp = ChooseQuantizationParams(-10.0, 0.0); EXPECT_NEAR(qp.scale, 0.039216, 1e-5); EXPECT_EQ(qp.zero_point, 255); } @@ -330,7 +330,7 @@ TEST(QuantizationUtilTest, IntegerDoubleCompare) { #if GTEST_HAS_DEATH_TEST TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) { - EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); + EXPECT_DEATH(ChooseQuantizationParams(10.0, -30.0), ""); } TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOneExp) { @@ -533,12 +533,12 @@ TEST(QuantizationUtilTest, QuantizeMultiplierArray) { const std::vector weights = {-4, -2, -1, -0.5, -0.25, -0.125, 0, 0.125, 0.25, 0.5, 1, 2, 4}; const int size = weights.size(); - std::vector effective_scale_significand(size); + std::vector effective_scale_significand(size); std::vector effective_scale_shift(size); QuantizeMultiplierArray(weights.data(), size, effective_scale_significand.data(), effective_scale_shift.data()); - const std::vector expected_effective_scale_significand = { + const std::vector expected_effective_scale_significand = { -1073741824, // float scale = -4 -1073741824, // float scale = -2 -1073741824, // float scale = -1 diff --git a/tensorflow/lite/kernels/internal/reference/gather.h b/tensorflow/lite/kernels/internal/reference/gather.h index d95c072f9fe2bc..6a61a8d91cbae6 100644 --- a/tensorflow/lite/kernels/internal/reference/gather.h +++ b/tensorflow/lite/kernels/internal/reference/gather.h @@ -24,7 +24,7 @@ limitations under the License. namespace tflite { namespace reference_ops { -template +template inline TfLiteStatus Gather(const tflite::GatherParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& coords_shape, diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h b/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h index ae846faf251f91..3c2b85a5956e78 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h @@ -26,12 +26,12 @@ template inline void Dequantize(const tflite::DequantizationParams& op_params, const RuntimeShape& input_shape, const T* input_data, const RuntimeShape& output_shape, float* output_data) { - const int32 zero_point = op_params.zero_point; + const int32_t zero_point = op_params.zero_point; const double scale = op_params.scale; const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { - const int32 val = static_cast(input_data[i]); + const int32_t val = static_cast(input_data[i]); const float result = static_cast(scale * (val - zero_point)); output_data[i] = result; } diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h index 2b56b4fc9194b1..6d0b278b997be9 100644 --- a/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h +++ b/tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h @@ -25,7 +25,7 @@ namespace reference_integer_ops { inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift, int32_t reverse_multiplier, int32_t reverse_shift, int32_t diff_min, int32_t outer_size, int32_t depth, - const int8* input_data, int8* output_data) { + const int8_t* input_data, int8_t* output_data) { static constexpr int8_t kMinInt8 = std::numeric_limits::min(); static constexpr int8_t kMaxInt8 = std::numeric_limits::max(); static constexpr int32_t kMinInt32 = std::numeric_limits::min(); @@ -39,11 +39,11 @@ inline void LogSoftmax(int32_t input_multiplier, int32_t input_shift, static constexpr int kInputIntegerBits = 5; static constexpr int kAccumulationIntegerBits = 12; static constexpr int kOutputIntegerBits = 4; - using F5 = gemmlowp::FixedPoint; - using F12 = gemmlowp::FixedPoint; + using F5 = gemmlowp::FixedPoint; + using F12 = gemmlowp::FixedPoint; for (int outer_index = 0; outer_index < outer_size; ++outer_index) { - int8 max_in_row = kMinInt8; + int8_t max_in_row = kMinInt8; for (int inner_index = 0; inner_index < depth; ++inner_index) { max_in_row = std::max(max_in_row, input_data[outer_index * depth + inner_index]); diff --git a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h index 4ae6af7d077b99..c9d6ed965c7aca 100644 --- a/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h +++ b/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h @@ -110,16 +110,16 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, depth_multiplier, output_data, output_dims); } -inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int dilation_width_factor, int dilation_height_factor, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::DepthwiseParams op_params; // Padding type is ignored, but still set. @@ -145,15 +145,15 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, bias_data, DimsToShape(output_dims), output_data); } -inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, @@ -165,15 +165,15 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Legacy, for compatibility with old checked-in code. template -void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int depth_multiplier, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int pad_height, int depth_multiplier, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { if (Ac == FusedActivationFunctionType::kNone) { TFLITE_DCHECK_EQ(output_activation_min, 0); TFLITE_DCHECK_EQ(output_activation_max, 255); @@ -188,15 +188,15 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, // Legacy, for compatibility with old checked-in code. template -void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int depth_multiplier, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +void DepthwiseConv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int stride, int pad_width, int pad_height, + int depth_multiplier, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride, stride, pad_width, pad_height, depth_multiplier, @@ -276,16 +276,17 @@ void Conv(const float* input_data, const Dims<4>& input_dims, output_dims, im2col_data, im2col_dims); } -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int dilation_width_factor, int dilation_height_factor, int pad_width, int pad_height, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, - uint8* im2col_data, const Dims<4>& im2col_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims, uint8_t* im2col_data, + const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { tflite::ConvParams op_params; // Padding type is ignored, but still set. @@ -310,16 +311,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context); } -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { Conv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1, @@ -330,16 +331,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -inline void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, +inline void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride_width, int stride_height, int pad_width, - int pad_height, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, + int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -359,15 +360,16 @@ inline void Conv(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void Conv(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, int stride, - int pad_width, int pad_height, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data, - const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) { +void Conv(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, int stride, + int pad_width, int pad_height, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims, + uint8_t* im2col_data, const Dims<4>& im2col_dims, + gemmlowp::GemmContext* gemmlowp_context) { Conv(input_data, input_dims, input_offset, filter_data, filter_dims, filter_offset, bias_data, bias_dims, stride, stride, pad_width, pad_height, output_offset, output_multiplier, output_shift, @@ -442,31 +444,31 @@ void FullyConnected(const float* input_data, const Dims<4>& input_dims, inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - uint8* output_data, gemmlowp::GemmContext*) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + uint8_t* output_data, gemmlowp::GemmContext*) { FullyConnected(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data); } inline void FullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& filter_shape, - const uint8* filter_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, gemmlowp::GemmContext*) { + const uint8_t* input_data, const RuntimeShape& filter_shape, + const uint8_t* filter_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int16_t* output_data, gemmlowp::GemmContext*) { FullyConnected(params, input_shape, input_data, filter_shape, filter_data, bias_shape, bias_data, output_shape, output_data); } -inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; @@ -485,13 +487,13 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, gemmlowp_context); } -inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, int16* output_data, +inline void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, int16_t* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; @@ -512,10 +514,10 @@ inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, inline void ShuffledFullyConnected( const FullyConnectedParams& params, const RuntimeShape& input_shape, - const uint8* input_data, const RuntimeShape& weights_shape, - const uint8* shuffled_weights_data, const RuntimeShape& bias_shape, - const int32* bias_data, const RuntimeShape& output_shape, - int16* output_data, uint8* shuffled_input_workspace_data, + const uint8_t* input_data, const RuntimeShape& weights_shape, + const uint8_t* shuffled_weights_data, const RuntimeShape& bias_shape, + const int32_t* bias_data, const RuntimeShape& output_shape, + int16_t* output_data, uint8_t* shuffled_input_workspace_data, gemmlowp::GemmContext*) { ShuffledFullyConnected(params, input_shape, input_data, weights_shape, shuffled_weights_data, bias_shape, bias_data, @@ -524,12 +526,12 @@ inline void ShuffledFullyConnected( } inline void ShuffledFullyConnected( - const uint8* input_data, const Dims<4>& input_dims, - const uint8* shuffled_weights_data, const Dims<4>& weights_dims, - const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - int16* output_data, const Dims<4>& output_dims, - uint8* shuffled_input_workspace_data, + const uint8_t* input_data, const Dims<4>& input_dims, + const uint8_t* shuffled_weights_data, const Dims<4>& weights_dims, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_multiplier, int output_shift, int32_t output_activation_min, + int32_t output_activation_max, int16_t* output_data, + const Dims<4>& output_dims, uint8_t* shuffled_input_workspace_data, gemmlowp::GemmContext* gemmlowp_context) { tflite::FullyConnectedParams op_params; op_params.output_multiplier = output_multiplier; @@ -547,13 +549,13 @@ inline void ShuffledFullyConnected( // legacy, for compatibility with old checked-in code template -void FullyConnected(const uint8* input_data, const Dims<4>& input_dims, - int32 input_offset, const uint8* filter_data, - const Dims<4>& filter_dims, int32 filter_offset, - const int32* bias_data, const Dims<4>& bias_dims, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +void FullyConnected(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_offset, const uint8_t* filter_data, + const Dims<4>& filter_dims, int32_t filter_offset, + const int32_t* bias_data, const Dims<4>& bias_dims, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims, gemmlowp::GemmContext* gemmlowp_context) { static_assert(Ac == FusedActivationFunctionType::kNone || @@ -596,17 +598,18 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, } template -void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, - const uint8* prev_activ_data_uint8, - const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, - const Dims<4>& weights_dims, const int32* bias_data_int32, - const Dims<4>& bias_dims, const int16* prev_state_data_int16, - const Dims<4>& prev_state_dims, int16* output_state_data_int16, - const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, - const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, - const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, - const Dims<4>& activ_temp_dims, int32 weights_zero_point, - int32 accum_multiplier, int accum_shift, +void LstmCell(const uint8_t* input_data_uint8, const Dims<4>& input_dims, + const uint8_t* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8, + const Dims<4>& weights_dims, const int32_t* bias_data_int32, + const Dims<4>& bias_dims, const int16_t* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16_t* output_state_data_int16, + const Dims<4>& output_state_dims, + uint8_t* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32_t weights_zero_point, + int32_t accum_multiplier, int accum_shift, gemmlowp::GemmContext* gemmlowp_context) { tflite::LstmCellParams op_params; op_params.weights_zero_point = weights_zero_point; @@ -671,12 +674,12 @@ inline void Concatenation(int concat_dim, const Scalar* const* input_data, DimsToShape(output_dims), output_data); } -inline void Concatenation(int concat_dim, const uint8* const* input_data, +inline void Concatenation(int concat_dim, const uint8_t* const* input_data, const Dims<4>* const* input_dims, - const int32* input_zeropoint, + const int32_t* input_zeropoint, const float* input_scale, int inputs_count, - uint8* output_data, const Dims<4>& output_dims, - const int32 output_zeropoint, + uint8_t* output_data, const Dims<4>& output_dims, + const int32_t output_zeropoint, const float output_scale) { std::vector input_shapes(inputs_count); std::vector input_shapes_indirect(inputs_count); @@ -759,10 +762,10 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape, Softmax(params, input_shape, input_data, output_shape, output_data); } -inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, - const RuntimeShape& output_shape) { +inline void Softmax(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_beta_multiplier, + int32_t input_beta_left_shift, int diff_min, + uint8_t* output_data, const RuntimeShape& output_shape) { SoftmaxParams params; params.input_multiplier = input_beta_multiplier; params.input_left_shift = input_beta_left_shift; @@ -777,11 +780,12 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, LogSoftmax(params, input_shape, input_data, output_shape, output_data); } -inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const RuntimeShape& output_shape) { +inline void LogSoftmax(const uint8_t* input_data, + const RuntimeShape& input_shape, + int32_t input_multiplier, int32_t input_left_shift, + int32_t reverse_scaling_divisor, + int32_t reverse_scaling_right_shift, int diff_min, + uint8_t* output_data, const RuntimeShape& output_shape) { SoftmaxParams params; params.input_multiplier = input_multiplier; params.input_left_shift = input_left_shift; @@ -792,50 +796,50 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, } inline void Logistic(const LogisticParams& params, - const RuntimeShape& input_shape, const uint8* input_data, - const RuntimeShape& output_shape, uint8* output_data) { - const int32 input_zero_point = params.input_zero_point; - const int32 input_range_radius = params.input_range_radius; - const int32 input_multiplier = params.input_multiplier; + const RuntimeShape& input_shape, const uint8_t* input_data, + const RuntimeShape& output_shape, uint8_t* output_data) { + const int32_t input_zero_point = params.input_zero_point; + const int32_t input_range_radius = params.input_range_radius; + const int32_t input_multiplier = params.input_multiplier; const int input_left_shift = params.input_left_shift; const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { - const uint8 input_val_u8 = input_data[i]; - const int32 input_val_centered = - static_cast(input_val_u8) - input_zero_point; - uint8 output_val; + const uint8_t input_val_u8 = input_data[i]; + const int32_t input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8_t output_val; if (input_val_centered <= -input_range_radius) { output_val = 0; } else if (input_val_centered >= input_range_radius) { output_val = 255; } else { - const int32 input_val_rescaled = + const int32_t input_val_rescaled = MultiplyByQuantizedMultiplierGreaterThanOne( input_val_centered, input_multiplier, input_left_shift); - using FixedPoint4 = gemmlowp::FixedPoint; - using FixedPoint0 = gemmlowp::FixedPoint; + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4); // Convert from Q0.31 to Q23.8. using gemmlowp::RoundingDivideByPOT; - int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); + int32_t output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23); if (output_val_s32 == 256) { output_val_s32 = 255; } // Reinterpret as U0.8. TFLITE_DCHECK_GE(output_val_s32, 0); TFLITE_DCHECK_LE(output_val_s32, 255); - output_val = static_cast(output_val_s32); + output_val = static_cast(output_val_s32); } output_data[i] = output_val; } } -inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +inline void Logistic(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const RuntimeShape& output_shape) { LogisticParams params; params.input_zero_point = input_zero_point; params.input_range_radius = input_range_radius; @@ -844,17 +848,17 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, - const RuntimeShape& output_shape, int16* output_data) { +inline void Logistic(const RuntimeShape& input_shape, const int16_t* input_data, + const RuntimeShape& output_shape, int16_t* output_data) { LogisticParams params; // No params currently needed by int16 Logistic. Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +inline void Tanh(const uint8_t* input_data, const RuntimeShape& input_shape, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const RuntimeShape& output_shape) { TanhParams params; params.input_zero_point = input_zero_point; params.input_range_radius = input_range_radius; @@ -863,16 +867,16 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, Tanh(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, - int input_left_shift, int16* output_data, +inline void Tanh(const int16_t* input_data, const RuntimeShape& input_shape, + int input_left_shift, int16_t* output_data, const RuntimeShape& output_shape) { TanhParams params; params.input_left_shift = input_left_shift; Tanh(params, input_shape, input_data, output_shape, output_data); } -inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, - int32 zero_point, double scale, float* output_data, +inline void Dequantize(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t zero_point, double scale, float* output_data, const Dims<4>& output_dims) { tflite::DequantizationParams op_params; op_params.zero_point = zero_point; @@ -896,7 +900,7 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims, template inline void Gather(const T* input_data, const Dims<4>& input_dims, - int input_rank, const int32* coords_data, + int input_rank, const int32_t* coords_data, const Dims<4>& coords_dims, T* output_data, const Dims<4>& output_dims) { tflite::GatherParams op_params; @@ -908,7 +912,7 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, output_data); } -inline uint32 LegacyReverseBits32(uint32 n) { +inline uint32_t LegacyReverseBits32(uint32_t n) { n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1); n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2); n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4); @@ -924,18 +928,18 @@ inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) { std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count); std::reverse(p->strides, p->strides + p->strides_count); - p->begin_mask = LegacyReverseBits32(static_cast(p->begin_mask)) >> + p->begin_mask = LegacyReverseBits32(static_cast(p->begin_mask)) >> (32 - p->start_indices_count); p->ellipsis_mask = - LegacyReverseBits32(static_cast(p->ellipsis_mask)) >> + LegacyReverseBits32(static_cast(p->ellipsis_mask)) >> (32 - p->start_indices_count); - p->end_mask = LegacyReverseBits32(static_cast(p->end_mask)) >> + p->end_mask = LegacyReverseBits32(static_cast(p->end_mask)) >> (32 - p->start_indices_count); p->new_axis_mask = - LegacyReverseBits32(static_cast(p->new_axis_mask)) >> + LegacyReverseBits32(static_cast(p->new_axis_mask)) >> (32 - p->start_indices_count); p->shrink_axis_mask = - LegacyReverseBits32(static_cast(p->shrink_axis_mask)) >> + LegacyReverseBits32(static_cast(p->shrink_axis_mask)) >> (32 - p->start_indices_count); } @@ -993,12 +997,12 @@ inline void Comparison(const T* input1_data, const Dims<4>& input1_dims, DimsToShape(output_dims), output_data); } -template F> +template F> inline void Comparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, + const Dims<4>& input1_dims, int32_t input1_offset, + int32_t input1_multiplier, int input1_shift, const T* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, + int32_t input2_offset, int32_t input2_multiplier, int input2_shift, bool* output_data, const Dims<4>& output_dims) { tflite::ComparisonParams op_params; @@ -1031,14 +1035,13 @@ inline void BroadcastComparison(const T* input1_data, output_data); } -template F> -inline void BroadcastComparison(int left_shift, const T* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const T* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 input2_multiplier, int input2_shift, - bool* output_data, const Dims<4>& output_dims) { +template F> +inline void BroadcastComparison( + int left_shift, const T* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, int32_t input1_multiplier, int input1_shift, + const T* input2_data, const Dims<4>& input2_dims, int32_t input2_offset, + int32_t input2_multiplier, int input2_shift, bool* output_data, + const Dims<4>& output_dims) { ComparisonParams op_params; op_params.left_shift = left_shift; @@ -1174,9 +1177,9 @@ void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims, template void Pack(int dim, const Scalar* const* input_data, - const Dims<4>* const* input_dims, const int32* input_zeropoint, + const Dims<4>* const* input_dims, const int32_t* input_zeropoint, const float* input_scale, int inputs_count, Scalar* output_data, - const Dims<4>& output_dims, const int32 output_zeropoint, + const Dims<4>& output_dims, const int32_t output_zeropoint, const float output_scale) { std::vector input_shapes(inputs_count); std::vector input_shapes_indirect(inputs_count); @@ -1207,9 +1210,9 @@ void L2Normalization(const float* input_data, const RuntimeShape& input_shape, output_data); } -inline void L2Normalization(const uint8* input_data, +inline void L2Normalization(const uint8_t* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, uint8* output_data, + int32_t input_zero_point, uint8_t* output_data, const RuntimeShape& output_shape) { tflite::L2NormalizationParams op_params; op_params.input_zero_point = input_zero_point; @@ -1225,9 +1228,9 @@ void L2Normalization(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, uint8* output_data, - const Dims<4>& output_dims) { +inline void L2Normalization(const uint8_t* input_data, + const Dims<4>& input_dims, int32_t input_zero_point, + uint8_t* output_data, const Dims<4>& output_dims) { L2Normalization(input_data, DimsToShape(input_dims), input_zero_point, output_data, DimsToShape(output_dims)); } @@ -1250,9 +1253,9 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims, output_data); } -inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, - const RuntimeShape& input_shape, uint8* output_data, - const RuntimeShape& output_shape) { +inline void ReluX(uint8_t min_value, uint8_t max_value, + const uint8_t* input_data, const RuntimeShape& input_shape, + uint8_t* output_data, const RuntimeShape& output_shape) { tflite::ActivationParams params; params.quantized_activation_max = max_value; params.quantized_activation_min = min_value; @@ -1260,14 +1263,15 @@ inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data, } template -inline void Add(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, int input2_shift, - int32 output_offset, int32 output_multiplier, int output_shift, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { +inline void Add(int left_shift, const uint8_t* input1_data, + const Dims<4>& input1_dims, int32_t input1_offset, + int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, + int input2_shift, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -1299,30 +1303,30 @@ inline void Add(int left_shift, const uint8* input1_data, } template -void Add(const int32* input1_data, const Dims<4>& input1_dims, - const int32* input2_data, const Dims<4>& input2_dims, - int32* output_data, const Dims<4>& output_dims) { +void Add(const int32_t* input1_data, const Dims<4>& input1_dims, + const int32_t* input2_data, const Dims<4>& input2_dims, + int32_t* output_data, const Dims<4>& output_dims) { ruy::profiler::ScopeLabel label("Add/int32"); TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); tflite::ArithmeticParams op_params; - op_params.quantized_activation_min = std::numeric_limits::min(); - op_params.quantized_activation_max = std::numeric_limits::max(); + op_params.quantized_activation_min = std::numeric_limits::min(); + op_params.quantized_activation_max = std::numeric_limits::max(); Add(op_params, DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims), input2_data, DimsToShape(output_dims), output_data); } template -inline void BroadcastAdd(int left_shift, const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, - const uint8* input2_data, const Dims<4>& input2_dims, - int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, - int32 output_multiplier, int output_shift, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastAdd(int left_shift, const uint8_t* input1_data, + const Dims<4>& input1_dims, int32_t input1_offset, + int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, + int input2_shift, int32_t output_offset, + int32_t output_multiplier, int output_shift, + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || @@ -1385,12 +1389,13 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, template inline void BroadcastAddFivefold( int y0, int y1, int y2, int y3, int y4, int left_shift, - const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, - int32 input1_multiplier, int input1_shift, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier, - int input2_shift, int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, int32_t input1_multiplier, int input1_shift, + const uint8_t* input2_data, const Dims<4>& input2_dims, + int32_t input2_offset, int32_t input2_multiplier, int input2_shift, + int32_t output_offset, int32_t output_multiplier, int output_shift, + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { constexpr int kReverseShift = -1; static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -1441,11 +1446,11 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, } template -inline void Add(const int16* input1_data, const Dims<4>& input1_dims, - int input1_shift, const int16* input2_data, +inline void Add(const int16_t* input1_data, const Dims<4>& input1_dims, + int input1_shift, const int16_t* input2_data, const Dims<4>& input2_dims, int input2_shift, - int16 output_activation_min, int16 output_activation_max, - int16* output_data, const Dims<4>& output_dims) { + int16_t output_activation_min, int16_t output_activation_max, + int16_t* output_data, const Dims<4>& output_dims) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu6 || @@ -1514,14 +1519,13 @@ inline bool AveragePool(const float* input_data, const Dims<4>& input_dims, // Transitional version that will be moved shortly to legacy_reference_ops, as // part of RuntimeShape revisions. -inline void BroadcastMul4DSlow(const uint8* input1_data, - const Dims<4>& input1_dims, int32 input1_offset, - const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, - const Dims<4>& output_dims) { +inline void BroadcastMul4DSlow( + const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, const uint8_t* input2_data, + const Dims<4>& input2_dims, int32_t input2_offset, int32_t output_offset, + int32_t output_multiplier, int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; SetActivationParams(output_activation_min, output_activation_max, &op_params); op_params.input1_offset = input1_offset; @@ -1535,12 +1539,12 @@ inline void BroadcastMul4DSlow(const uint8* input1_data, DimsToShape(output_dims), output_data); } -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastMul(const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, const uint8_t* input2_data, + const Dims<4>& input2_dims, int32_t input2_offset, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { BroadcastMul4DSlow( input1_data, input1_dims, input1_offset, input2_data, input2_dims, @@ -1553,12 +1557,12 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, // legacy, for compatibility with old checked-in code template -inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, - int32 input1_offset, const uint8* input2_data, - const Dims<4>& input2_dims, int32 input2_offset, - int32 output_offset, int32 output_multiplier, - int output_shift, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void BroadcastMul(const uint8_t* input1_data, const Dims<4>& input1_dims, + int32_t input1_offset, const uint8_t* input2_data, + const Dims<4>& input2_dims, int32_t input2_offset, + int32_t output_offset, int32_t output_multiplier, + int output_shift, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { BroadcastMul(input1_data, input1_dims, input1_offset, input2_data, input2_dims, input2_offset, output_offset, output_multiplier, @@ -1592,11 +1596,11 @@ bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride, output_dims); } -inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, +inline bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::PoolParams params; params.stride_height = stride_height; @@ -1613,11 +1617,11 @@ inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, +bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || Ac == FusedActivationFunctionType::kRelu6 || @@ -1635,10 +1639,10 @@ bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride, - int pad_width, int pad_height, int filter_width, - int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +bool AveragePool(const uint8_t* input_data, const Dims<4>& input_dims, + int stride, int pad_width, int pad_height, int filter_width, + int filter_height, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { return AveragePool(input_data, input_dims, stride, stride, pad_width, pad_height, filter_width, filter_height, @@ -1686,11 +1690,12 @@ void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride, filter_width, filter_height, output_data, output_dims); } -inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +inline void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, + const Dims<4>& output_dims) { PoolParams params; params.stride_height = stride_height; params.stride_width = stride_width; @@ -1706,10 +1711,10 @@ inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, +void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride_width, int stride_height, int pad_width, int pad_height, - int filter_width, int filter_height, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, + int filter_width, int filter_height, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { static_assert(Ac == FusedActivationFunctionType::kNone || Ac == FusedActivationFunctionType::kRelu || @@ -1727,10 +1732,10 @@ void MaxPool(const uint8* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code template -void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride, +void MaxPool(const uint8_t* input_data, const Dims<4>& input_dims, int stride, int pad_width, int pad_height, int filter_width, int filter_height, - int32 output_activation_min, int32 output_activation_max, - uint8* output_data, const Dims<4>& output_dims) { + int32_t output_activation_min, int32_t output_activation_max, + uint8_t* output_data, const Dims<4>& output_dims) { MaxPool(input_data, input_dims, stride, stride, pad_width, pad_height, filter_width, filter_height, output_activation_min, output_activation_max, output_data, output_dims); @@ -1783,10 +1788,10 @@ inline void Softmax(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, - const Dims<4>& output_dims) { +inline void Softmax(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_beta_multiplier, + int32_t input_beta_left_shift, int diff_min, + uint8_t* output_data, const Dims<4>& output_dims) { Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier, input_beta_left_shift, diff_min, output_data, DimsToShape(output_dims)); @@ -1798,11 +1803,11 @@ inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, DimsToShape(output_dims)); } -inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const Dims<4>& output_dims) { +inline void LogSoftmax(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_multiplier, int32_t input_left_shift, + int32_t reverse_scaling_divisor, + int32_t reverse_scaling_right_shift, int diff_min, + uint8_t* output_data, const Dims<4>& output_dims) { LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier, input_left_shift, reverse_scaling_divisor, reverse_scaling_right_shift, diff_min, output_data, @@ -1815,17 +1820,17 @@ inline void Logistic(const float* input_data, const Dims<4>& input_dims, output_data); } -inline void Logistic(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { +inline void Logistic(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const Dims<4>& output_dims) { Logistic(input_data, DimsToShape(input_dims), input_zero_point, input_range_radius, input_multiplier, input_left_shift, output_data, DimsToShape(output_dims)); } -inline void Logistic(const int16* input_data, const Dims<4>& input_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Logistic(const int16_t* input_data, const Dims<4>& input_dims, + int16_t* output_data, const Dims<4>& output_dims) { Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims), output_data); } @@ -1836,17 +1841,17 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, output_data); } -inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const Dims<4>& output_dims) { +inline void Tanh(const uint8_t* input_data, const Dims<4>& input_dims, + int32_t input_zero_point, int32_t input_range_radius, + int32_t input_multiplier, int input_left_shift, + uint8_t* output_data, const Dims<4>& output_dims) { Tanh(input_data, DimsToShape(input_dims), input_zero_point, input_range_radius, input_multiplier, input_left_shift, output_data, DimsToShape(output_dims)); } -inline void Tanh(const int16* input_data, const Dims<4>& input_dims, - int input_left_shift, int16* output_data, +inline void Tanh(const int16_t* input_data, const Dims<4>& input_dims, + int input_left_shift, int16_t* output_data, const Dims<4>& output_dims) { Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data, DimsToShape(output_dims)); @@ -1932,9 +1937,9 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, DimsToShape(output_dims), output_data); } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int16* output_data, const Dims<4>& output_dims) { +inline void Mul(const int16_t* input1_data, const Dims<4>& input1_dims, + const int16_t* input2_data, const Dims<4>& input2_dims, + int16_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; // No params in this version. @@ -1943,10 +1948,10 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, output_data); } -inline void Mul(const int16* input1_data, const Dims<4>& input1_dims, - const int16* input2_data, const Dims<4>& input2_dims, - int32 output_offset, int32 output_activation_min, - int32 output_activation_max, uint8* output_data, +inline void Mul(const int16_t* input1_data, const Dims<4>& input1_dims, + const int16_t* input2_data, const Dims<4>& input2_dims, + int32_t output_offset, int32_t output_activation_min, + int32_t output_activation_max, uint8_t* output_data, const Dims<4>& output_dims) { tflite::ArithmeticParams op_params; op_params.quantized_activation_min = output_activation_min; @@ -1988,7 +1993,7 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims, template inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, - const int32* output_size_data, + const int32_t* output_size_data, const Dims<4>& output_size_dims, T* output_data, const Dims<4>& output_dims, bool align_corners) { tflite::ResizeBilinearParams op_params; @@ -2001,7 +2006,7 @@ inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims, // legacy, for compatibility with old checked-in code inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, - const int32* output_size_data, + const int32_t* output_size_data, const Dims<4>& output_size_dims, float* output_data, const Dims<4>& output_dims) { ResizeBilinear(input_data, input_dims, output_size_data, @@ -2009,20 +2014,20 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, /*align_corners=*/false); } -inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims, - const int32* output_size_data, - const Dims<4>& output_size_dims, uint8* output_data, - const Dims<4>& output_dims) { - ResizeBilinear(input_data, input_dims, output_size_data, - output_size_dims, output_data, output_dims, - /*align_corners=*/false); +inline void ResizeBilinear(const uint8_t* input_data, const Dims<4>& input_dims, + const int32_t* output_size_data, + const Dims<4>& output_size_dims, + uint8_t* output_data, const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, + output_size_dims, output_data, output_dims, + /*align_corners=*/false); } template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, + const int32_t* block_shape_data, const Dims<4>& block_shape_dims, - const int32* paddings_data, + const int32_t* paddings_data, const Dims<4>& paddings_dims, T* output_data, const Dims<4>& output_dims, const int32_t pad_value) { @@ -2037,9 +2042,9 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, + const int32_t* block_shape_data, const Dims<4>& block_shape_dims, - const int32* paddings_data, + const int32_t* paddings_data, const Dims<4>& paddings_dims, T* output_data, const Dims<4>& output_dims) { tflite::SpaceToBatchParams op_params; @@ -2053,9 +2058,9 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, template inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims, - const int32* block_shape_data, + const int32_t* block_shape_data, const Dims<4>& block_shape_dims, - const int32* crops_data, const Dims<4>& crops_dims, + const int32_t* crops_data, const Dims<4>& crops_dims, T* output_data, const Dims<4>& output_dims) { BatchToSpaceND(DimsToShape(input_dims), input_data, DimsToShape(block_shape_dims), block_shape_data, diff --git a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc index f65127d029fce3..ee60b084edfbd2 100644 --- a/tensorflow/lite/kernels/internal/resize_bilinear_test.cc +++ b/tensorflow/lite/kernels/internal/resize_bilinear_test.cc @@ -55,7 +55,7 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params, FillRandom(&input_data, min_amplitude, max_amplitude); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {output_height, output_width}; + std::vector output_size_data = {output_height, output_width}; reference_ops::ResizeBilinear(op_params, input_dims_inference, input_data.data(), output_size_dims, @@ -66,7 +66,7 @@ void TestOneResizeBilinear(const tflite::ResizeBilinearParams& op_params, output_size_data.data(), output_dims_inference, output_data.data()); bool strict_match = false; - if (std::is_same::value && ((depth % 8) == 0) && + if (std::is_same::value && ((depth % 8) == 0) && ((input_width * 8) == output_width) && ((input_height * 8) == output_height)) { strict_match = true; @@ -111,9 +111,9 @@ TEST_P(ResizeBilinearImplTest, TestResizeBilinearUint8) { const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200); const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200); - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } @@ -136,9 +136,9 @@ TEST_P(ResizeBilinearImplTest, TestResizeBilinearUint8_2x2) { // versions. error_threshold = 1e-3; } - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - error_threshold); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + error_threshold); } } @@ -217,7 +217,7 @@ TEST(ResizeBilinear, TestResizeBilinearHalfPixelCentersFloat_3x3to2x2) { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -261,7 +261,7 @@ TEST(ResizeBilinear, TestResizeBilinearHalfPixelCentersFloat_2x2to4x4) { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {4, 4}; + std::vector output_size_data = {4, 4}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -312,7 +312,7 @@ void TestResizeBilinearHalfPixelCenters_2x2to4x6() { std::vector output_data(output_buffer_size, 3); RuntimeShape output_size_dims({1, 1, 1, 2}); - std::vector output_size_data = {4, 6}; + std::vector output_size_data = {4, 6}; tflite::ResizeBilinearParams op_params; op_params.align_corners = false; @@ -394,9 +394,9 @@ TEST_P(ResizeBilinearImplX8ChannelTest, TestResizeBilinearX8ChannelUint8) { const int output_width = input_width * scale_factor; const int output_height = input_height * scale_factor; - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } @@ -418,9 +418,9 @@ TEST_P(ResizeBilinearImplX8ChannelTest, TestResizeBilinearX8ChannelInt8) { const int output_width = input_width * scale_factor; const int output_height = input_height * scale_factor; - TestOneResizeBilinear(op_params, batch, depth, input_width, - input_height, output_width, output_height, - 0.025); + TestOneResizeBilinear(op_params, batch, depth, input_width, + input_height, output_width, output_height, + 0.025); } } diff --git a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc index debeb36e48fb9e..31ff68cc3ec3c8 100644 --- a/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc +++ b/tensorflow/lite/kernels/internal/resize_nearest_neighbor_test.cc @@ -28,7 +28,7 @@ namespace { template void TestReferenceResizeNearestNeighbor( const RuntimeShape& input_shape, const std::vector& input_data, - const std::vector& output_size_data, + const std::vector& output_size_data, const RuntimeShape& output_shape, const std::vector& expected_output_data, bool align_corners = false, bool half_pixel_centers = false) { @@ -48,7 +48,7 @@ void TestReferenceResizeNearestNeighbor( TEST(ResizeNearestNeighborReference, Test2x2To1x1) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {1}; @@ -59,7 +59,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1) { TEST(ResizeNearestNeighborReference, Test2x2To1x1_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {1}; @@ -71,7 +71,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1_AlignCorners) { TEST(ResizeNearestNeighborReference, Test2x2To1x1_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {1, 1}; + std::vector output_size_data = {1, 1}; RuntimeShape output_shape = {1, 1, 1, 1}; std::vector output_data = {4}; @@ -82,10 +82,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To1x1_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To3x3) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; + std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -94,7 +94,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3) { TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; std::vector output_data = {1, 1, 2, 1, 1, 2, 3, 3, 4}; @@ -104,10 +104,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3Int16) { TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; + std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data, @@ -116,10 +116,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3_AlignCorners) { TEST(ResizeNearestNeighborReference, Test2x2To3x3_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; + std::vector output_data = {1, 2, 2, 3, 4, 4, 3, 4, 4}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -129,7 +129,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To3x3_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test3x3To2x2) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 2, 4, 5}; @@ -140,7 +140,7 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2) { TEST(ResizeNearestNeighborReference, Test3x3To2x2_AlignCorners) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 3, 7, 9}; @@ -152,7 +152,7 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2_AlignCorners) { TEST(ResizeNearestNeighborReference, Test3x3To2x2_HalfPixelCenters) { RuntimeShape input_shape = {1, 3, 3, 1}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 9}; - std::vector output_size_data = {2, 2}; + std::vector output_size_data = {2, 2}; RuntimeShape output_shape = {1, 2, 2, 1}; std::vector output_data = {1, 3, 7, 9}; @@ -163,10 +163,10 @@ TEST(ResizeNearestNeighborReference, Test3x3To2x2_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To2x5) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {2, 5}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {2, 5}; RuntimeShape output_shape = {1, 2, 5, 1}; - std::vector output_data = {1, 1, 1, 2, 2, 3, 3, 3, 4, 4}; + std::vector output_data = {1, 1, 1, 2, 2, 3, 3, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -174,10 +174,10 @@ TEST(ResizeNearestNeighborReference, Test2x2To2x5) { TEST(ResizeNearestNeighborReference, Test2x2To2x5_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {2, 5}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {2, 5}; RuntimeShape output_shape = {1, 2, 5, 1}; - std::vector output_data = {1, 1, 2, 2, 2, 3, 3, 4, 4, 4}; + std::vector output_data = {1, 1, 2, 2, 2, 3, 3, 4, 4, 4}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -186,11 +186,11 @@ TEST(ResizeNearestNeighborReference, Test2x2To2x5_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test4x4To3x3) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 2, 3, 5, 6, 7, 9, 10, 11}; + std::vector output_data = {1, 2, 3, 5, 6, 7, 9, 10, 11}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -198,11 +198,11 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3) { TEST(ResizeNearestNeighborReference, Test4x4To3x3_AlignCorners) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; + std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data, @@ -211,11 +211,11 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3_AlignCorners) { TEST(ResizeNearestNeighborReference, Test4x4To3x3_HalfPixelCenters) { RuntimeShape input_shape = {1, 4, 4, 1}; - std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16}; - std::vector output_size_data = {3, 3}; + std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {1, 3, 3, 1}; - std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; + std::vector output_data = {1, 3, 4, 9, 11, 12, 13, 15, 16}; TestReferenceResizeNearestNeighbor( input_shape, input_data, output_size_data, output_shape, output_data, @@ -225,7 +225,7 @@ TEST(ResizeNearestNeighborReference, Test4x4To3x3_HalfPixelCenters) { TEST(ResizeNearestNeighborReference, Test2x2To5x2) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {1, 2, 1, 2, 1, 2, 3, 4, 3, 4}; @@ -236,7 +236,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To5x2) { TEST(ResizeNearestNeighborReference, Test2x2To5x2_HalfPixelCenters) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {1, 2, 1, 2, 3, 4, 3, 4, 3, 4}; @@ -249,7 +249,7 @@ TEST(ResizeNearestNeighborReference, Test2x2To5x2_HalfPixelCenters_AlignCorners) { RuntimeShape input_shape = {1, 2, 2, 1}; std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {5, 2}; + std::vector output_size_data = {5, 2}; RuntimeShape output_shape = {1, 5, 2, 1}; std::vector output_data = {2, 2, 2, 2, 4, 4, 4, 4, 4, 4}; @@ -260,11 +260,11 @@ TEST(ResizeNearestNeighborReference, TEST(ResizeNearestNeighborReference, Test2x2To4x4) { RuntimeShape input_shape = {1, 2, 2, 1}; - std::vector input_data = {1, 2, 3, 4}; - std::vector output_size_data = {4, 4}; + std::vector input_data = {1, 2, 3, 4}; + std::vector output_size_data = {4, 4}; RuntimeShape output_shape = {1, 4, 4, 1}; - std::vector output_data = {1, 1, 2, 2, 1, 1, 2, 2, - 3, 3, 4, 4, 3, 3, 4, 4}; + std::vector output_data = {1, 1, 2, 2, 1, 1, 2, 2, + 3, 3, 4, 4, 3, 3, 4, 4}; TestReferenceResizeNearestNeighbor(input_shape, input_data, output_size_data, output_shape, output_data); @@ -279,7 +279,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; // Output: // [ [ 1, 1 ], [ 1, 1 ], [ 2, 2 ], @@ -300,7 +300,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2_AlignCorners) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = { 1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 7, 8, 5, 6, 7, 8, 7, 8, @@ -316,7 +316,7 @@ TEST(ResizeNearestNeighborReference, Test2x2x2x2To2x3x3x2_HalfPixelCenters) { RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = {1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 4, 4, 3, 3, 4, 4, 4, 4, 5, 5, 6, 6, 6, 6, @@ -332,7 +332,7 @@ TEST(ResizeNearestNeighborReference, RuntimeShape input_shape = {2, 2, 2, 2}; std::vector input_data = {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}; - std::vector output_size_data = {3, 3}; + std::vector output_size_data = {3, 3}; RuntimeShape output_shape = {2, 3, 3, 2}; std::vector output_data = {1, 2, 3, 4, 3, 4, 5, 6, 7, 8, 7, 8, 5, 6, 7, 8, 7, 8, 1, 2, 3, 4, 3, 4, @@ -351,14 +351,14 @@ void TestOptimizedResizeNearestNeighbor(int batch, int depth, int input_width, RuntimeShape input_shape({batch, input_height, input_width, depth}); RuntimeShape output_shape({batch, output_height, output_width, depth}); - std::vector input_data(input_shape.FlatSize(), 0); - FillRandom(&input_data, static_cast(0), static_cast(255)); + std::vector input_data(input_shape.FlatSize(), 0); + FillRandom(&input_data, static_cast(0), static_cast(255)); - std::vector reference_output_data(output_shape.FlatSize(), 0); + std::vector reference_output_data(output_shape.FlatSize(), 0); // Initialize the output data with something other than zero, so we can catch // issue with kernels failing to initialize the output. - std::vector output_data(output_shape.FlatSize(), 3); - std::vector output_size_data = {output_height, output_width}; + std::vector output_data(output_shape.FlatSize(), 3); + std::vector output_size_data = {output_height, output_width}; ResizeNearestNeighborParams op_params{/*align_corners=*/false, /*half_pixel_centers=*/false}; @@ -412,22 +412,22 @@ bool is_valid_scale(int input_width, int input_height, int output_width, const float width_scale_float = static_cast(input_width) / output_width; - int32 height_scale_int = (input_height << 16) / output_height + 1; - int32 width_scale_int = (input_width << 16) / output_width + 1; + int32_t height_scale_int = (input_height << 16) / output_height + 1; + int32_t width_scale_int = (input_width << 16) / output_width + 1; for (int y = 0; y < output_height; ++y) { - int32 in_y_float = - std::min(static_cast(std::floor(y * height_scale_float)), + int32_t in_y_float = + std::min(static_cast(std::floor(y * height_scale_float)), input_height - 1); - int32 in_y_int = std::min((y * height_scale_int) >> 16, input_height - 1); + int32_t in_y_int = std::min((y * height_scale_int) >> 16, input_height - 1); if (in_y_int != in_y_float) { return false; } for (int x = 0; x < output_width; ++x) { - int32 in_x_float = - std::min(static_cast(std::floor(x * width_scale_float)), + int32_t in_x_float = + std::min(static_cast(std::floor(x * width_scale_float)), input_width - 1); - int32 in_x_int = std::min((x * width_scale_int) >> 16, input_width - 1); + int32_t in_x_int = std::min((x * width_scale_int) >> 16, input_width - 1); if (in_x_int != in_x_float) { return false; } diff --git a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc index 9b5ef171eaf9b5..4f736225d3508a 100644 --- a/tensorflow/lite/kernels/internal/softmax_quantized_test.cc +++ b/tensorflow/lite/kernels/internal/softmax_quantized_test.cc @@ -32,11 +32,11 @@ limitations under the License. namespace tflite { namespace { -void RunSoftmaxFloatReference(const uint8* input_data, +void RunSoftmaxFloatReference(const uint8_t* input_data, const RuntimeShape& shape_common, - int32 input_offset, const double input_scale, + int32_t input_offset, const double input_scale, int stride, float beta, - uint8* reference_output_data) { + uint8_t* reference_output_data) { const int ref_buffer_size = shape_common.FlatSize(); std::vector reference_dequant_data(ref_buffer_size); std::vector reference_output_float_data(ref_buffer_size); @@ -103,18 +103,18 @@ void CheckOutputData(const T* test_output, const T* reference_output, // Runs the Softmax and compares against the float reference implementation and // the quantized reference implementation. -void RunOneSoftmaxTest(const uint8* input_data, - const RuntimeShape& shape_common, int32 input_offset, +void RunOneSoftmaxTest(const uint8_t* input_data, + const RuntimeShape& shape_common, int32_t input_offset, const double input_scale, int stride, float beta) { const int buffer_size = shape_common.FlatSize(); - std::vector optimized_softmax_output(buffer_size); - std::vector reference_float_softmax_output(buffer_size); - std::vector reference_quant_softmax_output(buffer_size); + std::vector optimized_softmax_output(buffer_size); + std::vector reference_float_softmax_output(buffer_size); + std::vector reference_quant_softmax_output(buffer_size); RunSoftmaxFloatReference(input_data, shape_common, input_offset, input_scale, stride, beta, reference_float_softmax_output.data()); - int32 input_beta_multiplier; + int32_t input_beta_multiplier; int input_beta_left_shift; static const int kScaledDiffIntegerBits = 5; tflite::PreprocessSoftmaxScaling(beta, input_scale, kScaledDiffIntegerBits, @@ -180,14 +180,14 @@ bool TryOneUniformSoftmax() { const int input_height = ExponentialRandomPositiveInt(0.8f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); auto shape_common = RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandom(&input_data); RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, stride, beta); @@ -213,7 +213,7 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { const int input_height = ExponentialRandomPositiveInt(0.7f, 20, 200); const int stride = ExponentialRandomPositiveInt(0.9f, 3, 8); const double input_scale = std::pow(10.0, UniformRandomFloat(-2.0, 1.0)); - const int32 input_offset = UniformRandomInt(-256, 0); + const int32_t input_offset = UniformRandomInt(-256, 0); const float beta = 1.0f + ExponentialRandomPositiveFloat(0.9f, 2, 10); // Extra parameters for skyscraper input patterns. const double middle_proportion = @@ -225,7 +225,7 @@ bool TryOneSkyscraperSoftmax(bool small_depth) { RuntimeShape({batch, input_height, input_width, input_depth}); const int buffer_size = shape_common.FlatSize(); - std::vector input_data(buffer_size); + std::vector input_data(buffer_size); FillRandomSkyscraper(&input_data, input_depth, middle_proportion, middle_min, sides_max); RunOneSoftmaxTest(input_data.data(), shape_common, input_offset, input_scale, diff --git a/tensorflow/lite/kernels/internal/tensor_test.cc b/tensorflow/lite/kernels/internal/tensor_test.cc index d746d66dc94359..0006f385d7b863 100644 --- a/tensorflow/lite/kernels/internal/tensor_test.cc +++ b/tensorflow/lite/kernels/internal/tensor_test.cc @@ -24,28 +24,28 @@ using ::testing::ElementsAre; TEST(TensorTest, GetTensorShape4D) { RuntimeShape d = GetTensorShape({2, 3, 4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(2, 3, 4, 5)); } TEST(TensorTest, GetTensorShape3D) { RuntimeShape d = GetTensorShape({3, 4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(3, 4, 5)); } TEST(TensorTest, GetTensorShape2D) { RuntimeShape d = GetTensorShape({4, 5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(4, 5)); } TEST(TensorTest, GetTensorShape1D) { RuntimeShape d = GetTensorShape({5}); EXPECT_THAT( - std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), + std::vector(d.DimsData(), d.DimsData() + d.DimensionsCount()), ElementsAre(5)); } diff --git a/tensorflow/lite/kernels/internal/test_util.h b/tensorflow/lite/kernels/internal/test_util.h index ec64590d0d3508..7e17170cfa57e5 100644 --- a/tensorflow/lite/kernels/internal/test_util.h +++ b/tensorflow/lite/kernels/internal/test_util.h @@ -93,8 +93,8 @@ void FillRandom(std::vector* vec) { // the depth) with higher values than the surround. template void FillRandomSkyscraper(std::vector* vec, int depth, - double middle_proportion, uint8 middle_min, - uint8 sides_max) { + double middle_proportion, uint8_t middle_min, + uint8_t sides_max) { for (auto base_it = std::begin(*vec); base_it != std::end(*vec); base_it += depth) { auto left_it = base_it + std::ceil(0.5 * depth * (1.0 - middle_proportion)); diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 39f7bc7da53a49..2333caebce546a 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -456,6 +456,12 @@ std::string GetShapeDebugString(const TfLiteIntArray* shape) { return str; } +std::string GetTensorDebugString(const TfLiteTensor* tensor) { + return std::string("{\n type: ") + TfLiteTypeGetName(tensor->type) + + "\n data: {...}\n dims: " + GetShapeDebugString(tensor->dims) + + "\n}"; +} + TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, const TfLiteTensor* input1, const TfLiteTensor* input2, diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index e318118fb649f3..070f363b5a6412 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -310,6 +310,8 @@ TfLiteStatus GetOutputShapeFromInput(TfLiteContext* context, std::string GetShapeDebugString(const TfLiteIntArray* shape); +std::string GetTensorDebugString(const TfLiteTensor* tensor); + #endif // !defined(TF_LITE_STATIC_MEMORY) // Calculates the output_shape that is necessary for element-wise operations diff --git a/tensorflow/lite/kernels/parse_example/parse_example.cc b/tensorflow/lite/kernels/parse_example/parse_example.cc index acec0331b414f7..ec87aabfc86c95 100644 --- a/tensorflow/lite/kernels/parse_example/parse_example.cc +++ b/tensorflow/lite/kernels/parse_example/parse_example.cc @@ -111,7 +111,7 @@ void FillAndCopyVarLen(const int d, const size_t num_elements, bool ParseExample(StringRef serialized, Example* example) { DCHECK(example != nullptr); tf::protobuf::io::CodedInputStream stream( - reinterpret_cast(serialized.str), serialized.len); + reinterpret_cast(serialized.str), serialized.len); tensorflow::example::EnableAliasing(&stream); return ParseExample(&stream, example); } diff --git a/tensorflow/lite/kernels/shim/BUILD b/tensorflow/lite/kernels/shim/BUILD index 3244635f60cf36..9aa10f1463769e 100644 --- a/tensorflow/lite/kernels/shim/BUILD +++ b/tensorflow/lite/kernels/shim/BUILD @@ -163,6 +163,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:macros", ] + if_mobile([ "//tensorflow/core:portable_tensorflow_lib_lite", ]) + if_not_mobile([ diff --git a/tensorflow/lite/kernels/shim/test_op/BUILD b/tensorflow/lite/kernels/shim/test_op/BUILD index e5703e9fd4c5eb..af4cd02d50c67d 100644 --- a/tensorflow/lite/kernels/shim/test_op/BUILD +++ b/tensorflow/lite/kernels/shim/test_op/BUILD @@ -48,7 +48,7 @@ tf_cc_test( "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/platform:tstring", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -132,7 +132,7 @@ tf_cc_test( "//tensorflow/core/framework:tensor_testutil", "//tensorflow/core/kernels:ops_testutil", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc index db537b723bfadc..a37483f1b55ab8 100644 --- a/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc +++ b/tensorflow/lite/kernels/shim/test_op/simple_tf_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -22,7 +23,6 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/platform/tstring.h" -#include "tsl/lib/core/status_test_util.h" namespace tflite { namespace shim { diff --git a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc index c457bcc012da3a..8e661d82b7bc00 100644 --- a/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc +++ b/tensorflow/lite/kernels/shim/test_op/tmpl_tf_op_test.cc @@ -15,13 +15,13 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/kernels/ops_testutil.h" -#include "tsl/lib/core/status_test_util.h" namespace tflite { namespace shim { diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.cc b/tensorflow/lite/kernels/shim/tf_op_shim.cc index 7d12bc88417769..d71cfa74b9c23f 100644 --- a/tensorflow/lite/kernels/shim/tf_op_shim.cc +++ b/tensorflow/lite/kernels/shim/tf_op_shim.cc @@ -21,10 +21,17 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/lite/kernels/shim/op_kernel.h" +#include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow/lite/kernels/shim/tensor_view.h" #include "tensorflow/lite/kernels/shim/tf_tensor_view.h" diff --git a/tensorflow/lite/kernels/shim/tf_op_shim.h b/tensorflow/lite/kernels/shim/tf_op_shim.h index 834a394b39b2e7..8f6442bc7db42c 100644 --- a/tensorflow/lite/kernels/shim/tf_op_shim.h +++ b/tensorflow/lite/kernels/shim/tf_op_shim.h @@ -21,13 +21,15 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/registration/registration.h" #include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" #include "tensorflow/lite/kernels/shim/shape.h" +#include "tsl/platform/macros.h" // This file contains the TF adapter. That is, it takes a `OpKernelShim` // class and provides a TF kernel out of it. @@ -51,9 +53,9 @@ class TfInvokeContext : public InvokeContext { public: explicit TfInvokeContext(::tensorflow::OpKernelContext* context); // Read an input tensor - ConstTensorViewOr GetInput(const int idx) const; + ConstTensorViewOr GetInput(int idx) const; // Get a mutable output tensor - TensorViewOr GetOutput(const int idx, const Shape& shape) const; + TensorViewOr GetOutput(int idx, const Shape& shape) const; // Number of input tensors int NumInputs() const; // Number of output tensors @@ -70,11 +72,11 @@ class TfShapeInferenceContext explicit TfShapeInferenceContext( ::tensorflow::shape_inference::InferenceContext* context); // Read an input tensor shape - ShapeOr GetInputShape(const int idx) const; + ShapeOr GetInputShape(int idx) const; // Set an output tensor shape - absl::Status SetOutputShape(const int idx, const Shape& shape); + absl::Status SetOutputShape(int idx, const Shape& shape); // Read an input tensor during shape inference - ConstTensorViewOr GetInputTensor(const int idx) const; + ConstTensorViewOr GetInputTensor(int idx) const; // Read a given attribute absl::StatusOr GetAttr(const std::string& attr_name) const; // Number of input tensors diff --git a/tensorflow/lite/kernels/variants/BUILD b/tensorflow/lite/kernels/variants/BUILD index 46c7755cef5248..531fc8bfe0f6eb 100644 --- a/tensorflow/lite/kernels/variants/BUILD +++ b/tensorflow/lite/kernels/variants/BUILD @@ -308,7 +308,7 @@ cc_library( srcs = ["tensor_array.cc"], hdrs = ["tensor_array.h"], compatible_with = get_compatible_with_portable(), - visibility = ["//tensorflow/lite:__subpackages__"], + visibility = ["//visibility:private"], deps = [ "//tensorflow/lite:array", "//tensorflow/lite:util", diff --git a/tensorflow/lite/profiling/profile_summary_formatter_test.cc b/tensorflow/lite/profiling/profile_summary_formatter_test.cc index 48b069756579b3..a6ba3808f829f4 100644 --- a/tensorflow/lite/profiling/profile_summary_formatter_test.cc +++ b/tensorflow/lite/profiling/profile_summary_formatter_test.cc @@ -14,13 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/profiling/profile_summary_formatter.h" +#include #include #include #include #include #include +#include -#include #include #include "absl/strings/match.h" #include "tensorflow/core/util/stat_summarizer_options.h" @@ -32,6 +33,127 @@ namespace profiling { namespace { +// LINT.IfChange(OpProfilingStatComparator) +bool AreOpProfilingStatEqual(const OpProfilingStat& op_profiling_stat_1, + const OpProfilingStat& op_profiling_stat_2) { + auto proto_to_tuple = [](const OpProfilingStat& op_profiling_stat) { + return std::make_tuple(op_profiling_stat.first(), op_profiling_stat.last(), + op_profiling_stat.avg(), op_profiling_stat.stddev(), + op_profiling_stat.variance(), + op_profiling_stat.min(), op_profiling_stat.max(), + op_profiling_stat.sum(), op_profiling_stat.count()); + }; + return proto_to_tuple(op_profiling_stat_1) == + proto_to_tuple(op_profiling_stat_2); +} +// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfilingStat) + +// LINT.IfChange(OpProfileDataComparator) +bool AreOpProfileDataEqual(const OpProfileData& op_profile_data_1, + const OpProfileData& op_profile_data_2) { + auto proto_to_tuple = [](const OpProfileData& op_profile_data) { + return std::make_tuple(op_profile_data.node_type(), + op_profile_data.times_called(), + op_profile_data.name(), op_profile_data.run_order()); + }; + + return (proto_to_tuple(op_profile_data_1) == + proto_to_tuple(op_profile_data_2)) && + AreOpProfilingStatEqual(op_profile_data_1.inference_microseconds(), + op_profile_data_2.inference_microseconds()) && + (AreOpProfilingStatEqual(op_profile_data_1.mem_kb(), + op_profile_data_2.mem_kb())); +} +// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:OpProfileData) + +// LINT.IfChange(SubGraphProfilingDataComparator) +bool AreSubGraphProfilingDataEqual( + const SubGraphProfilingData& subgraph_profiling_data_1, + const SubGraphProfilingData& subgraph_profiling_data_2) { + auto proto_to_tuple = + [](const SubGraphProfilingData& subgraph_profiling_data) { + return std::make_tuple( + subgraph_profiling_data.subgraph_name(), + subgraph_profiling_data.per_op_profiles().size()); + }; + + if (proto_to_tuple(subgraph_profiling_data_1) == + proto_to_tuple(subgraph_profiling_data_2)) { + for (size_t i = 0; i < subgraph_profiling_data_1.per_op_profiles().size(); + ++i) { + auto op_profile_data_1 = subgraph_profiling_data_1.per_op_profiles(i); + auto op_profile_data_2 = subgraph_profiling_data_2.per_op_profiles(i); + if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) { + return false; + } + } + return true; + } + return false; +} +// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:SubGraphProfilingData) + +// LINT.IfChange(DelegateProfilingDataComparator) +bool AreDelegateProfilingDataEqual( + const DelegateProfilingData& delegate_profiling_data_1, + const DelegateProfilingData& delegate_profiling_data_2) { + auto proto_to_tuple = + [](const DelegateProfilingData& delegate_profiling_data) { + return std::make_tuple( + delegate_profiling_data.delegate_name(), + delegate_profiling_data.per_op_profiles().size()); + }; + + if (proto_to_tuple(delegate_profiling_data_1) == + proto_to_tuple(delegate_profiling_data_2)) { + for (size_t i = 0; i < delegate_profiling_data_1.per_op_profiles().size(); + ++i) { + auto op_profile_data_1 = delegate_profiling_data_1.per_op_profiles(i); + auto op_profile_data_2 = delegate_profiling_data_2.per_op_profiles(i); + if (!AreOpProfileDataEqual(op_profile_data_1, op_profile_data_2)) { + return false; + } + } + return true; + } + return false; +} +// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:DelegateProfilingData) + +// LINT.IfChange(ModelProfilingDataComparator) +bool AreModelProfilingDataEqual( + const ModelProfilingData& model_profiling_data_1, + const ModelProfilingData& model_profiling_data_2) { + if (model_profiling_data_1.subgraph_profiles().size() != + model_profiling_data_2.subgraph_profiles().size()) { + return false; + } + for (size_t i = 0; i < model_profiling_data_1.subgraph_profiles().size(); + ++i) { + auto subgraph_profile_1 = model_profiling_data_1.subgraph_profiles(i); + auto subgraph_profile_2 = model_profiling_data_2.subgraph_profiles(i); + if (!AreSubGraphProfilingDataEqual(subgraph_profile_1, + subgraph_profile_2)) { + return false; + } + } + if (model_profiling_data_1.delegate_profiles().size() != + model_profiling_data_2.delegate_profiles().size()) { + return false; + } + for (size_t i = 0; i < model_profiling_data_1.delegate_profiles().size(); + ++i) { + auto delegate_profile_1 = model_profiling_data_1.delegate_profiles(i); + auto delegate_profile_2 = model_profiling_data_2.delegate_profiles(i); + if (!AreDelegateProfilingDataEqual(delegate_profile_1, + delegate_profile_2)) { + return false; + } + } + return true; +} +// LINT.ThenChange(//tensorflow/lite/profiling/proto/profiling_info.proto:ModelProfilingData) + TEST(SummaryWriterTest, SummaryOptionStdOut) { ProfileSummaryDefaultFormatter writer; tensorflow::StatSummarizerOptions options = writer.GetStatSummarizerOptions(); @@ -182,8 +304,9 @@ TEST(SummaryWriterTest, MultiSubgraphOutputStringForProto) { op_profile_data_1.set_name(kernel_name_1); op_profile_data_1.set_run_order(1); op_profile_data_1.set_times_called(2); - EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(0), - testing::EqualsProto(op_profile_data_1)); + EXPECT_TRUE(AreOpProfileDataEqual( + model_profiling_data.subgraph_profiles(0).per_op_profiles(0), + op_profile_data_1)); OpProfileData op_profile_data_2; op_profile_data_2.set_node_type(op_name_2); @@ -212,8 +335,9 @@ TEST(SummaryWriterTest, MultiSubgraphOutputStringForProto) { op_profile_data_2.set_name(kernel_name_2); op_profile_data_2.set_run_order(2); - EXPECT_THAT(model_profiling_data.subgraph_profiles(0).per_op_profiles(1), - testing::EqualsProto(op_profile_data_2)); + EXPECT_TRUE(AreOpProfileDataEqual( + model_profiling_data.subgraph_profiles(0).per_op_profiles(1), + op_profile_data_2)); ASSERT_EQ(model_profiling_data.subgraph_profiles(1).subgraph_name(), "Subgraph 1"); @@ -246,8 +370,9 @@ TEST(SummaryWriterTest, MultiSubgraphOutputStringForProto) { op_profile_data_3.set_times_called(1); op_profile_data_3.set_name(kernel_name_3); op_profile_data_3.set_run_order(3); - EXPECT_THAT(model_profiling_data.subgraph_profiles(1).per_op_profiles(0), - testing::EqualsProto(op_profile_data_3)); + EXPECT_TRUE(AreOpProfileDataEqual( + model_profiling_data.subgraph_profiles(1).per_op_profiles(0), + op_profile_data_3)); } TEST(SummaryWriterTest, MultiSubgraphHandleOutputForProto) { @@ -351,10 +476,10 @@ TEST(SummaryWriterTest, MultiSubgraphHandleOutputForProto) { file.close(); ASSERT_TRUE(benchmark_profiling_data.model_name().empty()); - EXPECT_THAT(benchmark_profiling_data.init_profile(), - testing::EqualsProto(model_profiling_data_init)); - EXPECT_THAT(benchmark_profiling_data.runtime_profile(), - testing::EqualsProto(model_profiling_data_run)); + EXPECT_TRUE(AreModelProfilingDataEqual( + benchmark_profiling_data.init_profile(), model_profiling_data_init)); + EXPECT_TRUE(AreModelProfilingDataEqual( + benchmark_profiling_data.runtime_profile(), model_profiling_data_run)); } TEST(SummaryWriterTest, MultiSubgraphShortSummary) { diff --git a/tensorflow/lite/profiling/proto/profiling_info.proto b/tensorflow/lite/profiling/proto/profiling_info.proto index 8116524405dc11..5d33571efcab88 100644 --- a/tensorflow/lite/profiling/proto/profiling_info.proto +++ b/tensorflow/lite/profiling/proto/profiling_info.proto @@ -25,22 +25,27 @@ message BenchmarkProfilingData { optional ModelProfilingData runtime_profile = 3; } +// LINT.IfChange(ModelProfilingData) message ModelProfilingData { repeated SubGraphProfilingData subgraph_profiles = 1; repeated DelegateProfilingData delegate_profiles = 2; } +// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:ModelProfilingDataComparator) +// LINT.IfChange(SubGraphProfilingData) message SubGraphProfilingData { optional string subgraph_name = 1; optional int32 subgraph_index = 2; repeated OpProfileData per_op_profiles = 3; } +// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:SubGraphProfilingDataComparator) message DelegateProfilingData { optional string delegate_name = 1; repeated OpProfileData per_op_profiles = 2; } +// LINT.IfChange(OpProfilingStat) message OpProfilingStat { optional int64 first = 1; optional int64 last = 2; @@ -52,7 +57,9 @@ message OpProfilingStat { optional int64 sum = 8; optional int64 count = 9; } +// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfilingStatComparator) +// LINT.IfChange(OpProfileData) message OpProfileData { optional string node_type = 1; optional OpProfilingStat inference_microseconds = 2; @@ -61,3 +68,4 @@ message OpProfileData { optional string name = 5; optional int64 run_order = 6; } +// LINT.ThenChange(//tensorflow/lite/profiling/profile_summary_formatter_test.cc:OpProfileDataComparator) diff --git a/tensorflow/lite/python/convert.py b/tensorflow/lite/python/convert.py index 2389b3b8d393e3..403eb9549369a2 100644 --- a/tensorflow/lite/python/convert.py +++ b/tensorflow/lite/python/convert.py @@ -600,6 +600,7 @@ def build_conversion_flags( qdq_conversion_mode=None, disable_per_channel_quantization_for_dense_layers=False, enable_composite_direct_lowering=False, + model_origin_framework=lite_constants.UNSET, **_, ): """Builds protocol buffer describing a conversion of a model. @@ -731,6 +732,8 @@ def build_conversion_flags( layers. The flag works only for integer quantized model. enable_composite_direct_lowering: If set, attempts to lower composite ops directly to tflite ops. + model_origin_framework: A str specifying the framework of the original + model. Can be {TENSORFLOW, KERAS, JAX, PYTORCH} Returns: conversion_flags: protocol buffer describing the conversion process. @@ -854,6 +857,11 @@ def build_conversion_flags( conversion_flags.enable_composite_direct_lowering = ( enable_composite_direct_lowering ) + conversion_flags.model_origin_framework = ( + _conversion_flags_pb2.TocoFlags.ModelOriginFramework.Value( + model_origin_framework + ) + ) return conversion_flags diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py index 670340e8dba7fa..e49c63763c222c 100644 --- a/tensorflow/lite/python/interpreter_test.py +++ b/tensorflow/lite/python/interpreter_test.py @@ -310,15 +310,13 @@ class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase): # Model must have at least 7 bytes to hold model identifier def testTooShortModelContent(self): - with self.assertRaisesRegex( - ValueError, - 'Model provided must have at least 7 bytes to hold identifier.', - ): + with self.assertRaisesRegex(ValueError, + 'The model is not a valid Flatbuffer buffer'): interpreter_wrapper.Interpreter(model_content=b'short') def testInvalidModelContent(self): with self.assertRaisesRegex(ValueError, - 'Model provided has model identifier \''): + 'The model is not a valid Flatbuffer buffer'): interpreter_wrapper.Interpreter(model_content=b'wrong_identifier') def testInvalidModelFile(self): diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc index 7ab81eec5d58fd..a14d6dcc9e2121 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include #include @@ -745,12 +746,32 @@ PyObject* InterpreterWrapper::GetTensor(int tensor_index, tensor->type != kTfLiteVariant) { // Make a buffer copy but we must tell Numpy It owns that data or else // it will leak. - void* data = malloc(tensor->bytes); + size_t numpy_bytes = tensor->bytes; + if (tensor->type == kTfLiteInt4) { + // Numpy doesn't have int4 type, so we double the size of the buffer + // to hold int8 type for each (4-bit packed) element. + numpy_bytes *= 2; + } + void* data = malloc(numpy_bytes); if (!data) { PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed."); return nullptr; } - memcpy(data, tensor->data.raw, tensor->bytes); + if (tensor->type == kTfLiteInt4) { + int8_t* tensor_data = reinterpret_cast(tensor->data.raw); + int8_t* numpy_data = static_cast(data); + // Unpack each 4-bit value to an 8-bit container. + for (size_t i = 0; i < tensor->bytes; i++) { + int8_t byte = tensor_data[i]; + int8_t lower = static_cast(byte << 4) >> 4; + int8_t upper = static_cast(byte >> 4); + numpy_data[2 * i] = lower; + numpy_data[2 * i + 1] = upper; + } + } else { + memcpy(data, tensor->data.raw, tensor->bytes); + } + PyObject* np_array; if (tensor->sparsity == nullptr) { np_array = @@ -866,7 +887,8 @@ InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer( return nullptr; } std::unique_ptr model = - Model::BuildFromBuffer(buf, length, error_reporter.get()); + Model::VerifyAndBuildFromBuffer(buf, length, /*extra_verifier=*/nullptr, + error_reporter.get()); return CreateInterpreterWrapper( std::move(model), op_resolver_id, std::move(error_reporter), registerers_by_name, registerers_by_func, error_msg, preserve_all_tensors, diff --git a/tensorflow/lite/python/lite.py b/tensorflow/lite/python/lite.py index 2005f80d03bc12..117abe593de4bf 100644 --- a/tensorflow/lite/python/lite.py +++ b/tensorflow/lite/python/lite.py @@ -674,6 +674,7 @@ def __init__(self): self._experimental_qdq_conversion_mode = None self._experimental_disable_per_channel_quantization_for_dense_layers = False self._experimental_enable_composite_direct_lowering = False + self.model_origin_framework = constants.UNSET # Debug parameters self.ir_dump_dir = None @@ -836,6 +837,7 @@ def _get_base_converter_args(self): "enable_composite_direct_lowering": ( self._experimental_enable_composite_direct_lowering ), + "model_origin_framework": self.model_origin_framework, } if self.saved_model_dir: diff --git a/tensorflow/lite/python/lite_constants.py b/tensorflow/lite/python/lite_constants.py index 4700a5920b57c0..843c2225eb6f2f 100644 --- a/tensorflow/lite/python/lite_constants.py +++ b/tensorflow/lite/python/lite_constants.py @@ -31,6 +31,21 @@ TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF TFLITE = _toco_flags_pb2.TFLITE GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT +UNSET = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.UNSET +) +TENSORFLOW = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.TENSORFLOW +) +KERAS = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.KERAS +) +JAX = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.JAX +) +PYTORCH = _toco_flags_pb2.TocoFlags.ModelOriginFramework.Name( + _toco_flags_pb2.TocoFlags.PYTORCH +) _tf_export(v1=["lite.constants.FLOAT"]).export_constant(__name__, "FLOAT") _tf_export(v1=["lite.constants.FLOAT16"]).export_constant(__name__, "FLOAT16") @@ -65,6 +80,11 @@ "TENSORFLOW_GRAPHDEF", "TFLITE", "GRAPHVIZ_DOT", + "UNSET", + "TENSORFLOW", + "KERAS", + "JAX", + "PYTORCH", "EXPERIMENTAL_USE_TOCO_API_DIRECTLY", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/lite/schema/BUILD b/tensorflow/lite/schema/BUILD index 7bf0f18d68fc24..e064789cbe77c6 100644 --- a/tensorflow/lite/schema/BUILD +++ b/tensorflow/lite/schema/BUILD @@ -48,10 +48,10 @@ py_strict_library( "upgrade_schema.py", ], data = [ - "schema_v0.fbs", - "schema_v1.fbs", - "schema_v2.fbs", - "schema_v3.fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_v0.fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_v1.fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_v2.fbs", + "//tensorflow/compiler/mlir/lite/schema:schema_v3.fbs", "@flatbuffers//:flatc", ], srcs_version = "PY3", @@ -103,13 +103,6 @@ py_strict_test( exports_files([ "conversion_metadata.fbs", - "schema.fbs", - "schema_v0.fbs", - "schema_v1.fbs", - "schema_v2.fbs", - "schema_v3.fbs", - "schema_v3a.fbs", - "schema_v3b.fbs", ]) flatbuffer_cc_library( diff --git a/tensorflow/lite/schema/schema_v3b.fbs b/tensorflow/lite/schema/schema_v3b.fbs deleted file mode 100644 index 917786050f7e8b..00000000000000 --- a/tensorflow/lite/schema/schema_v3b.fbs +++ /dev/null @@ -1,1242 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Revision History -// Version 0: Initial version. -// Version 1: Add subgraphs to schema. -// Version 2: Rename operators to conform to NN API. -// Version 3: Move buffer data from Model.Subgraph.Tensors to Model.Buffers. -// Version 3a: Add new builtin op code field. Has backward compatibility with -// version 3. -// Version 3b: Rename fields in SignatureDef. Has backward compatibility with -// version 3 and 3a. - -namespace tflite; - -// This corresponds to the version. -file_identifier "TFL3"; -// File extension of any written files. -file_extension "tflite"; - -// IMPORTANT: All new members of tables, enums and unions must be added at the -// end to ensure backwards compatibility. - -// The type of data stored in a tensor. -enum TensorType : byte { - FLOAT32 = 0, - FLOAT16 = 1, - INT32 = 2, - UINT8 = 3, - INT64 = 4, - STRING = 5, - BOOL = 6, - INT16 = 7, - COMPLEX64 = 8, - INT8 = 9, - FLOAT64 = 10, - COMPLEX128 = 11, - UINT64 = 12, - // Experimental: Resource and variant types are experimental, that are subject - // to change. Do not implement custom kernels using resource & variant types - // now. - RESOURCE = 13, - VARIANT = 14, - UINT32 = 15, -} - -// Custom quantization parameters for experimenting with new quantization -// techniques. -table CustomQuantization { - custom:[ubyte] (force_align: 16); -} - -// Represents a specific quantization technique's parameters. -union QuantizationDetails { - CustomQuantization, -} - -// Parameters for converting a quantized tensor back to float. -table QuantizationParameters { - // These four parameters are the asymmetric linear quantization parameters. - // Given a quantized value q, the corresponding float value f should be: - // f = scale * (q - zero_point) - // For other quantization types, the QuantizationDetails below is used. - min:[float]; // For importing back into tensorflow. - max:[float]; // For importing back into tensorflow. - scale:[float]; // For dequantizing the tensor's values. - zero_point:[long]; - - // If this is not none, the other quantization parameters (i.e. min, max, - // scale, zero_point fields above) are ignored and the value of the - // QuantizationDetails union should be used. - details:QuantizationDetails; - - // Specifies the dimension of the Tensor's shape that the scales and - // zero_points correspond to. For example, a tensor t, with dims=[4, 3, 2, 1] - // with quantization params: - // scale=[1.0, 2.0, 3.0], zero_point=[1, 2, 3], quantization_dimension=1 - // will be quantized across the second dimension of t. - // t[:, 0, :, :] will have scale[0]=1.0, zero_point[0]=1 - // t[:, 1, :, :] will have scale[1]=2.0, zero_point[0]=2 - // t[:, 2, :, :] will have scale[2]=3.0, zero_point[0]=3 - quantized_dimension:int; -} - -// Sparse tensors. -// We use a modification of the TACO format. -// Reference: http://tensor-compiler.org/kjolstad-oopsla17-tensor-compiler.pdf -// -// To encode a conceptual n-dimensional dense tensor with dims (d0, ..., dn-1), -// potentially with a k-dimensional block (0 <= k <= n) with dims -// (dn, ..., dn+k-1), the format needs to specify: -// 1. In what order to traverse these dimensions. For example, to store a 2-D -// matrix in row major order, the traversal order would be (d0, d1), -// whereas to store it in column major order, the traversal order would be -// (d1, d0). If the 2-D matrix has a 2-D inner block, the traversal order -// could be (d0, d1, d2, d3). -// 2. How each block dimension in (dn, ..., dn+k-1) maps to the original -// tensor dimension in (d0, ..., dn-1). -// 3. In the traversal order defined above, the format (dense vs. sparse) and -// index metadata for each dimension. For a dense dimension, this is just -// the size of that dimension. For a sparse dimension, it's the same as -// the compressed index defined in the Compressed Sparse Row (CSR) format. -// (http://scipy-lectures.org/advanced/scipy_sparse/csr_matrix.html) - -// The storage type for a dimension. Currently we support: -// 1. DENSE: each coordinate in this dimension is stored implicitly. -// 2. SPARSE_CSR: only the coordinates with non-zero elements are stored. The -// compression technique is the same what CSR uses. -// More types like a sparse dimension with a different compression technique -// could be added to the list in the future. -enum DimensionType : byte { - DENSE = 0, - SPARSE_CSR = 1, -} - -table Int32Vector { - values:[int]; -} - -table Uint16Vector { - values:[ushort] (force_align: 4); -} - -table Uint8Vector { - values:[ubyte] (force_align: 4); -} - -// Variable-typed buffer to store the index metadata for a sparse dimension. -// The widest type is Int32 instead of UInt32 because tensor's shape is a int32 -// vector. We don't want the per-dimensional index to overflow that range. -union SparseIndexVector { - Int32Vector, - Uint16Vector, - Uint8Vector -} - -table DimensionMetadata { - // Whether a dimension is dense or sparse. - format:DimensionType; - // Index metadata used for a dimension. - // - If format is DimensionType.DENSE then we use the dense_size field to - // store the size of that dimension. Each index in that dimension is - // stored implicitly. - // - If format is DimensionType.SPARSE_CSR then we use array_segments and - // array_indices to encode that dimension. array_segments represents how - // to segment the indices array, each segment corresponds to one element - // in the previous dimension. array_indices represents the index of the - // non-zero elements within this dimension (as those in the CSR matrix - // format, where the first array is row pointers and the second array is - // column indices). - dense_size:int; - array_segments:SparseIndexVector; - array_indices:SparseIndexVector; -} - -// Parameters to encode a sparse TfLite tensor. -table SparsityParameters { - // The traversal order of the dimensions defined in the `shape` field of the - // conceptual dense tensor. For a n-dimensional tensors with dims (d0, d1, - // ..., dn-1), - // - if not block sparse, the traversal_order is just a permutation of (d0, - // ..., dn-1). For example, a 2-D matrix stored in row-major order would - // have traversal_order = (d0, d1). - // - if block sparse with a k-dimensional block (0 <= k <= n), the - // traversal_order has n + k elements. The first n elements are still a - // permutation of (d0, ..., dn-1). The lask k elements are a permutation - // of (dn, ..., dn+k-1), defining how to traverse a block internally. For - // example, a 2-D matrix with 2-D blocks, both stored in row-major order - // would have traversal_order = (d0, d1, d2, d3). - traversal_order:[int]; - // For an n-dimensional tensor with a k-dimensional block (0 <= k <= n), - // stores how a block dimension in (dn, ..., dn+k-1) maps to the original - // tensor dimension in (d0, ..., dn). - // It's stored in the order of (dn, ..., dn+k-1). - // If not block-sparse, this field is NULL. - block_map:[int]; - // In the traversal order defined above, the metadata needed for - // each dimension to locate the non-zero values in the original dense tensor. - // The size of the dim_metadata array = the size of the traversal_order array - // = n + k. - dim_metadata:[DimensionMetadata]; -} - -table Tensor { - // The tensor shape. The meaning of each entry is operator-specific but - // builtin ops use: [batch size, height, width, number of channels] (That's - // Tensorflow's NHWC). - shape:[int]; - type:TensorType; - // An index that refers to the buffers table at the root of the model. Or, - // if there is no data buffer associated (i.e. intermediate results), then - // this is 0 (which refers to an always existent empty buffer). - // - // The data_buffer itself is an opaque container, with the assumption that the - // target device is little-endian. In addition, all builtin operators assume - // the memory is ordered such that if `shape` is [4, 3, 2], then index - // [i, j, k] maps to data_buffer[i*3*2 + j*2 + k]. - buffer:uint; - name:string; // For debugging and importing back into tensorflow. - quantization:QuantizationParameters; // Optional. - - is_variable:bool = false; - - // Parameters to encode a sparse tensor. See the example in - // tensorflow/lite/testdata/sparse_tensor.json. - sparsity:SparsityParameters; // Optional. - - // Encodes `shape` with unknown dimensions. Unknown dimensions are - // represented with -1. - shape_signature:[int]; // Optional. -} - -// A list of builtin operators. Builtin operators are slightly faster than custom -// ones, but not by much. Moreover, while custom operators accept an opaque -// object containing configuration parameters, builtins have a predetermined -// set of acceptable options. -// LINT.IfChange -enum BuiltinOperator : int32 { - ADD = 0, - AVERAGE_POOL_2D = 1, - CONCATENATION = 2, - CONV_2D = 3, - DEPTHWISE_CONV_2D = 4, - DEPTH_TO_SPACE = 5, - DEQUANTIZE = 6, - EMBEDDING_LOOKUP = 7, - FLOOR = 8, - FULLY_CONNECTED = 9, - HASHTABLE_LOOKUP = 10, - L2_NORMALIZATION = 11, - L2_POOL_2D = 12, - LOCAL_RESPONSE_NORMALIZATION = 13, - LOGISTIC = 14, - LSH_PROJECTION = 15, - LSTM = 16, - MAX_POOL_2D = 17, - MUL = 18, - RELU = 19, - // NOTE(aselle): RELU_N1_TO_1 used to be called RELU1, but it was renamed - // since different model developers use RELU1 in different ways. Never - // create another op called RELU1. - RELU_N1_TO_1 = 20, - RELU6 = 21, - RESHAPE = 22, - RESIZE_BILINEAR = 23, - RNN = 24, - SOFTMAX = 25, - SPACE_TO_DEPTH = 26, - SVDF = 27, - TANH = 28, - CONCAT_EMBEDDINGS = 29, - SKIP_GRAM = 30, - CALL = 31, - CUSTOM = 32, - EMBEDDING_LOOKUP_SPARSE = 33, - PAD = 34, - UNIDIRECTIONAL_SEQUENCE_RNN = 35, - GATHER = 36, - BATCH_TO_SPACE_ND = 37, - SPACE_TO_BATCH_ND = 38, - TRANSPOSE = 39, - MEAN = 40, - SUB = 41, - DIV = 42, - SQUEEZE = 43, - UNIDIRECTIONAL_SEQUENCE_LSTM = 44, - STRIDED_SLICE = 45, - BIDIRECTIONAL_SEQUENCE_RNN = 46, - EXP = 47, - TOPK_V2 = 48, - SPLIT = 49, - LOG_SOFTMAX = 50, - // DELEGATE is a special op type for the operations which are delegated to - // other backends. - // WARNING: Experimental interface, subject to change - DELEGATE = 51, - BIDIRECTIONAL_SEQUENCE_LSTM = 52, - CAST = 53, - PRELU = 54, - MAXIMUM = 55, - ARG_MAX = 56, - MINIMUM = 57, - LESS = 58, - NEG = 59, - PADV2 = 60, - GREATER = 61, - GREATER_EQUAL = 62, - LESS_EQUAL = 63, - SELECT = 64, - SLICE = 65, - SIN = 66, - TRANSPOSE_CONV = 67, - SPARSE_TO_DENSE = 68, - TILE = 69, - EXPAND_DIMS = 70, - EQUAL = 71, - NOT_EQUAL = 72, - LOG = 73, - SUM = 74, - SQRT = 75, - RSQRT = 76, - SHAPE = 77, - POW = 78, - ARG_MIN = 79, - FAKE_QUANT = 80, - REDUCE_PROD = 81, - REDUCE_MAX = 82, - PACK = 83, - LOGICAL_OR = 84, - ONE_HOT = 85, - LOGICAL_AND = 86, - LOGICAL_NOT = 87, - UNPACK = 88, - REDUCE_MIN = 89, - FLOOR_DIV = 90, - REDUCE_ANY = 91, - SQUARE = 92, - ZEROS_LIKE = 93, - FILL = 94, - FLOOR_MOD = 95, - RANGE = 96, - RESIZE_NEAREST_NEIGHBOR = 97, - LEAKY_RELU = 98, - SQUARED_DIFFERENCE = 99, - MIRROR_PAD = 100, - ABS = 101, - SPLIT_V = 102, - UNIQUE = 103, - CEIL = 104, - REVERSE_V2 = 105, - ADD_N = 106, - GATHER_ND = 107, - COS = 108, - WHERE = 109, - RANK = 110, - ELU = 111, - REVERSE_SEQUENCE = 112, - MATRIX_DIAG = 113, - QUANTIZE = 114, - MATRIX_SET_DIAG = 115, - ROUND = 116, - HARD_SWISH = 117, - IF = 118, - WHILE = 119, - NON_MAX_SUPPRESSION_V4 = 120, - NON_MAX_SUPPRESSION_V5 = 121, - SCATTER_ND = 122, - SELECT_V2 = 123, - DENSIFY = 124, - SEGMENT_SUM = 125, - BATCH_MATMUL = 126, - PLACEHOLDER_FOR_GREATER_OP_CODES = 127, - CUMSUM = 128, - CALL_ONCE = 129, - BROADCAST_TO = 130, - RFFT2D = 131, - CONV_3D = 132, - IMAG=133, - REAL=134, - COMPLEX_ABS=135, - HASHTABLE = 136, - HASHTABLE_FIND = 137, - HASHTABLE_IMPORT = 138, - HASHTABLE_SIZE = 139, - REDUCE_ALL = 140, - CONV_3D_TRANSPOSE = 141, - VAR_HANDLE = 142, - READ_VARIABLE = 143, - ASSIGN_VARIABLE = 144, -} -// LINT.ThenChange(nnapi_linter/linter.proto) - -// Options for the builtin operators. -union BuiltinOptions { - Conv2DOptions, - DepthwiseConv2DOptions, - ConcatEmbeddingsOptions, - LSHProjectionOptions, - Pool2DOptions, - SVDFOptions, - RNNOptions, - FullyConnectedOptions, - SoftmaxOptions, - ConcatenationOptions, - AddOptions, - L2NormOptions, - LocalResponseNormalizationOptions, - LSTMOptions, - ResizeBilinearOptions, - CallOptions, - ReshapeOptions, - SkipGramOptions, - SpaceToDepthOptions, - EmbeddingLookupSparseOptions, - MulOptions, - PadOptions, - GatherOptions, - BatchToSpaceNDOptions, - SpaceToBatchNDOptions, - TransposeOptions, - ReducerOptions, - SubOptions, - DivOptions, - SqueezeOptions, - SequenceRNNOptions, - StridedSliceOptions, - ExpOptions, - TopKV2Options, - SplitOptions, - LogSoftmaxOptions, - CastOptions, - DequantizeOptions, - MaximumMinimumOptions, - ArgMaxOptions, - LessOptions, - NegOptions, - PadV2Options, - GreaterOptions, - GreaterEqualOptions, - LessEqualOptions, - SelectOptions, - SliceOptions, - TransposeConvOptions, - SparseToDenseOptions, - TileOptions, - ExpandDimsOptions, - EqualOptions, - NotEqualOptions, - ShapeOptions, - PowOptions, - ArgMinOptions, - FakeQuantOptions, - PackOptions, - LogicalOrOptions, - OneHotOptions, - LogicalAndOptions, - LogicalNotOptions, - UnpackOptions, - FloorDivOptions, - SquareOptions, - ZerosLikeOptions, - FillOptions, - BidirectionalSequenceLSTMOptions, - BidirectionalSequenceRNNOptions, - UnidirectionalSequenceLSTMOptions, - FloorModOptions, - RangeOptions, - ResizeNearestNeighborOptions, - LeakyReluOptions, - SquaredDifferenceOptions, - MirrorPadOptions, - AbsOptions, - SplitVOptions, - UniqueOptions, - ReverseV2Options, - AddNOptions, - GatherNdOptions, - CosOptions, - WhereOptions, - RankOptions, - ReverseSequenceOptions, - MatrixDiagOptions, - QuantizeOptions, - MatrixSetDiagOptions, - HardSwishOptions, - IfOptions, - WhileOptions, - DepthToSpaceOptions, - NonMaxSuppressionV4Options, - NonMaxSuppressionV5Options, - ScatterNdOptions, - SelectV2Options, - DensifyOptions, - SegmentSumOptions, - BatchMatMulOptions, - CumsumOptions, - CallOnceOptions, - BroadcastToOptions, - Rfft2dOptions, - Conv3DOptions, - HashtableOptions, - HashtableFindOptions, - HashtableImportOptions, - HashtableSizeOptions, - VarHandleOptions, - ReadVariableOptions, - AssignVariableOptions, -} - -enum Padding : byte { SAME, VALID } - -enum ActivationFunctionType : byte { - NONE = 0, - RELU = 1, - RELU_N1_TO_1 = 2, - RELU6 = 3, - TANH = 4, - SIGN_BIT = 5, -} - -table Conv2DOptions { - padding:Padding; - stride_w:int; - stride_h:int; - fused_activation_function:ActivationFunctionType; - dilation_w_factor:int = 1; - dilation_h_factor:int = 1; -} - -// Options for both Conv3D and Conv3DTranspose. -table Conv3DOptions { - padding:Padding; - stride_d:int; - stride_w:int; - stride_h:int; - fused_activation_function:ActivationFunctionType; - dilation_d_factor:int = 1; - dilation_w_factor:int = 1; - dilation_h_factor:int = 1; -} - -table Pool2DOptions { - padding:Padding; - stride_w:int; - stride_h:int; - filter_width:int; - filter_height:int; - fused_activation_function:ActivationFunctionType; -} - -table DepthwiseConv2DOptions { - // Parameters for DepthwiseConv version 1 or above. - padding:Padding; - stride_w:int; - stride_h:int; - // `depth_multiplier` is redundant. It's used by CPU kernels in - // TensorFlow 2.0 or below, but ignored in versions above. - // See comments in lite/c/builtin_op_data.h for more details. - depth_multiplier:int; - fused_activation_function:ActivationFunctionType; - // Parameters for DepthwiseConv version 2 or above. - dilation_w_factor:int = 1; - dilation_h_factor:int = 1; -} - -table ConcatEmbeddingsOptions { - num_channels:int; - num_columns_per_channel:[int]; - embedding_dim_per_channel:[int]; // This could be inferred from parameters. -} - -enum LSHProjectionType: byte { - UNKNOWN = 0, - SPARSE = 1, - DENSE = 2, -} - -table LSHProjectionOptions { - type: LSHProjectionType; -} - -table SVDFOptions { - rank:int; - fused_activation_function:ActivationFunctionType; - // For weights-only quantization, use asymmetric quantization for non - // constant inputs at evaluation time. - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow RNNCell. -table RNNOptions { - fused_activation_function:ActivationFunctionType; - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow dynamic_rnn with RNNCell. -table SequenceRNNOptions { - time_major:bool; - fused_activation_function:ActivationFunctionType; - asymmetric_quantize_inputs:bool; -} - -// An implementation of TensorFlow bidrectional_dynamic_rnn with RNNCell. -table BidirectionalSequenceRNNOptions { - time_major:bool; - fused_activation_function:ActivationFunctionType; - merge_outputs: bool; - asymmetric_quantize_inputs:bool; -} - -enum FullyConnectedOptionsWeightsFormat: byte { - DEFAULT = 0, - SHUFFLED4x16INT8 = 1, -} - -// An implementation of TensorFlow fully_connected (a.k.a Dense) layer. -table FullyConnectedOptions { - // Parameters for FullyConnected version 1 or above. - fused_activation_function:ActivationFunctionType; - - // Parameters for FullyConnected version 2 or above. - weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; - - // Parameters for FullyConnected version 5 or above. - // If set to true, then the number of dimension is preserved. Furthermore, - // all but the last dimension of the input and output shapes will be equal. - keep_num_dims: bool; - - // Parameters for FullyConnected version 7 or above. - // If set to true, then weights-only op will use asymmetric quantization for - // inputs. - asymmetric_quantize_inputs: bool; -} - -table SoftmaxOptions { - beta: float; -} - -// An implementation of TensorFlow concat. -table ConcatenationOptions { - axis:int; - fused_activation_function:ActivationFunctionType; -} - -table AddOptions { - fused_activation_function:ActivationFunctionType; - // Parameters supported by version 3. - pot_scale_int16:bool = true; -} - -table MulOptions { - fused_activation_function:ActivationFunctionType; -} - -table L2NormOptions { - // This field is currently ignored in the L2 Norm Op. - fused_activation_function:ActivationFunctionType; -} - -table LocalResponseNormalizationOptions { - radius:int; - bias:float; - alpha:float; - beta:float; -} - -enum LSTMKernelType : byte { - // Full LSTM kernel which supports peephole and projection. - FULL = 0, - // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. - BASIC = 1, -} - -// An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell -table LSTMOptions { - // Parameters for LSTM version 1 or above. - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // Parameters for LSTM version 2 or above. - // Basic kernel is only supported in version 2 or above. - kernel_type: LSTMKernelType = FULL; - - // Parameters for LSTM version 4 or above. - asymmetric_quantize_inputs: bool; -} - -// An implementation of TensorFlow dynamic_rnn with LSTMCell. -table UnidirectionalSequenceLSTMOptions { - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // If true then first dimension is sequence, otherwise batch. - time_major:bool; - - // Parameter for Unidirectional Sequence LSTM version 4. - asymmetric_quantize_inputs:bool; -} - -table BidirectionalSequenceLSTMOptions { - // Parameters supported by version 1: - fused_activation_function:ActivationFunctionType; - cell_clip: float; // Optional, 0.0 means no clipping - proj_clip: float; // Optional, 0.0 means no clipping - - // If true, store the outputs of both directions into the first output. - merge_outputs: bool; - - // Parameters supported by version 2: - // If true then first dimension is sequence, otherwise batch. - // Version 1 implementations assumed time_major to be true, so this default - // value should never change. - time_major: bool = true; - - // Parameters for version 3 or above. - asymmetric_quantize_inputs:bool; -} - -table ResizeBilinearOptions { - new_height: int (deprecated); - new_width: int (deprecated); - align_corners: bool; - half_pixel_centers: bool; -} - -table ResizeNearestNeighborOptions { - align_corners: bool; - half_pixel_centers: bool; -} - -// A call operation options -table CallOptions { - // The subgraph index that needs to be called. - subgraph:uint; -} - -table PadOptions { -} - -table PadV2Options { -} - -table ReshapeOptions { - new_shape:[int]; -} - -table SpaceToBatchNDOptions { -} - -table BatchToSpaceNDOptions { -} - -table SkipGramOptions { - ngram_size: int; - max_skip_size: int; - include_all_ngrams: bool; -} - -table SpaceToDepthOptions { - block_size: int; -} - -table DepthToSpaceOptions { - block_size: int; -} - -table SubOptions { - fused_activation_function:ActivationFunctionType; - // Parameters supported by version 5 - pot_scale_int16:bool = true; -} - -table DivOptions { - fused_activation_function:ActivationFunctionType; -} - -table TopKV2Options { -} - -enum CombinerType : byte { - SUM = 0, - MEAN = 1, - SQRTN = 2, -} - -table EmbeddingLookupSparseOptions { - combiner:CombinerType; -} - -table GatherOptions { - axis: int; - // Parameters for Gather version 5 or above. - batch_dims: int = 0; -} - -table TransposeOptions { -} - -table ExpOptions { -} - -table CosOptions { -} - -table ReducerOptions { - keep_dims: bool; -} - -table SqueezeOptions { - squeeze_dims:[int]; -} - -table SplitOptions { - num_splits: int; -} - -table SplitVOptions { - num_splits: int; -} - -table StridedSliceOptions { - begin_mask: int; - end_mask: int; - ellipsis_mask: int; - new_axis_mask: int; - shrink_axis_mask: int; -} - -table LogSoftmaxOptions { -} - -table CastOptions { - in_data_type: TensorType; - out_data_type: TensorType; -} - -table DequantizeOptions { -} - -table MaximumMinimumOptions { -} - -table TileOptions { -} - -table ArgMaxOptions { - output_type : TensorType; -} - -table ArgMinOptions { - output_type : TensorType; -} - -table GreaterOptions { -} - -table GreaterEqualOptions { -} - -table LessOptions { -} - -table LessEqualOptions { -} - -table NegOptions { -} - -table SelectOptions { -} - -table SliceOptions { -} - -table TransposeConvOptions { - padding:Padding; - stride_w:int; - stride_h:int; -} - -table ExpandDimsOptions { -} - -table SparseToDenseOptions { - validate_indices:bool; -} - -table EqualOptions { -} - -table NotEqualOptions { -} - -table ShapeOptions { - // Optional output type of the operation (int32 or int64). Defaults to int32. - out_type : TensorType; -} - -table RankOptions { -} - -table PowOptions { -} - -table FakeQuantOptions { - // Parameters supported by version 1: - min:float; - max:float; - num_bits:int; - - // Parameters supported by version 2: - narrow_range:bool; -} - -table PackOptions { - values_count:int; - axis:int; -} - -table LogicalOrOptions { -} - -table OneHotOptions { - axis:int; -} - -table AbsOptions { -} - - -table HardSwishOptions { -} - -table LogicalAndOptions { -} - -table LogicalNotOptions { -} - -table UnpackOptions { - num:int; - axis:int; -} - -table FloorDivOptions { -} - -table SquareOptions { -} - -table ZerosLikeOptions { -} - -table FillOptions { -} - -table FloorModOptions { -} - -table RangeOptions { -} - -table LeakyReluOptions { - alpha:float; -} - -table SquaredDifferenceOptions { -} - -enum MirrorPadMode : byte { - // Doesn't include borders. - REFLECT = 0, - // Includes borders. - SYMMETRIC = 1, -} - -table MirrorPadOptions { - mode:MirrorPadMode; -} - -table UniqueOptions { - idx_out_type:TensorType = INT32; -} - -table ReverseV2Options { -} - -table AddNOptions { -} - -table GatherNdOptions { -} - -table WhereOptions { -} - -table ReverseSequenceOptions { - seq_dim:int; - batch_dim:int = 0; -} - -table MatrixDiagOptions { -} - -table QuantizeOptions { -} - -table MatrixSetDiagOptions { -} - -table IfOptions { - then_subgraph_index:int; - else_subgraph_index:int; -} - -table CallOnceOptions { - init_subgraph_index:int; -} - -table WhileOptions { - cond_subgraph_index:int; - body_subgraph_index:int; -} - -table NonMaxSuppressionV4Options { -} - -table NonMaxSuppressionV5Options { -} - -table ScatterNdOptions { -} - -table SelectV2Options { -} - -table DensifyOptions { -} - -table SegmentSumOptions { -} - -table BatchMatMulOptions { - adj_x:bool; - adj_y:bool; - // Parameters for BatchMatMul version 4 or above. - // If set to true, then weights-only op will use asymmetric quantization for - // inputs. - asymmetric_quantize_inputs: bool; -} - -table CumsumOptions { - exclusive:bool; - reverse:bool; -} - -table BroadcastToOptions { -} - -table Rfft2dOptions { -} - -table HashtableOptions { - // The identity of hash tables. This identity will be used across different - // subgraphs in the same interpreter instance. - table_id:int; - key_dtype:TensorType; - value_dtype:TensorType; -} - -table HashtableFindOptions { -} - -table HashtableImportOptions { -} - -table HashtableSizeOptions { -} - -table VarHandleOptions { - container:string; - shared_name:string; -} - -table ReadVariableOptions { -} - -table AssignVariableOptions { -} - -// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a -// builtin, or a string if the operator is custom. -table OperatorCode { - // This field is for backward compatibility. This field will be used when - // the value of the extended builtin_code field has less than - // BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. - deprecated_builtin_code:byte; - custom_code:string; - - // The version of the operator. The version need to be bumped whenever new - // parameters are introduced into an op. - version:int = 1; - - // This field is introduced for resolving op builtin code shortage problem - // (the original BuiltinOperator enum field was represented as a byte). - // This field will be used when the value of the extended builtin_code field - // has greater than BulitinOperator_PLACEHOLDER_FOR_GREATER_OP_CODES. - builtin_code:BuiltinOperator; -} - -enum CustomOptionsFormat : byte { - FLEXBUFFERS = 0, -} - -// An operator takes tensors as inputs and outputs. The type of operation being -// performed is determined by an index into the list of valid OperatorCodes, -// while the specifics of each operations is configured using builtin_options -// or custom_options. -table Operator { - // Index into the operator_codes array. Using an integer here avoids - // complicate map lookups. - opcode_index:uint; - - // Optional input are indicated by -1. - inputs:[int]; - outputs:[int]; - - builtin_options:BuiltinOptions; - custom_options:[ubyte]; - custom_options_format:CustomOptionsFormat; - - // A list of booleans indicating the input tensors which are being mutated by - // this operator.(e.g. used by RNN and LSTM). - // For example, if the "inputs" array refers to 5 tensors and the second and - // fifth are mutable variables, then this list will contain - // [false, true, false, false, true]. - // - // If the list is empty, no variable is mutated in this operator. - // The list either has the same length as `inputs`, or is empty. - mutating_variable_inputs:[bool]; - - // A list of indices to the subgraph's "tensors" that are internal to an Op. - // Internal tensors are those that do not flow in or out of the operation, - // but instead are part of internal computation. As such, the operation's - // implementation may manage its memory more efficiently. They are needed - // however (i.e. not just an implementation detail) since they are part of the - // computation, which may require relevant metadata such as quantization - // parameters. - intermediates:[int]; -} - -// The root type, defining a subgraph, which typically represents an entire -// model. -table SubGraph { - // A list of all tensors used in this subgraph. - tensors:[Tensor]; - - // Indices of the tensors that are inputs into this subgraph. Note this is - // the list of non-static tensors that feed into the subgraph for inference. - inputs:[int]; - - // Indices of the tensors that are outputs out of this subgraph. Note this is - // the list of output tensors that are considered the product of the - // subgraph's inference. - outputs:[int]; - - // All operators, in execution order. - operators:[Operator]; - - // Name of this subgraph (used for debugging). - name:string; -} - -// Table of raw data buffers (used for constant tensors). Referenced by tensors -// by index. The generous alignment accommodates mmap-friendly data structures. -table Buffer { - data:[ubyte] (force_align: 16); -} - -table Metadata { - // A human readable string to uniquely identify a Metadata. - name:string; - // An index to the buffers table. - buffer:uint; -} - -// Map from an alias name of tensor to tensor index in the graph. -// This is used in Signature def. -table TensorMap { - // Represents the alias to use for this tensor. - name:string; - - // The actual tensor index in the primary graph, that 'name' corresponds to. - tensor_index:uint; -} - -// This corresponds to SignatureDef in Tensorflow SavedModel. -// The SignatureDef will be part of the SavedModel provided for conversion. -table SignatureDef { - // Named inputs for this signature. - inputs:[TensorMap]; - - // Named outputs for this signature. - outputs:[TensorMap]; - - // Key value which was in the Tensorflow SavedModel SignatureDef map. - signature_key:string; - - // Model tag, deprecated. - deprecated_tag:string (deprecated); - - // Index of subgraphs that corresponds to the exported method. - subgraph_index:uint; -} - -table Model { - // Version of the schema. - version:uint; - - // A list of all operator codes used in this model. This is - // kept in order because operators carry an index into this - // vector. - operator_codes:[OperatorCode]; - - // All the subgraphs of the model. The 0th is assumed to be the main - // model. - subgraphs:[SubGraph]; - - // A description of the model. - description:string; - - // Buffers of the model. - // Note the 0th entry of this array must be an empty buffer (sentinel). - // This is a convention so that tensors without a buffer can provide 0 as - // their buffer. - buffers:[Buffer]; - - // Metadata about the model. Indirects into the existings buffers list. - // Deprecated, prefer to use metadata field. - metadata_buffer:[int]; - - // Metadata about the model. - metadata:[Metadata]; - - // Optional SignatureDefs for the model. - signature_defs:[SignatureDef]; -} - -root_type Model; diff --git a/tensorflow/lite/stateful_error_reporter.h b/tensorflow/lite/stateful_error_reporter.h index cf6693431f9118..10dc09646cb273 100644 --- a/tensorflow/lite/stateful_error_reporter.h +++ b/tensorflow/lite/stateful_error_reporter.h @@ -15,9 +15,10 @@ limitations under the License. #ifndef TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ #define TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ +// LINT.IfChange #include -#include "tensorflow/lite/core/api/error_reporter.h" +#include "tensorflow/compiler/mlir/lite/core/api/error_reporter.h" namespace tflite { @@ -30,5 +31,6 @@ class StatefulErrorReporter : public ErrorReporter { }; } // namespace tflite +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stateful_error_reporter.h) #endif // TENSORFLOW_LITE_STATEFUL_ERROR_REPORTER_H_ diff --git a/tensorflow/lite/stderr_reporter.h b/tensorflow/lite/stderr_reporter.h index 2eacb9eca244d2..fdac5d4062cab3 100644 --- a/tensorflow/lite/stderr_reporter.h +++ b/tensorflow/lite/stderr_reporter.h @@ -18,7 +18,6 @@ limitations under the License. #include #include "tensorflow/lite/core/api/error_reporter.h" -#include "tensorflow/lite/core/c/common.h" namespace tflite { diff --git a/tensorflow/lite/testdata/no_signatures.bin b/tensorflow/lite/testdata/no_signatures.bin new file mode 100644 index 00000000000000..1a6f71b7936722 Binary files /dev/null and b/tensorflow/lite/testdata/no_signatures.bin differ diff --git a/tensorflow/lite/testing/BUILD b/tensorflow/lite/testing/BUILD index e7f4fdafd7ad1e..5d3976fb9ee5af 100644 --- a/tensorflow/lite/testing/BUILD +++ b/tensorflow/lite/testing/BUILD @@ -217,6 +217,33 @@ cc_library( ], ) +cc_library( + name = "matchers", + testonly = True, + srcs = ["matchers.h"], + hdrs = ["matchers.h"], + deps = [ + "//tensorflow/lite/core/c:common", + "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "matchers_test", + srcs = ["matchers_test.cc"], + deps = [ + ":matchers", + "//tensorflow/lite/core/c:c_api_types", + "//tensorflow/lite/core/c:common", + "@com_google_absl//absl/base", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "message", srcs = ["message.cc"], diff --git a/tensorflow/lite/testing/matchers.h b/tensorflow/lite/testing/matchers.h new file mode 100644 index 00000000000000..604b3dd9ff6cfe --- /dev/null +++ b/tensorflow/lite/testing/matchers.h @@ -0,0 +1,272 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_TESTING_MATCHERS_H_ +#define TENSORFLOW_LITE_TESTING_MATCHERS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include "absl/log/absl_check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/lite/core/c/common.h" +#include "tensorflow/lite/kernels/kernel_util.h" + +// gMock matchers for TfLiteTensors. +// +// EXPECT_THAT(a, EqualsTensor(b)); +// EXPECT_THAT(a, Approximately(EqualsTensor(b))); +// EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin*/)); +// EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin=*/0, /*fraction*/)); +// +// TODO: who/impjdi - Expand to more dtypes than just float. +// TODO: who/impjdi - Add cross-dtype matchers. + +inline void PrintTo(const TfLiteTensor& tensor, std::ostream* os) { + *os << "\n" << ::tflite::GetTensorDebugString(&tensor); +} + +namespace testing { +namespace tflite { +namespace internal { + +enum class FloatComparison { kExact, kApproximate }; + +struct TensorComparison { + FloatComparison float_comp = FloatComparison::kExact; + bool custom_margin = false; + bool custom_fraction = false; + double margin = 0.0; // only used if custom_margin == true + double fraction = 0.0; // only used if custom_fraction == true +}; + +class TensorMatcher { + public: + TensorMatcher(const TensorComparison& comp, const TfLiteTensor& expected) + : comp_(comp), expected_(expected) {} + + bool MatchAndExplain(const TfLiteTensor& actual, + MatchResultListener* listener) const { + const bool match = Match(actual); + if (listener->IsInterested() && !match) *listener << DescribeDiff(actual); + return match; + } + + void DescribeTo(std::ostream* os) const { Describe(os, "is "); } + void DescribeNegationTo(std::ostream* os) const { Describe(os, "is not "); } + + void SetCompareApproximately() { + comp_.float_comp = FloatComparison::kApproximate; + } + + void SetMargin(double margin) { + ABSL_QCHECK_GE(margin, 0.0) // Crash OK + << "Using a negative margin for Approximately"; + comp_.custom_margin = true; + comp_.margin = margin; + } + + void SetFraction(double fraction) { + ABSL_QCHECK(0.0 <= fraction && fraction < 1.0) // Crash OK + << "Fraction for Approximately must be >= 0.0 and < 1.0"; + comp_.custom_fraction = true; + comp_.fraction = fraction; + } + + private: + static std::string TensorIndex(int index, const TfLiteIntArray* dims) { + if (!dims->size) return ""; + std::vector index_nd(dims->size); + for (int i = dims->size - 1; i >= 0; --i) { + index_nd[i] = index % dims->data[i]; + index /= dims->data[i]; + } + return absl::StrCat("[", absl::StrJoin(index_nd, "]["), "]"); + } + + bool CompareFloat(float x, float y) const { + switch (comp_.float_comp) { + case FloatComparison::kExact: + return x == y; + case FloatComparison::kApproximate: + if (x == y) return true; + float fraction, margin; + if (comp_.custom_margin || comp_.custom_fraction) { + fraction = comp_.fraction; + margin = comp_.margin; + } else { + constexpr float kEpsilon = 32 * FLT_EPSILON; + if (std::fabs(x) <= kEpsilon && std::fabs(y) <= kEpsilon) return true; + fraction = kEpsilon; + margin = kEpsilon; + } + if (!std::isfinite(x) || !std::isfinite(y)) return false; + float relative_margin = fraction * std::max(std::fabs(x), std::fabs(y)); + return std::fabs(x - y) <= std::max(margin, relative_margin); + } + return false; + } + + void Describe(std::ostream* os, std::string_view prefix) const { + *os << prefix; + if (comp_.float_comp == FloatComparison::kApproximate) { + *os << "approximately "; + if (comp_.custom_margin || comp_.custom_fraction) { + *os << "("; + if (comp_.custom_margin) { + std::stringstream ss; + ss << std::setprecision(std::numeric_limits::digits10 + 2) + << comp_.margin; + *os << "absolute error of float values <= " << ss.str(); + } + if (comp_.custom_margin && comp_.custom_fraction) { + *os << " or "; + } + if (comp_.custom_fraction) { + std::stringstream ss; + ss << std::setprecision(std::numeric_limits::digits10 + 2) + << comp_.fraction; + *os << "relative error of float values <= " << ss.str(); + } + *os << ") "; + } + } + *os << "equal to "; + PrintTo(expected_, os); + } + + std::string DescribeDiff(const TfLiteTensor& actual) const { + if (actual.type != expected_.type) { + return absl::StrCat( + "dtypes don't match: ", TfLiteTypeGetName(actual.type), " vs ", + TfLiteTypeGetName(expected_.type)); + } + if (!actual.dims) return "actual.dims is null."; + if (!expected_.dims) return "expected.dims is null."; + if (actual.dims->size != expected_.dims->size) { + return absl::StrCat("dims don't match: ", actual.dims->size, "D vs ", + expected_.dims->size, "D"); + } + if (int n = actual.dims->size; + std::memcmp(actual.dims->data, expected_.dims->data, n * sizeof(int))) { + return absl::StrCat( + "shapes don't match: ", ::tflite::GetShapeDebugString(actual.dims), + " vs ", ::tflite::GetShapeDebugString(expected_.dims)); + } + if (!actual.data.raw) return "actual.data is null."; + if (!expected_.data.raw) return "expected.data is null."; + if (actual.bytes != expected_.bytes) { + return absl::StrCat("bytes don't match: ", actual.bytes, " vs ", + expected_.bytes); + } + std::string error = "\n"; + TfLiteIntArray* dims = actual.dims; + int n = ::tflite::NumElements(dims); + constexpr int kMaxMismatches = 20; + for (int i = 0, j = 0; i < n; ++i) { + if (!CompareFloat(actual.data.f[i], expected_.data.f[i])) { + absl::StrAppend(&error, "data", TensorIndex(i, dims), + " don't match: ", actual.data.f[i], " vs ", + expected_.data.f[i], "\n"); + ++j; + } + if (j == kMaxMismatches) { + absl::StrAppend(&error, "Too many mismatches; stopping after ", j, + ".\n"); + break; + } + } + return error; + } + + bool Match(const TfLiteTensor& actual) const { + if (actual.type != expected_.type) return false; + if (!actual.dims) return false; + if (!expected_.dims) return false; + if (actual.dims->size != expected_.dims->size) return false; + if (int n = actual.dims->size; + std::memcmp(actual.dims->data, expected_.dims->data, n * sizeof(int))) { + return false; + } + if (!actual.data.raw) return false; + if (!expected_.data.raw) return false; + if (actual.bytes != expected_.bytes) return false; + switch (comp_.float_comp) { + case FloatComparison::kExact: + if (int n = actual.bytes; + std::memcmp(actual.data.raw, expected_.data.raw, n)) { + return false; + } + break; + case FloatComparison::kApproximate: + for (int i = 0, n = ::tflite::NumElements(actual.dims); i < n; ++i) { + if (!CompareFloat(actual.data.f[i], expected_.data.f[i])) { + return false; + } + } + break; + }; + return true; + } + + TensorComparison comp_; + TfLiteTensor expected_; +}; + +} // namespace internal + +inline PolymorphicMatcher EqualsTensor( + const TfLiteTensor& expected) { + internal::TensorComparison comp; + return MakePolymorphicMatcher(internal::TensorMatcher(comp, expected)); +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m) { + m.mutable_impl().SetCompareApproximately(); + return m; +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m, double margin) { + m.mutable_impl().SetCompareApproximately(); + m.mutable_impl().SetMargin(margin); + return m; +} + +template +inline InnerTensorMatcherT Approximately(InnerTensorMatcherT m, double margin, + double fraction) { + m.mutable_impl().SetCompareApproximately(); + m.mutable_impl().SetMargin(margin); + m.mutable_impl().SetFraction(fraction); + return m; +} + +} // namespace tflite +} // namespace testing + +#endif // TENSORFLOW_LITE_TESTING_MATCHERS_H_ diff --git a/tensorflow/lite/testing/matchers_test.cc b/tensorflow/lite/testing/matchers_test.cc new file mode 100644 index 00000000000000..bae6cff1af3a08 --- /dev/null +++ b/tensorflow/lite/testing/matchers_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/testing/matchers.h" + +#include +#include +#include + +#include +#include +#include "absl/base/casts.h" +#include "absl/types/span.h" +#include "tensorflow/lite/core/c/c_api_types.h" +#include "tensorflow/lite/core/c/common.h" + +namespace tflite { +namespace { + +// A wrapper of TfLiteTensor that frees dims at destruction. +struct Tensor : public TfLiteTensor { + template + Tensor(TfLiteType dtype, const std::vector& shape, absl::Span buf) { + type = dtype; + dims = TfLiteIntArrayCreate(shape.size()); + std::memcpy(dims->data, shape.data(), shape.size() * sizeof(int)); + data = {.data = buf.data()}; + bytes = buf.size() * sizeof(T); + } + ~Tensor() { TfLiteIntArrayFree(dims); } +}; + +// Delegate pretty print to PrintTo(TfLiteTensor&). +void PrintTo(const Tensor& tensor, std::ostream* os) { // NOLINT + PrintTo(absl::implicit_cast(tensor), os); +} + +using ::testing::tflite::Approximately; +using ::testing::tflite::EqualsTensor; + +TEST(TensorMatcherTest, ExactlyEqualsSelf) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + EXPECT_THAT(a, EqualsTensor(a)); +} + +TEST(TensorMatcherTest, ExactlyEqualsSame) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.71828f, 3.14159f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, EqualsTensor(b)); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentType) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + Tensor b(TfLiteType::kTfLiteInt32, {1, 2}, absl::MakeSpan(data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentDims) { + float data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(data)); + Tensor b(TfLiteType::kTfLiteFloat32, {2, 1}, absl::MakeSpan(data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, DoesNotExactlyEqualDifferentData) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {3.14159f, 2.71828f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsDefaultMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.718277f, 3.141593f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Approximately(EqualsTensor(b))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsWithLooseMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Approximately(EqualsTensor(b), /*margin=*/0.01)); +} + +TEST(TensorMatcherTest, DoesNotApproximatelyEqualWithTightMargin) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(Approximately(EqualsTensor(b), /*margin=*/0.001))); +} + +TEST(TensorMatcherTest, ApproximatelyEqualsWithLooseFraction) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT( + a, Approximately(EqualsTensor(b), /*margin=*/0.0, /*fraction=*/0.999)); +} + +TEST(TensorMatcherTest, DoesNotApproximatelyEqualWithTightFraction) { + float a_data[] = {2.71828f, 3.14159f}; + Tensor a(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(a_data)); + float b_data[] = {2.72f, 3.14f}; + Tensor b(TfLiteType::kTfLiteFloat32, {1, 2}, absl::MakeSpan(b_data)); + EXPECT_THAT(a, Not(Approximately(EqualsTensor(b), /*margin=*/0.0, + /*fraction=*/0.0001))); +} + +} // namespace +} // namespace tflite diff --git a/tensorflow/lite/testing/op_tests/is_finite.py b/tensorflow/lite/testing/op_tests/is_finite.py index 2425fa9686d475..493ea059b2a024 100644 --- a/tensorflow/lite/testing/op_tests/is_finite.py +++ b/tensorflow/lite/testing/op_tests/is_finite.py @@ -52,7 +52,7 @@ def random_index(shape): input_values[random_index(input_values.shape)] = np.inf input_values[random_index(input_values.shape)] = -np.inf - input_values[random_index(input_values.shape)] = np.NAN + input_values[random_index(input_values.shape)] = np.nan input_values[random_index(input_values.shape)] = tf.float32.max input_values[random_index(input_values.shape)] = tf.float32.min diff --git a/tensorflow/lite/toco/BUILD b/tensorflow/lite/toco/BUILD index 4f370bb9d1cb5e..8c9ccf7c225e7d 100644 --- a/tensorflow/lite/toco/BUILD +++ b/tensorflow/lite/toco/BUILD @@ -302,7 +302,6 @@ cc_library( "//tensorflow/lite/kernels/internal:strided_slice_logic", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc index 89da8a6888b9eb..d6932b73138c94 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_expanddims_to_reshape.cc @@ -17,11 +17,14 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc index cc519e4f559647..6d2b5ca4c4a582 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_diag_v2_or_v3_to_v1.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc index 66d7f64d7b7483..84e84aabce74d3 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_matrix_set_diag_v2_or_v3_to_v1.cc @@ -17,9 +17,9 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "absl/status/status.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc index 8f56bfa36794f7..b7763e1ff98fe3 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc @@ -17,10 +17,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc index 49c380b3564d27..60dcf00f8d5693 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_reorder_axes.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc index c3bfbf5369ef0c..c98d64d389aacb 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_squeeze_to_reshape.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc index 547e0d805757b6..c60ddff8a9284f 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_addn_to_add.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc index f493d4ec9c6ae7..c945615c1fb319 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_pack_to_reshape.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc index 4781f4ef88e780..71a7d92d2e2b0e 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_tile_to_concat.cc @@ -14,10 +14,12 @@ limitations under the License. ==============================================================================*/ #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc index 183cb536eefe1d..8a33ad575bcf12 100644 --- a/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/convert_trivial_transpose_to_reshape.cc @@ -15,7 +15,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc index f69afe4b237730..380cdf216efb70 100644 --- a/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/create_im2col_arrays.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include -#include "absl/strings/str_cat.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/dequantize.cc b/tensorflow/lite/toco/graph_transformations/dequantize.cc index 1aa50692e6748b..5dd4d2e8750377 100644 --- a/tensorflow/lite/toco/graph_transformations/dequantize.cc +++ b/tensorflow/lite/toco/graph_transformations/dequantize.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc index 0a7af2f78e72ff..cdd748ac371075 100644 --- a/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_fake_quant.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc index a0768141def7df..d3cfae07faebbd 100644 --- a/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc +++ b/tensorflow/lite/toco/graph_transformations/drop_im2col_arrays.cc @@ -12,10 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc index 22d6d94cde2f95..f8d639cc396e25 100644 --- a/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc index b496f5111ae239..ed3a89a70123ad 100644 --- a/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc +++ b/tensorflow/lite/toco/graph_transformations/ensure_uint8_weights_safe_for_fast_int8_kernels.cc @@ -17,10 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc index 7d342702594e2c..64b91ccf62878a 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index 926d41ca6bbe1e..3afa9c44a59e5c 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -18,11 +18,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc index 130cb0b3dd4bb4..fa0baf97dbd9c5 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/runtime/types.h" diff --git a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc index df1d6daf213f81..ba57090e2eff6a 100644 --- a/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc +++ b/tensorflow/lite/toco/graph_transformations/fuse_broadcast_into_following_binary.cc @@ -17,10 +17,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc index bee666531a7573..125e5597a49f35 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc @@ -21,9 +21,12 @@ limitations under the License. #include #include -#include "tensorflow/lite/toco/toco_port.h" -#include "tensorflow/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/lite/toco/format_port.h" +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h index 9f93ee1b36eb28..c7e2c9de186f97 100644 --- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/toco_port.h" diff --git a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc index f5a8d161e926a2..2da6fbe6cfe76f 100644 --- a/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc +++ b/tensorflow/lite/toco/graph_transformations/group_bidirectional_sequence_ops.cc @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc index 53c12b476a8c0e..6f142a447f60d8 100644 --- a/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/hardcode_min_max.cc @@ -17,7 +17,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc index 026f51ab144a9b..985e588072136e 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -15,10 +15,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc index 4a6dea0c487c42..437147f8b55d81 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_hardswish.cc @@ -18,7 +18,8 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/identify_util.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc index b66f0b02b2ab12..e8a5d209d64a6f 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc index 91bda7ef825cf0..a980995a870280 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_l2_pool.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc index 18e74ae6270bc6..df0aa9ff3ddba7 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm.cc @@ -16,8 +16,12 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc index 4b2c49757aef65..24299d557551c8 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_merge_inputs.cc @@ -18,8 +18,9 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc index 3de0a7198ee49d..aea6d93d00a04a 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_lstm_split_inputs.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "absl/memory/memory.h" -#include "absl/strings/string_view.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc index 580b680fa6800a..1d1d67bd253a75 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_nearest_upsample.cc @@ -17,8 +17,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc index 31edcb47a6d516..0f28cb1cd26ef6 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_prelu.cc @@ -17,10 +17,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" // This transformation rule tries to identify the PRelu structure generated by // Keras, and convert it to a single op. diff --git a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc index dad425c5809526..6f2e22439f7e44 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_relu1.cc @@ -17,7 +17,8 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/identify_util.h" #include "tensorflow/lite/toco/model.h" diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.cc b/tensorflow/lite/toco/graph_transformations/identify_util.cc index e8605114529139..6ed8e33152b67e 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_util.cc +++ b/tensorflow/lite/toco/graph_transformations/identify_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/identify_util.h b/tensorflow/lite/toco/graph_transformations/identify_util.h index 1a79231ff012c7..6c59b0b03cb326 100644 --- a/tensorflow/lite/toco/graph_transformations/identify_util.h +++ b/tensorflow/lite/toco/graph_transformations/identify_util.h @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/runtime/types.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc index 7a979b73f5b157..676aa752fc7a3d 100644 --- a/tensorflow/lite/toco/graph_transformations/lstm_utils.cc +++ b/tensorflow/lite/toco/graph_transformations/lstm_utils.cc @@ -16,6 +16,9 @@ limitations under the License. #include +#include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/tooling_util.h" + namespace toco { void CreateOptionalArray(Model* model, std::string* input_array_buffer, diff --git a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 290dc7fe457323..0726b32632668f 100644 --- a/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc index b07815ea350b18..a292b97f002010 100644 --- a/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc @@ -18,12 +18,14 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/runtime/types.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc index 85384139b638cc..588a03445d4df8 100644 --- a/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc +++ b/tensorflow/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -15,6 +15,9 @@ #include #include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc index 79d8229da8ea2e..fffdde0a571cf9 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -17,12 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/remove_trivial_passthrough.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/runtime/types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc index af801c3cbf1d63..ef0a5205bd867a 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -17,9 +17,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc index 0f9197cd485ea5..54b76fb89bbbda 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 10968a93211ece..62d8715b808491 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc index ab6f40765d8854..5136bc0012a8af 100644 --- a/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -21,11 +21,14 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/strings/str_join.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/kernels/internal/strided_slice_logic.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc index 36d6819ecf31e0..9e5e58017afd00 100644 --- a/tensorflow/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/lite/toco/graph_transformations/quantize.cc @@ -21,11 +21,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" +#include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" namespace toco { @@ -47,6 +51,7 @@ bool SupportsQuantization(Model* model, const Operator& op) { static const std::set supported_ops{ OperatorType::kAdd, OperatorType::kArgMax, + OperatorType::kArgMin, OperatorType::kAveragePool, OperatorType::kBatchToSpaceND, OperatorType::kConcatenation, diff --git a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc index 5e867eab33e8d4..bf9334f2a86793 100644 --- a/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc +++ b/tensorflow/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" -#include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc index 438c7a63aba1e8..fc15e8ed7cd406 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_final_dequantize_op.cc @@ -17,11 +17,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/model_flags.pb.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc index fdc4d274b33b5f..79e6b68c99978a 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_successive_transpose.cc @@ -15,7 +15,8 @@ limitations under the License. #include #include -#include "tensorflow/core/platform/logging.h" +#include "absl/status/status.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" diff --git a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc index 88402f0092e509..45de603fdc20a7 100644 --- a/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc +++ b/tensorflow/lite/toco/graph_transformations/remove_tensorflow_assert.cc @@ -16,10 +16,12 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/lite/toco/model.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/core/platform/logging.h" namespace toco { diff --git a/tensorflow/lite/toco/tflite/BUILD b/tensorflow/lite/toco/tflite/BUILD index 6bd43f8091d32f..7377ec00d6b666 100644 --- a/tensorflow/lite/toco/tflite/BUILD +++ b/tensorflow/lite/toco/tflite/BUILD @@ -112,6 +112,7 @@ cc_library( deps = [ ":operator", ":types", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_conversion_utils", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/platform:status", @@ -122,7 +123,6 @@ cc_library( "//tensorflow/lite/toco:model", "//tensorflow/lite/toco:toco_port", "//tensorflow/lite/toco:tooling_util", - "//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/versioning", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", diff --git a/tensorflow/lite/toco/tflite/export.cc b/tensorflow/lite/toco/tflite/export.cc index e9124e89f2a892..44223eac63c130 100644 --- a/tensorflow/lite/toco/tflite/export.cc +++ b/tensorflow/lite/toco/tflite/export.cc @@ -23,6 +23,7 @@ limitations under the License. #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers #include "flatbuffers/string.h" // from @flatbuffers +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_conversion_utils.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" @@ -32,7 +33,6 @@ limitations under the License. #include "tensorflow/lite/toco/tflite/types.h" #include "tensorflow/lite/toco/toco_types.h" #include "tensorflow/lite/toco/tooling_util.h" -#include "tensorflow/lite/tools/optimize/quantize_weights.h" #include "tensorflow/lite/tools/versioning/runtime_version.h" #include "tensorflow/lite/util.h" #include "tensorflow/lite/version.h" @@ -670,19 +670,19 @@ tensorflow::Status Export( flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); const uint8_t* buffer = builder.GetBufferPointer(); const ::tflite::Model* input_model = ::tflite::GetModel(buffer); - ::tflite::optimize::BufferType quantized_type; + ::mlir::lite::toco_legacy::BufferType quantized_type; if (params.quantize_weights == QuantizedBufferType::INT8) { - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_INT8; + quantized_type = ::mlir::lite::toco_legacy::BufferType::QUANTIZED_INT8; } else if (params.quantize_weights == QuantizedBufferType::FLOAT16) { - quantized_type = ::tflite::optimize::BufferType::QUANTIZED_FLOAT16; + quantized_type = ::mlir::lite::toco_legacy::BufferType::QUANTIZED_FLOAT16; } else { return tensorflow::errors::InvalidArgument( "Quantized type not recognized"); } - if (!::tflite::optimize::QuantizeWeights( + if (!::mlir::lite::toco_legacy::QuantizeWeights( &q_builder, input_model, quantized_type, !params.disable_per_channel, - ::tflite::optimize::QuantizerType::OLD_QUANTIZER) + ::mlir::lite::toco_legacy::QuantizerType::OLD_QUANTIZER) .ok()) { return tensorflow::errors::InvalidArgument( "Quantize weights transformation failed."); diff --git a/tensorflow/lite/toco/toco_flags.proto b/tensorflow/lite/toco/toco_flags.proto index 1760841a333f6a..ac5ed8c3ef6ae2 100644 --- a/tensorflow/lite/toco/toco_flags.proto +++ b/tensorflow/lite/toco/toco_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 64. +// Next ID to use: 65. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -360,4 +360,16 @@ message TocoFlags { // Enables the attempt to directly lower composites into tflite ops. // WARNING: Experimental interface, subject to change. optional bool enable_composite_direct_lowering = 63 [default = false]; + + // The source model framework. + enum ModelOriginFramework { + UNSET = 0; + TENSORFLOW = 1; + KERAS = 2; + JAX = 3; + PYTORCH = 4; + } + + // The source model type. + optional ModelOriginFramework model_origin_framework = 64 [default = UNSET]; } diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD index b08a2d913b6ec7..b9260be0b9eac3 100644 --- a/tensorflow/lite/tools/BUILD +++ b/tensorflow/lite/tools/BUILD @@ -414,6 +414,19 @@ cc_library( "//tensorflow/lite/c:common", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels:kernel_util", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + copts = tflite_copts(), + deps = [ + ":utils", + "//tensorflow/lite/c:common", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", ], ) diff --git a/tensorflow/lite/tools/benchmark/BUILD b/tensorflow/lite/tools/benchmark/BUILD index 63aae4ff6b0029..b26dbde5a742d9 100644 --- a/tensorflow/lite/tools/benchmark/BUILD +++ b/tensorflow/lite/tools/benchmark/BUILD @@ -162,6 +162,7 @@ cc_library( ":benchmark_model_lib", ":benchmark_utils", ":profiling_listener", + "//tensorflow/core/example:example_protos_cc_impl", "//tensorflow/lite:framework", "//tensorflow/lite:simple_memory_arena_debug_dump", "//tensorflow/lite:string_util", @@ -180,6 +181,7 @@ cc_library( "//tensorflow/lite/tools/delegates:tflite_execution_providers", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@ruy//ruy/profiler", ], ) diff --git a/tensorflow/lite/tools/benchmark/CMakeLists.txt b/tensorflow/lite/tools/benchmark/CMakeLists.txt index 56794382ff45a8..eb0862f58aea00 100644 --- a/tensorflow/lite/tools/benchmark/CMakeLists.txt +++ b/tensorflow/lite/tools/benchmark/CMakeLists.txt @@ -47,6 +47,8 @@ list(APPEND TFLITE_BENCHMARK_LIBS list(APPEND TFLITE_BENCHMARK_LIBS profiling_info_proto + feature_proto + example_proto protobuf::libprotobuf ) diff --git a/tensorflow/lite/tools/benchmark/README.md b/tensorflow/lite/tools/benchmark/README.md index e92d841b9c6a87..4b2f82fed258d0 100644 --- a/tensorflow/lite/tools/benchmark/README.md +++ b/tensorflow/lite/tools/benchmark/README.md @@ -90,6 +90,15 @@ and the following optional parameters: and the path to include the name of the output CSV; otherwise results are printed to `stdout`. +* `output_filepath`: `str` (default="") \ + File path to save output tensor data to. If specified, the output tensor + values are saved as binary data in the file. + +* `output_proto_filepath`: `str` (default="") \ + File path to save output tensor data as tensorflow example proto. If + specified, the output tensor values are saved in tensorflow example and then + serialized to the file. + * `print_preinvoke_state`: `bool` (default=false) \ Whether to print out the TfLite interpreter internals just before calling tflite::Interpreter::Invoke. The internals will include allocated memory diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc index 8fb5b23b7860d9..15a18a6f5c4196 100644 --- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc +++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -36,7 +37,10 @@ limitations under the License. #include "absl/strings/str_replace.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "ruy/profiler/profiler.h" // from @ruy +#include "tensorflow/core/example/example.pb.h" +#include "tensorflow/core/example/feature.pb.h" #include "tensorflow/lite/core/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/core/kernels/register.h" @@ -87,6 +91,49 @@ const char* kOpProfilingOutputModes[] = {kOpProfilingOutputModeStdout, kOpProfilingOutputModeCsv, kOpProfilingOutputModeProto}; +// Sets feature values in the tensorflow::Example proto from the tflite tensor. +// Returns an error if the tensor type is not supported or the tensor dime is a +// nullptr. +TfLiteStatus MaybeSetFeatureValuesFromTensor(const TfLiteTensor& tensor, + tensorflow::Example& example) { + if (tensor.dims == nullptr) { + return kTfLiteError; + } + + int total_elements = 1; + for (int i = 0; i < tensor.dims->size; i++) { + total_elements *= tensor.dims->data[i]; + } + tensorflow::Feature& feature = + (*example.mutable_features()->mutable_feature())[tensor.name]; + switch (tensor.type) { + case kTfLiteFloat32: + case kTfLiteFloat64: + feature.mutable_float_list()->mutable_value()->Resize(total_elements, 0); + return utils::TfLiteTensorToFloat32Array( + tensor, + absl::MakeSpan( + feature.mutable_float_list()->mutable_value()->mutable_data(), + feature.float_list().value_size())); + case kTfLiteUInt8: + case kTfLiteInt8: + case kTfLiteUInt16: + case kTfLiteInt16: + case kTfLiteInt32: + case kTfLiteUInt32: + case kTfLiteUInt64: + case kTfLiteInt64: + feature.mutable_int64_list()->mutable_value()->Resize(total_elements, 0); + return utils::TfLiteTensorToInt64Array( + tensor, + absl::MakeSpan( + feature.mutable_int64_list()->mutable_value()->mutable_data(), + feature.int64_list().value_size())); + default: + return kTfLiteError; + } +} + // Dumps ruy profiling events if the ruy profiler is enabled. class RuyProfileListener : public BenchmarkListener { public: @@ -153,17 +200,37 @@ class OutputSaver : public BenchmarkListener { } void OnBenchmarkEnd(const BenchmarkResults& results) override { - std::string path = params_->Get("output_filepath"); - if (path.empty()) return; + // If the output_filepath is specified, save the output tensors to the file. + const std::string path = params_->Get("output_filepath"); + if (!path.empty()) { + std::ofstream ofs(path, std::ofstream::out); + if (ofs.good()) { + for (int i = 0; i < interpreter_runner_->outputs().size(); i++) { + int tensor_index = interpreter_runner_->outputs()[i]; + ofs.write(interpreter_runner_->tensor(tensor_index)->data.raw, + interpreter_runner_->tensor(tensor_index)->bytes); + } + ofs.close(); + } + } - std::ofstream ofs(path, std::ofstream::out); - if (ofs.good()) { + // If the output_proto_filepath is specified, save the output tensors as + // tensorflow::Example proto and serialize it to the file. + const std::string output_proto_path = + params_->Get("output_proto_filepath"); + if (!output_proto_path.empty()) { + tensorflow::Example example; for (int i = 0; i < interpreter_runner_->outputs().size(); i++) { - int tensor_index = interpreter_runner_->outputs()[i]; - ofs.write(interpreter_runner_->tensor(tensor_index)->data.raw, - interpreter_runner_->tensor(tensor_index)->bytes); + const int tensor_index = interpreter_runner_->outputs()[i]; + const TfLiteTensor& tensor = + *(interpreter_runner_->tensor(tensor_index)); + MaybeSetFeatureValuesFromTensor(tensor, example); + } + std::ofstream ofs(output_proto_path, std::ios::out | std::ios::binary); + if (ofs.good()) { + example.SerializeToOstream(&ofs); + ofs.close(); } - ofs.close(); } } @@ -518,6 +585,8 @@ BenchmarkParams BenchmarkTfLiteModel::DefaultParams() { BenchmarkParam::Create(false)); default_params.AddParam("output_filepath", BenchmarkParam::Create("")); + default_params.AddParam("output_proto_filepath", + BenchmarkParam::Create("")); default_params.AddParam("tensor_name_display_length", BenchmarkParam::Create(25)); @@ -622,6 +691,9 @@ std::vector BenchmarkTfLiteModel::GetFlags() { CreateFlag( "output_filepath", ¶ms_, "File path to export outputs layer as binary data."), + CreateFlag( + "output_proto_filepath", ¶ms_, + "File path to export outputs layer as tf example proto."), CreateFlag( "tensor_name_display_length", ¶ms_, "The number of characters to show for the tensor's name when " @@ -700,6 +772,9 @@ void BenchmarkTfLiteModel::LogParams() { "Constant CAST output cache", verbose); LOG_BENCHMARK_PARAM(std::string, "output_filepath", "File path to export outputs layer to", verbose); + LOG_BENCHMARK_PARAM(std::string, "output_proto_filepath", + "File path to export outputs layer as tf example to", + verbose); LOG_BENCHMARK_PARAM(int32_t, "tensor_name_display_length", "Tensor name display length", verbose); LOG_BENCHMARK_PARAM(int32_t, "tensor_type_display_length", diff --git a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake index 0e24bf67849706..71bfa0de5fa4a3 100644 --- a/tensorflow/lite/tools/cmake/modules/xnnpack.cmake +++ b/tensorflow/lite/tools/cmake/modules/xnnpack.cmake @@ -23,7 +23,7 @@ OverridableFetchContent_Declare( xnnpack GIT_REPOSITORY https://github.com/google/XNNPACK # Sync with tensorflow/workspace2.bzl - GIT_TAG 488a695e3a10269755895da05c2711aadf08489b + GIT_TAG 9ddeb74f9f6866174d61888947e4aa9ffe963b1b GIT_PROGRESS TRUE PREFIX "${CMAKE_BINARY_DIR}" SOURCE_DIR "${CMAKE_BINARY_DIR}/xnnpack" diff --git a/tensorflow/lite/tools/optimize/BUILD b/tensorflow/lite/tools/optimize/BUILD index a88f69c8e541c4..bca1714ee6e93e 100644 --- a/tensorflow/lite/tools/optimize/BUILD +++ b/tensorflow/lite/tools/optimize/BUILD @@ -175,6 +175,7 @@ cc_library( ":model_utils", "//tensorflow/lite:framework", "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/c:c_api_types", "//tensorflow/lite/core/api", "//tensorflow/lite/core/c:common", "//tensorflow/lite/kernels/internal:cppmath", @@ -182,7 +183,6 @@ cc_library( "//tensorflow/lite/kernels/internal:tensor_utils", "//tensorflow/lite/kernels/internal:types", "//tensorflow/lite/schema:schema_fbs", - "@com_google_absl//absl/memory", "@eigen_archive//:eigen3", ], ) diff --git a/tensorflow/lite/tools/optimize/modify_model_interface.cc b/tensorflow/lite/tools/optimize/modify_model_interface.cc index 2af2324d70c793..860fa05b99f17c 100644 --- a/tensorflow/lite/tools/optimize/modify_model_interface.cc +++ b/tensorflow/lite/tools/optimize/modify_model_interface.cc @@ -248,7 +248,7 @@ TfLiteStatus SetOutputTypeToUINT8(ModelT* model, TfLiteStatus RemoveInputTensor(ModelT* model, const std::vector& inputs, - int32 original_number_tensors) { + int32_t original_number_tensors) { // Consistency check to make sure that erase start from the end. int last_op_index = std::numeric_limits::max(); int last_tensor_index = std::numeric_limits::max(); @@ -274,7 +274,7 @@ TfLiteStatus RemoveInputTensor(ModelT* model, TfLiteStatus RemoveOutputTensor(ModelT* model, const std::vector& outputs, - int32 original_number_tensors) { + int32_t original_number_tensors) { // Consistency check to make sure that erase start from the end. int last_op_index = std::numeric_limits::max(); int last_tensor_index = std::numeric_limits::max(); @@ -298,7 +298,6 @@ TfLiteStatus RemoveOutputTensor(ModelT* model, return kTfLiteOk; } - int GetOriginalNumberOfTensors(const TensorType& input_type, const TensorType& output_type, ModelT* model, ErrorReporter* error_reporter) { diff --git a/tensorflow/lite/tools/optimize/quantization_utils.cc b/tensorflow/lite/tools/optimize/quantization_utils.cc index 996483be758fdd..e1e6daa43ff1d0 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_utils.cc @@ -16,19 +16,21 @@ limitations under the License. #include #include +#include #include -#include +#include #include -#include +#include +#include -#include "absl/memory/memory.h" #include "Eigen/Core" // from @eigen_archive #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/tensor_utils.h" -#include "tensorflow/lite/kernels/internal/types.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" +#include "tensorflow/lite/logger.h" #include "tensorflow/lite/minimal_logging.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/tools/optimize/model_utils.h" @@ -39,16 +41,17 @@ namespace utils { namespace { +// LINT.IfChange(QuantizationUtilsConstants) const int8_t kMinQuantizedValue8bit = -127; const int8_t kMaxQuantizedValue8bit = 127; - const int8_t kMinQuantizedValue4bit = -7; const int8_t kMaxQuantizedValue4bit = 7; - // The maximum number of dimensions supported in per-channel quantization. constexpr int kPerChannelMaxDim = 4; +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:QuantizationUtilsConstants) } // namespace +// LINT.IfChange(NumElements) TfLiteStatus NumElements(const TensorT& tensor, uint64_t* num_elements) { *num_elements = 1; for (const int64_t dim : tensor.shape) { @@ -59,6 +62,7 @@ TfLiteStatus NumElements(const TensorT& tensor, uint64_t* num_elements) { } return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:NumElements) // Nudge min and max so that floating point 0 falls exactly on a quantized // value, returning the nudges scale and zero_point. @@ -139,6 +143,7 @@ void FillSingleMinMax(const float* const input, const uint64_t input_size, quantization_params->max.assign(1, *minmax.second); } +// LINT.IfChange(FillPerChannelMinMax) TfLiteStatus FillPerChannelMinMax(const float* const input, const std::vector& dimension, int32_t channel_dim_index, @@ -202,6 +207,7 @@ TfLiteStatus FillPerChannelMinMax(const float* const input, } return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:FillPerChannelMinMax) // Populates the scales vector based on max and min values of quant_params TfLiteStatus GetSymmetricScalesFromMaxMin(QuantizationParametersT* quant_params, @@ -300,6 +306,7 @@ TfLiteStatus AdjustWeightsForBiasScale(QuantizationParametersT* quant_params, return kTfLiteOk; } +// LINT.IfChange(SymmetricPerChannelQuantization) // Per-channel quantize a tensor at the given index and fills both scales and // quantized values. TfLiteStatus SymmetricPerChannelQuantization(TensorT* tensor, @@ -343,6 +350,7 @@ TfLiteStatus SymmetricPerChannelQuantization(TensorT* tensor, channel_dim_index, output_value); return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:SymmetricPerChannelQuantization) std::vector SymmetricQuantizeFloatsToInt16(const float* data, uint64_t num_elements, @@ -381,6 +389,7 @@ TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor, error_reporter); } +// LINT.IfChange(SymmetricPerChannelQuantizeValues) void SymmetricPerChannelQuantizeValues(const float* const input, const std::vector& scales_inv, const std::vector& dimension, @@ -417,6 +426,7 @@ void SymmetricPerChannelQuantizeValues(const float* const input, } } } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:SymmetricPerChannelQuantizeValues) // Quantize the tensor using the max and min values recorded in its quantization // parameters. Applies per-layer quantization. @@ -470,6 +480,7 @@ TfLiteStatus SymmetricQuantizeTensorFromMinMax(ModelT* model, TensorT* tensor, return kTfLiteOk; } +// LINT.IfChange(SymmetricQuantizeTensor) TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { if (model == nullptr || tensor == nullptr) { TFLITE_LOG(TFLITE_LOG_ERROR, "No tensor to quantize."); @@ -508,7 +519,9 @@ TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) { return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:SymmetricQuantizeTensor) +// LINT.IfChange(QuantizeTensorFloat16) TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) { if (model == nullptr || tensor == nullptr) { TFLITE_LOG(TFLITE_LOG_ERROR, "No tensor to quantize."); @@ -551,7 +564,9 @@ TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor) { return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:QuantizeTensorFloat16) +// LINT.IfChange(AddQuantizationParams) TfLiteStatus AddQuantizationParams(const std::vector& scales, const std::vector& zero_point, int quantized_dimension, @@ -579,7 +594,9 @@ TfLiteStatus AddQuantizationParams(const std::vector& scales, tensor->type = output_type; return kTfLiteOk; } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:AddQuantizationParams) +// LINT.IfChange(SymmetricQuantizeTensorPerChannel) TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, int32_t channel_dim_index, ErrorReporter* error_reporter) { @@ -619,6 +636,7 @@ TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, uint8_buffer, buffer_size, TensorType_INT8, model, tensor, error_reporter); } +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.cc:SymmetricQuantizeTensorPerChannel) template std::vector SymmetricBiasQuantize(const float* data, diff --git a/tensorflow/lite/tools/optimize/quantization_utils.h b/tensorflow/lite/tools/optimize/quantization_utils.h index dc6ba7d7299042..7f58971a5350c1 100644 --- a/tensorflow/lite/tools/optimize/quantization_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_utils.h @@ -15,10 +15,11 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ #define TENSORFLOW_LITE_TOOLS_OPTIMIZE_QUANTIZATION_UTILS_H_ +#include #include #include -#include "tensorflow/lite/context.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/api/error_reporter.h" #include "tensorflow/lite/schema/schema_generated.h" @@ -26,8 +27,10 @@ namespace tflite { namespace optimize { namespace utils { +// LINT.IfChange(num_elements) // Returns the number of elements in the given tensor. TfLiteStatus NumElements(const TensorT& tensor, uint64_t* num_elements); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:num_elements) // Populates the scale and zero point for quantization parameters. // @@ -41,13 +44,16 @@ void GetAsymmetricQuantizationParams( void FillSingleMinMax(const float* const input, const uint64_t input_size, QuantizationParametersT* quantization_params); +// LINT.IfChange(fill_per_channel_min_max) // Populates the max and min values for per channel quantization. TfLiteStatus FillPerChannelMinMax(const float* const input, const std::vector& dimension, int32_t channel_dim_index, QuantizationParametersT* quantization_params, ErrorReporter* error_reporter); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:fill_per_channel_min_max) +// LINT.IfChange(symmetric_per_channel_quantization) // Per-channel quantize a tensor at the given index and returns both scales and // quantized values. // Parameters: @@ -66,7 +72,9 @@ TfLiteStatus SymmetricPerChannelQuantization(TensorT* tensor, std::vector* output_scales, std::vector* output_value, ErrorReporter* error_reporter); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:symmetric_per_channel_quantization) +// LINT.IfChange(symmetric_per_channel_quantize_values) // Quantize the values given an array of scales. void SymmetricPerChannelQuantizeValues(const float* const input, const std::vector& scales_inv, @@ -74,14 +82,20 @@ void SymmetricPerChannelQuantizeValues(const float* const input, int32_t channel_dim_index, std::vector* output_value, TfLiteType type = kTfLiteNoType); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:symmetric_per_channel_quantize_values) +// LINT.IfChange(symmetric_quantize_tensor) // Quantizes tensor using symmetric quantization with the min and max elements // of the tensor. TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:symmetric_quantize_tensor) +// LINT.IfChange(quantize_tensor_float16) // Quantizes tensor to float16. TfLiteStatus QuantizeTensorFloat16(ModelT* model, TensorT* tensor); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:quantize_tensor_float16) +// LINT.IfChange(add_quantization_params) // Add quantization parameters. TfLiteStatus AddQuantizationParams(const std::vector& scales, const std::vector& zero_point, @@ -90,6 +104,7 @@ TfLiteStatus AddQuantizationParams(const std::vector& scales, size_t buffer_size, TensorType output_type, ModelT* model, TensorT* tensor, ErrorReporter* error_reporter); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:add_quantization_params) // Populates the scales vector based on max and min values of quant_params TfLiteStatus GetSymmetricScalesFromMaxMin(QuantizationParametersT* quant_params, @@ -104,10 +119,12 @@ TfLiteStatus AdjustWeightsForBiasScale(QuantizationParametersT* quant_params, const float input_scale, ErrorReporter* error_reporter); +// LINT.IfChange(symmetric_quantize_tensor_per_channel) // Quantizes tensor with per channel. TfLiteStatus SymmetricQuantizeTensorPerChannel(ModelT* model, TensorT* tensor, int32_t channel_dim_index, ErrorReporter* error_reporter); +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantization_utils.h:symmetric_quantize_tensor_per_channel) // Symmetrically quantizes float to 16bits. TfLiteStatus SymmetricQuantizeFloatsToInt16(ModelT* model, TensorT* tensor, diff --git a/tensorflow/lite/tools/utils.cc b/tensorflow/lite/tools/utils.cc index 12396ed7c3ce05..b8c18b24c5cce6 100644 --- a/tensorflow/lite/tools/utils.cc +++ b/tensorflow/lite/tools/utils.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include #include +#include "absl/types/span.h" #include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -55,6 +57,30 @@ inline InputTensorData CreateInputTensorData(int num_elements, return tmp; } +// Converts a TfLiteTensor to a float array. Returns an error if the tensor +// dimension is a null pointer. +template +TfLiteStatus ConvertToArray(const TfLiteTensor& tflite_tensor, + absl::Span& values) { + if (tflite_tensor.dims == nullptr) { + return kTfLiteError; + } + + int total_elements = 1; + for (int i = 0; i < tflite_tensor.dims->size; i++) { + total_elements *= tflite_tensor.dims->data[i]; + } + if (total_elements != values.size()) { + return kTfLiteError; + } + const TensorType* tensor_data = + reinterpret_cast(tflite_tensor.data.data); + for (int i = 0; i < total_elements; i++) { + values[i] = static_cast(tensor_data[i]); + } + return kTfLiteOk; +} + } // namespace InputTensorData CreateRandomTensorData(const TfLiteTensor& tensor, @@ -168,5 +194,41 @@ void GetDataRangesForType(TfLiteType type, float* low_range, } } +TfLiteStatus TfLiteTensorToFloat32Array(const TfLiteTensor& tensor, + absl::Span values) { + switch (tensor.type) { + case kTfLiteFloat32: + return ConvertToArray(tensor, values); + case kTfLiteFloat64: + return ConvertToArray(tensor, values); + default: + return kTfLiteError; + } +} + +TfLiteStatus TfLiteTensorToInt64Array(const TfLiteTensor& tensor, + absl::Span values) { + switch (tensor.type) { + case kTfLiteUInt8: + return ConvertToArray(tensor, values); + case kTfLiteInt8: + return ConvertToArray(tensor, values); + case kTfLiteUInt16: + return ConvertToArray(tensor, values); + case kTfLiteInt16: + return ConvertToArray(tensor, values); + case kTfLiteInt32: + return ConvertToArray(tensor, values); + case kTfLiteUInt32: + return ConvertToArray(tensor, values); + case kTfLiteUInt64: + return ConvertToArray(tensor, values); + case kTfLiteInt64: + return ConvertToArray(tensor, values); + default: + return kTfLiteError; + } +} + } // namespace utils } // namespace tflite diff --git a/tensorflow/lite/tools/utils.h b/tensorflow/lite/tools/utils.h index 2fc9c62de119d2..12d69e29dc2dd6 100644 --- a/tensorflow/lite/tools/utils.h +++ b/tensorflow/lite/tools/utils.h @@ -16,8 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_LITE_TOOLS_UTILS_H_ #define TENSORFLOW_LITE_TOOLS_UTILS_H_ +#include #include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/core/c/common.h" namespace tflite { @@ -43,6 +47,14 @@ InputTensorData CreateRandomTensorData(const TfLiteTensor& tensor, // benchmarking and/or testing purposes. void GetDataRangesForType(TfLiteType type, float* low_range, float* high_range); +// Converts TfLiteTensor to float array. Returns an error if the tensor type is +// not supported or the values size is not equal to the tensor dimension. +TfLiteStatus TfLiteTensorToFloat32Array(const TfLiteTensor& tensor, + absl::Span values); + +// Same as above, but converts to int64_t array. +TfLiteStatus TfLiteTensorToInt64Array(const TfLiteTensor& tensor, + absl::Span values); } // namespace utils } // namespace tflite diff --git a/tensorflow/lite/tools/utils_test.cc b/tensorflow/lite/tools/utils_test.cc new file mode 100644 index 00000000000000..ce519827aaf12f --- /dev/null +++ b/tensorflow/lite/tools/utils_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/tools/utils.h" + +#include + +#include +#include + +#include +#include +#include "absl/types/span.h" +#include "tensorflow/lite/c/common.h" + +namespace tflite::tools { +namespace { +using ::testing::FloatEq; + +// Helper function to test TfLiteTensorToFloat32Array. +template +void TestTfLiteTensorToFloat32Array(TfLiteType type) { + T data[] = {1, 2, 3, 4}; + TfLiteTensor tensor; + tensor.data.data = data; + tensor.type = type; + // Create an int array with 1 dimension and the array size is 4. + tensor.dims = TfLiteIntArrayCreate(1); + tensor.dims->data[0] = 4; + std::vector result(4, 0.0); + const auto status = + utils::TfLiteTensorToFloat32Array(tensor, absl::MakeSpan(result)); + TfLiteIntArrayFree(tensor.dims); + ASSERT_EQ(status, kTfLiteOk); + ASSERT_EQ(result.size(), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_THAT(result[i], FloatEq(static_cast(data[i]))); + } +} + +// Helper function to test TfLiteTensorToFloat32Array. +template +void TestTfLiteTensorToInt64Array(TfLiteType type) { + T data[] = {1, 2, 3, 4}; + TfLiteTensor tensor; + tensor.data.data = data; + tensor.type = type; + // Create an int array with 1 dimension and the array size is 4. + tensor.dims = TfLiteIntArrayCreate(1); + tensor.dims->data[0] = 4; + std::vector result(4, 0); + const auto status = + utils::TfLiteTensorToInt64Array(tensor, absl::MakeSpan(result)); + TfLiteIntArrayFree(tensor.dims); + ASSERT_EQ(status, kTfLiteOk); + ASSERT_EQ(result.size(), 4); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(result[i], static_cast(data[i])); + } +} + +// Tests TfLiteTensorToFloat32Array for supported TfLiteTypes. +TEST(Utils, TfLiteTensorToFloat32Array) { + TestTfLiteTensorToFloat32Array(kTfLiteFloat32); + TestTfLiteTensorToFloat32Array(kTfLiteFloat64); +} + +TEST(Utils, TfLiteTensorToInt64Array) { + TestTfLiteTensorToInt64Array(kTfLiteInt8); + TestTfLiteTensorToInt64Array(kTfLiteUInt8); + TestTfLiteTensorToInt64Array(kTfLiteInt16); + TestTfLiteTensorToInt64Array(kTfLiteUInt16); + TestTfLiteTensorToInt64Array(kTfLiteInt32); + TestTfLiteTensorToInt64Array(kTfLiteUInt32); + TestTfLiteTensorToInt64Array(kTfLiteInt64); + TestTfLiteTensorToInt64Array(kTfLiteUInt64); +} + +} // namespace +} // namespace tflite::tools diff --git a/tensorflow/lite/tools/versioning/gpu_compatibility.cc b/tensorflow/lite/tools/versioning/gpu_compatibility.cc index dd8658bda26a28..061eaca7a3c05b 100644 --- a/tensorflow/lite/tools/versioning/gpu_compatibility.cc +++ b/tensorflow/lite/tools/versioning/gpu_compatibility.cc @@ -1085,7 +1085,8 @@ absl::Status CheckGpuDelegateCompatibility(const OpSignature& op_sig, /*required_const_inputs=*/0, /*required_outputs=*/1)); - // Two arguments elemenetwise operations + // Two arguments elementwise operations + case kTfLiteBuiltinAtan2: case kTfLiteBuiltinDiv: case kTfLiteBuiltinEqual: case kTfLiteBuiltinFloorDiv: diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files index ab15889a196aad..8e09aa303c21a4 100644 --- a/tensorflow/opensource_only.files +++ b/tensorflow/opensource_only.files @@ -235,6 +235,7 @@ tf_staging/third_party/googleapis/build_rules.bzl: tf_staging/third_party/googleapis/googleapis.BUILD: tf_staging/third_party/googleapis/repository_rules.bzl: tf_staging/third_party/gpus/BUILD: +tf_staging/third_party/gpus/compiler_common_tools.bzl: tf_staging/third_party/gpus/crosstool/BUILD.rocm.tpl: tf_staging/third_party/gpus/crosstool/BUILD.sycl.tpl: tf_staging/third_party/gpus/crosstool/BUILD.tpl: @@ -252,6 +253,27 @@ tf_staging/third_party/gpus/cuda/LICENSE: tf_staging/third_party/gpus/cuda/build_defs.bzl.tpl: tf_staging/third_party/gpus/cuda/cuda_config.h.tpl: tf_staging/third_party/gpus/cuda/cuda_config.py.tpl: +tf_staging/third_party/gpus/cuda/hermetic/BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/BUILD: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_configure.bzl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl: +tf_staging/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl: tf_staging/third_party/gpus/cuda_configure.bzl: tf_staging/third_party/gpus/find_cuda_config:.py tf_staging/third_party/gpus/rocm/BUILD.tpl: @@ -284,6 +306,9 @@ tf_staging/third_party/nccl/archive.BUILD: tf_staging/third_party/nccl/archive.patch: tf_staging/third_party/nccl/build_defs.bzl.tpl: tf_staging/third_party/nccl/generated_names.bzl.tpl: +tf_staging/third_party/nccl/hermetic/BUILD: +tf_staging/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl: +tf_staging/third_party/nccl/hermetic/nccl_configure.bzl: tf_staging/third_party/nccl/nccl_configure.bzl: tf_staging/third_party/nccl/system.BUILD.tpl: tf_staging/third_party/nlohmann_json.BUILD: @@ -321,6 +346,7 @@ tf_staging/third_party/remote_config/remote_platform_configure.bzl: tf_staging/third_party/repo.bzl: tf_staging/third_party/six.BUILD: tf_staging/third_party/snappy.BUILD: +tf_staging/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: tf_staging/third_party/sqlite.BUILD: tf_staging/third_party/stablehlo/BUILD: tf_staging/third_party/systemlibs/BUILD.tpl: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f915f123d135cc..8aa243aafc84e5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -932,8 +932,6 @@ filegroup( "//tensorflow/core/util/tensor_bundle", # checkpoint_reader "//tensorflow/dtensor/cc:dtensor_device_cc", # DTensor "//tensorflow/dtensor/cc:tensor_layout", # DTensor - "//tensorflow/lite/kernels/shim:shape", # tf_text - "//tensorflow/lite/kernels/shim:tf_op_shim", # tf_text "//tensorflow/lite/toco/python:toco_python_api", # toco "//tensorflow/python/client:tf_session_helper", # tf_session "//tensorflow/python/eager:pywrap_tfe_lib", # pywrap_tfe_lib diff --git a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb index 8b7b3e9b350d13..44c2ea60c2a5b4 100644 --- a/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb +++ b/tensorflow/python/autograph/g3doc/pyct_tutorial.ipynb @@ -164,7 +164,7 @@ "source": [ "### Helpful static analysis passes\n", "\n", - "The `static_analysis` module contains various helper passes for dataflow analyis.\n", + "The `static_analysis` module contains various helper passes for dataflow analysis.\n", "\n", "All these passes annotate the AST. These annotations can be extracted using [anno.getanno](https://github.com/tensorflow/tensorflow/blob/40802bcdb5c8a4379da2145441f51051402bd29b/tensorflow/python/autograph/pyct/anno.py#L111). Most of them rely on the `qual_names` annotations, which just simplify the way more complex identifiers like `a.b.c` are accessed.\n", "\n", @@ -253,7 +253,7 @@ "\n", "\n", "def f(a):\n", - " if a \u003e 0:\n", + " if a > 0:\n", " return a\n", " b = -a\n", "\n", diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py index 08f73422667de1..ffacbe46e9f52f 100644 --- a/tensorflow/python/autograph/pyct/origin_info.py +++ b/tensorflow/python/autograph/pyct/origin_info.py @@ -172,7 +172,7 @@ def __init__(self, root_node, source_lines, comments_map, hasattr(root_node.decorator_list[0], 'lineno')): # Typical case: functions. The line number of the first decorator # is more accurate than the line number of the function itself in - # 3.8+. In earier versions they coincide. + # 3.8+. In earlier versions they coincide. self._lineno_offset = context_lineno - root_node.decorator_list[0].lineno else: # Fall back to the line number of the root node. diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index 8af0f7ae9477d0..5d6a8725c7cea2 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -570,7 +570,7 @@ def visit_FunctionDef(self, node): node.decorator_list = self.visit_block(node.decorator_list) if node.returns: node.returns = self._process_annotation(node.returns) - # Argument annotartions (includeing defaults) affect the defining context. + # Argument annotartions (including defaults) affect the defining context. node = self._visit_arg_annotations(node) function_name = qual_names.QN(node.name) diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py index cdeddaac7fd7ce..ad373a5808974f 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py @@ -249,7 +249,7 @@ def foo(): inner_fn_body = fn_body[1].body[1].body def_of_a_in_foo = inner_fn_body[0].value - # Even though `a` is visible in the inner functio above, the late binding + # Even though `a` is visible in the inner function above, the late binding # makes it impossible to assume that the same value will be visible at # call time. self.assertHasDefs(def_of_a_in_foo, 0) diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py index 5b59a5a18f5edf..d5ab1f5c541ce9 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py +++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py @@ -63,9 +63,10 @@ def res_name(self, ns, types_ns, name): ns: namespace types_ns: types namespace name: symbol name + Returns: Tuple (type, static_value). The first element is the type to use for - inferrence. The second is the static value to use. Return None to treat it + inference. The second is the static value to use. Return None to treat it as unknown. """ raise NotImplementedError('subclasses must implement') @@ -383,7 +384,7 @@ def _resolve_typed_callable(self, f_types, arg_types, keyword_types): for t in f_types: if isinstance(t, Callable): - # Note: these are undocummented - may be version-specific! + # Note: these are undocumented - may be version-specific! # Callable[[x], y]: __args__ are (x, y) args = t.__args__ if args: diff --git a/tensorflow/python/autograph/pyct/transformer.py b/tensorflow/python/autograph/pyct/transformer.py index c19009784a5620..005135e07b7786 100644 --- a/tensorflow/python/autograph/pyct/transformer.py +++ b/tensorflow/python/autograph/pyct/transformer.py @@ -314,13 +314,13 @@ def after_visit(node): in nodes after_visit: optional callable that takes in an AST node and returns a tuple (new_node, new_destination). It is called after visiting each item - in nodes. Is used in the same was as the - visit_* methods: new_node will replace the node; if not None, - new_destination must be a list, and subsequent nodes will be placed - in this list instead of the list returned by visit_block. + in nodes. Is used in the same was as the visit_* methods: new_node will + replace the node; if not None, new_destination must be a list, and + subsequent nodes will be placed in this list instead of the list + returned by visit_block. Returns: - A list of AST node objects containing the transformed items fron nodes, + A list of AST node objects containing the transformed items from nodes, except those nodes that have been relocated using after_visit. """ if nodes is None: diff --git a/tensorflow/python/autograph/pyct/transpiler.py b/tensorflow/python/autograph/pyct/transpiler.py index 013ccc562aab72..f7b9150e728fc9 100644 --- a/tensorflow/python/autograph/pyct/transpiler.py +++ b/tensorflow/python/autograph/pyct/transpiler.py @@ -238,7 +238,7 @@ def transform_ast(self, node, ctx): result = <> return result - transformer = MyTransfomer() + transformer = MyTransformer() result = transformer.transform(f, ...) # result is the output @@ -381,7 +381,7 @@ def transform_ast(self, node, ctx): node = <> return node - transformer = MyTransfomer() + transformer = MyTransformer() new_f, module, source_map = transformer.transform_function(f, ...) # new_f is a function with signature identical to f @@ -430,7 +430,7 @@ def _cached_factory(self, fn, cache_subkey): return cached_factory def transform_function(self, fn, user_context): - """Transforms a function. See GenericTranspiler.trasnform_function. + """Transforms a function. See GenericTranspiler.transform_function. This overload wraps the parent's `transform_function`, adding caching and facilities to instantiate the output as a Python object. It also @@ -441,6 +441,7 @@ def transform_function(self, fn, user_context): fn: A function or lambda. user_context: An opaque object (may be None) that is forwarded to transform_ast, through the ctx.user attribute. + Returns: A tuple: * A function or lambda with the same signature and closure as `fn` diff --git a/tensorflow/python/client/BUILD b/tensorflow/python/client/BUILD index c66d1c75782dda..76c4ccad009a29 100644 --- a/tensorflow/python/client/BUILD +++ b/tensorflow/python/client/BUILD @@ -292,6 +292,7 @@ py_strict_library( "//tensorflow/python/util:compat", "//tensorflow/python/util:deprecation", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", "@pypi_wrapt//:pkg", diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index d42e18551808d6..87b794fe094156 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -41,6 +41,7 @@ from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util import nest +from tensorflow.python.util import numpy_compat from tensorflow.python.util.compat import collections_abc from tensorflow.python.util.tf_export import tf_export @@ -140,14 +141,19 @@ def _get_feeds_for_indexed_slices(feed, feed_val): def _convert_to_numpy_obj(numpy_dtype, obj): """Explicitly convert obj based on numpy type except for string type.""" - return numpy_dtype(obj) if numpy_dtype is not object else str(obj) + return ( + numpy_dtype(np.array(obj).astype(numpy_dtype)) + if numpy_dtype is not object + else str(obj) + ) def register_session_run_conversion_functions( tensor_type, fetch_function, feed_function=None, - feed_function_for_partial_run=None): + feed_function_for_partial_run=None, +): """Register fetch and feed conversion functions for `tf.Session.run()`. This function registers a triple of conversion functions for fetching and/or @@ -1181,7 +1187,7 @@ def _feed_fn(feed, feed_val): np_val = subfeed_val.to_numpy_array() feed_handles[subfeed_t.ref()] = subfeed_val else: - np_val = np.asarray(subfeed_val, dtype=subfeed_dtype) + np_val = numpy_compat.np_asarray(subfeed_val, subfeed_dtype) if (not is_tensor_handle_feed and not subfeed_t.get_shape().is_compatible_with(np_val.shape)): diff --git a/tensorflow/python/client/tf_session_wrapper.cc b/tensorflow/python/client/tf_session_wrapper.cc index b2d3492f99dfd5..00baa132c6f036 100644 --- a/tensorflow/python/client/tf_session_wrapper.cc +++ b/tensorflow/python/client/tf_session_wrapper.cc @@ -334,7 +334,7 @@ class tf_handle { tf_handle(const tf_handle& other) { Reset(other.obj_); } - tf_handle& operator=(tf_handle&& other) { + tf_handle& operator=(tf_handle&& other) noexcept { if (this == &other) { return *this; } diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 8ed45e8a382b2b..29b1611b41ef35 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 7, 29) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2024, 8, 20) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None diff --git a/tensorflow/python/compiler/tensorrt/README.md b/tensorflow/python/compiler/tensorrt/README.md index 4c1d96bbed7e99..ec95cb6de69d30 100644 --- a/tensorflow/python/compiler/tensorrt/README.md +++ b/tensorflow/python/compiler/tensorrt/README.md @@ -1,5 +1,7 @@ # Using TensorRT in TensorFlow (TF-TRT) +Note: Starting from v.2.18.0, TensorFlow doesn't support TensorRT. + This module provides necessary bindings and introduces `TRTEngineOp` operator that wraps a subgraph in TensorRT. This module is under active development. diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py index 857c6f70470723..b7c062b194f8fc 100644 --- a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py @@ -86,13 +86,6 @@ def testIncorrectCardinality(self, num_elements, asserted_cardinality, @combinations.generate( combinations.times( test_base.default_test_combinations(), - combinations.combine( - num_elements=10, - asserted_cardinality=1, - expected_error=errors.FailedPreconditionError, - expected_error_message=( - "Input dataset was expected to contain 1 element but " - "contained at least 2 elements.")) + combinations.combine( num_elements=10, asserted_cardinality=100, diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index d84439f8fb1ab7..8db171dde251b2 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -253,8 +253,8 @@ def __init__(self, protocol: The protocol to use for communicating with the tf.data service, e.g. "grpc". data_transfer_protocol: (Optional.) The protocol to use for transferring - data with the tf.data service. By default, data is transferred using - gRPC. + data with the tf.data service. If not provided, a protocol is determined + at runtime. job_name: (Optional.) The name of the job. If provided, it must be a non-empty string or Tensor. This argument makes it possible for multiple datasets to share the same job. The default behavior is that the dataset @@ -280,7 +280,7 @@ def __init__(self, provided, dataset iteration will be shared across concurrently running trainers. See https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers - for details. + for details. target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data runtime decides which workers to read from. If `"ANY"`, reads from any tf.data service workers. If `"LOCAL"`, only reads from local @@ -464,8 +464,8 @@ def _distribute( service: A string or a tuple indicating how to connect to the tf.data service. If it's a string, it should be in the format `[://]
`, where `
` identifies the dispatcher - address and `` can optionally be used to override the default - protocol to use. If it's a tuple, it should be (protocol, address). + address and `` can optionally be used to override the default + protocol to use. If it's a tuple, it should be (protocol, address). job_name: (Optional.) The name of the job. If provided, it must be a non-empty string. This argument makes it possible for multiple datasets to share the same job. The default behavior is that the dataset creates @@ -488,7 +488,8 @@ def _distribute( task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the dispatcher for task changes. data_transfer_protocol: (Optional.) The protocol to use for transferring - data with the tf.data service. By default, data is transferred using gRPC. + data with the tf.data service. If not provided, a protocol is determined + at runtime. compression: How to compress the dataset's elements before transferring them over the network. "AUTO" leaves the decision of how to compress up to the tf.data service runtime. `None` indicates not to compress. @@ -496,7 +497,7 @@ def _distribute( provided, dataset iteration will be shared across concurrently running trainers. See https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers - for details. + for details. target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data runtime decides which workers to read from. If `"ANY"`, reads from any tf.data service workers. If `"LOCAL"`, only reads from local in-processs @@ -724,8 +725,8 @@ def distribute( service: A string or a tuple indicating how to connect to the tf.data service. If it's a string, it should be in the format `[://]
`, where `
` identifies the dispatcher - address and `` can optionally be used to override the default - protocol to use. If it's a tuple, it should be (protocol, address). + address and `` can optionally be used to override the default + protocol to use. If it's a tuple, it should be (protocol, address). job_name: (Optional.) The name of the job. If provided, it must be a non-empty string. This argument makes it possible for multiple datasets to share the same job. The default behavior is that the dataset creates @@ -746,7 +747,8 @@ def distribute( of memory used, since `distribute` won't use more than `element_size` * `max_outstanding_requests` of memory. data_transfer_protocol: (Optional.) The protocol to use for transferring - data with the tf.data service. By default, data is transferred using gRPC. + data with the tf.data service. If not provided, a protocol is determined + at runtime. compression: How to compress the dataset's elements before transferring them over the network. "AUTO" leaves the decision of how to compress up to the tf.data service runtime. `None` indicates not to compress. @@ -754,7 +756,7 @@ def distribute( provided, dataset iteration will be shared across concurrently running trainers. See https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers - for details. + for details. target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data runtime decides which workers to read from. If `"ANY"`, reads from any tf.data service workers. If `"LOCAL"`, only reads from local in-processs @@ -925,8 +927,8 @@ def _from_dataset_id(processing_mode, service: A string or a tuple indicating how to connect to the tf.data service. If it's a string, it should be in the format `[://]
`, where `
` identifies the dispatcher - address and `` can optionally be used to override the default - protocol to use. If it's a tuple, it should be (protocol, address). + address and `` can optionally be used to override the default + protocol to use. If it's a tuple, it should be (protocol, address). dataset_id: The id of the dataset to read from. This id is returned by `register_dataset` when the dataset is registered with the tf.data service. @@ -956,12 +958,13 @@ def _from_dataset_id(processing_mode, task_refresh_interval_hint_ms: (Optional.) A hint for how often to query the dispatcher for task changes. data_transfer_protocol: (Optional.) The protocol to use for transferring - data with the tf.data service. By default, data is transferred using gRPC. + data with the tf.data service. If not provided, a protocol is determined + at runtime. cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is provided, dataset iteration will be shared across concurrently running trainers. See https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers - for details. + for details. target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data runtime decides which workers to read from. If `"ANY"`, reads from any tf.data service workers. If `"LOCAL"`, only reads from local in-processs @@ -1110,8 +1113,8 @@ def from_dataset_id(processing_mode, service: A string or a tuple indicating how to connect to the tf.data service. If it's a string, it should be in the format `[://]
`, where `
` identifies the dispatcher - address and `` can optionally be used to override the default - protocol to use. If it's a tuple, it should be (protocol, address). + address and `` can optionally be used to override the default + protocol to use. If it's a tuple, it should be (protocol, address). dataset_id: The id of the dataset to read from. This id is returned by `register_dataset` when the dataset is registered with the tf.data service. @@ -1139,12 +1142,13 @@ def from_dataset_id(processing_mode, of memory used, since `distribute` won't use more than `element_size` * `max_outstanding_requests` of memory. data_transfer_protocol: (Optional.) The protocol to use for transferring - data with the tf.data service. By default, data is transferred using gRPC. + data with the tf.data service. If not provided, a protocol is determined + at runtime. cross_trainer_cache: (Optional.) If a `CrossTrainerCache` object is provided, dataset iteration will be shared across concurrently running trainers. See https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers - for details. + for details. target_workers: (Optional.) Which workers to read from. If `"AUTO"`, tf.data runtime decides which workers to read from. If `"ANY"`, reads from any tf.data service workers. If `"LOCAL"`, only reads from local in-processs diff --git a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi index e88ec5672773ef..29126c1902939e 100644 --- a/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi +++ b/tensorflow/python/data/experimental/service/_pywrap_utils_exp.pyi @@ -14,4 +14,3 @@ # ============================================================================== def TF_DATA_DefaultProtocol() -> str: ... -def TF_DATA_DisableCompressionAtRegistrationTime() -> bool: ... diff --git a/tensorflow/python/data/experimental/service/utils_wrapper.cc b/tensorflow/python/data/experimental/service/utils_wrapper.cc index f94982931e148b..c725ff3f58ec13 100644 --- a/tensorflow/python/data/experimental/service/utils_wrapper.cc +++ b/tensorflow/python/data/experimental/service/utils_wrapper.cc @@ -23,8 +23,4 @@ limitations under the License. PYBIND11_MODULE(_pywrap_utils_exp, m) { m.def("TF_DATA_DefaultProtocol", []() -> std::string { return tensorflow::data::DefaultProtocol(); }); - - m.def("TF_DATA_DisableCompressionAtRegistrationTime", []() -> bool { - return tensorflow::data::DisableCompressionAtRegistrationTime(); - }); }; diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index 7c20b1fae0bbeb..88de9cd6a5c6b8 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -213,9 +213,11 @@ tf_py_strict_test( name = "concatenate_test", size = "medium", srcs = ["concatenate_test.py"], + shard_count = 20, deps = [ ":checkpoint_test_base", ":test_base", + "//tensorflow/python/data/experimental/ops:global_shuffle_op", "//tensorflow/python/data/experimental/ops:random_access", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:options", diff --git a/tensorflow/python/data/kernel_tests/concatenate_test.py b/tensorflow/python/data/kernel_tests/concatenate_test.py index be0859425f15ca..51c7ca461410f5 100644 --- a/tensorflow/python/data/kernel_tests/concatenate_test.py +++ b/tensorflow/python/data/kernel_tests/concatenate_test.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Tests for `tf.data.Dataset.concatenate().""" +from typing import Callable, Tuple from absl.testing import parameterized import numpy as np +from tensorflow.python.data.experimental.ops import global_shuffle_op from tensorflow.python.data.experimental.ops import random_access from tensorflow.python.data.kernel_tests import checkpoint_test_base from tensorflow.python.data.kernel_tests import test_base @@ -248,5 +250,301 @@ def testConcatenateTwoNonEmptyDatasets(self): self.evaluate(random_access.at(concatenated, index=5)) +class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): + """Tests for global shuffling of tf.data datasets.""" + + @combinations.generate(test_base.default_test_combinations()) + def testShuffledOutput(self): + dataset1 = dataset_ops.Dataset.range(0, 5) + dataset2 = dataset_ops.Dataset.range(5, 17) + + dataset = dataset1.concatenate(dataset2) + + dataset = global_shuffle_op._global_shuffle(dataset) + + output = self.getDatasetOutput(dataset, requires_initialization=True) + self.assertCountEqual(output, range(0, 17)) + + @combinations.generate(test_base.default_test_combinations()) + def testShuffledWithBatchOutput(self): + """Testing with `.batch()` ensures the global shuffle map is stateless.""" + dataset1 = dataset_ops.Dataset.range(0, 4) + dataset2 = dataset_ops.Dataset.range(4, 10) + + dataset = dataset1.concatenate(dataset2) + dataset = dataset.batch(3, drop_remainder=True) + + dataset = global_shuffle_op._global_shuffle(dataset) + + got = self.getDatasetOutput(dataset, requires_initialization=True) + expected = [ + np.array([0, 1, 2], dtype=np.int32), + np.array([3, 4, 5], dtype=np.int32), + np.array([6, 7, 8], dtype=np.int32), + ] + + self.assertIsInstance(got, list) + # Converts to tuples for lexicographically sort + got.sort(key=tuple) + + self.assertLen(got, len(expected)) + + for element_got, element_expected in zip(got, expected): + self.assertAllEqual(element_got, element_expected) + + @combinations.generate(test_base.default_test_combinations()) + def testNestedConcatenateShuffledOutput(self): + dataset1 = dataset_ops.Dataset.range(0, 3) + dataset2 = dataset_ops.Dataset.range(3, 6) + dataset3 = dataset_ops.Dataset.range(6, 9) + + dataset = dataset1.concatenate(dataset2) + dataset = dataset.concatenate(dataset3) + + dataset = global_shuffle_op._global_shuffle(dataset) + + output = self.getDatasetOutput(dataset, requires_initialization=True) + self.assertCountEqual(output, range(0, 9)) + + +class ConcatenateGlobalShuffleCheckpointTest( + checkpoint_test_base.CheckpointTestBase, parameterized.TestCase +): + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_ranges=[(10, 8), (9, 5), (4, 7), (5, 8)], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False], + ), + ) + ) + def testConcatenate( + self, + verify_fn: Callable[..., None], + dataset_ranges: Tuple[int, int], + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool, + ): + + def _build_dataset(): + first_dataset = dataset_ops.Dataset.range(dataset_ranges[0]) + second_dataset = dataset_ops.Dataset.range( + dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1] + ) + dataset = first_dataset.concatenate(second_dataset) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration + ) + + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=sum(dataset_ranges), + assert_items_equal=reshuffle_each_iteration, + ) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_ranges=[(10, 8, 11), (9, 5, 3)], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False], + ), + ) + ) + def testNestedConcatenate( + self, + verify_fn: Callable[..., None], + dataset_ranges: Tuple[int, int], + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool, + ): + + def _build_dataset(): + first_dataset = dataset_ops.Dataset.range(dataset_ranges[0]) + second_dataset = dataset_ops.Dataset.range( + dataset_ranges[0], dataset_ranges[0] + dataset_ranges[1] + ) + third_dataset = dataset_ops.Dataset.range( + sum(dataset_ranges[:2]), sum(dataset_ranges[:3]) + ) + + dataset = first_dataset.concatenate(second_dataset) + dataset = dataset.concatenate(third_dataset) + + dataset = global_shuffle_op._global_shuffle( + dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration + ) + + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=sum(dataset_ranges), + assert_items_equal=reshuffle_each_iteration, + ) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_ranges=[(3, 4, 6, 5)], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False], + ), + ) + ) + def testFourNestedConcatenate( + self, + verify_fn: Callable[..., None], + dataset_ranges: Tuple[int, int], + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool, + ): + def _build_dataset(): + first_dataset = dataset_ops.Dataset.range(dataset_ranges[0]) + second_dataset = dataset_ops.Dataset.range( + dataset_ranges[0], sum(dataset_ranges[:2]) + ) + third_dataset = dataset_ops.Dataset.range( + sum(dataset_ranges[:2]), sum(dataset_ranges[:3]) + ) + fourth_dataset = dataset_ops.Dataset.range( + sum(dataset_ranges[:3]), sum(dataset_ranges) + ) + + left = first_dataset.concatenate(second_dataset) + right = third_dataset.concatenate(fourth_dataset) + + dataset = left.concatenate(right) + dataset = global_shuffle_op._global_shuffle( + dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration + ) + + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=sum(dataset_ranges), + assert_items_equal=reshuffle_each_iteration, + ) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_ranges=[(1, 2, 3, 4, 5, 6)], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False], + ), + ) + ) + def testDeepConcatenate( + self, + verify_fn: Callable[..., None], + dataset_ranges: Tuple[int, ...], + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool, + ): + def _build_dataset(): + prefix_sums = [0] * (len(dataset_ranges) + 1) + for i, value in enumerate(dataset_ranges): + prefix_sums[i + 1] = prefix_sums[i] + value + + dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1]) + for i in range(1, len(dataset_ranges)): + to_concat = dataset_ops.Dataset.range( + prefix_sums[i], prefix_sums[i + 1] + ) + dataset = dataset.concatenate(to_concat) + + dataset = global_shuffle_op._global_shuffle( + dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration + ) + + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=sum(dataset_ranges), + assert_items_equal=reshuffle_each_iteration, + ) + + @combinations.generate( + combinations.times( + test_base.default_test_combinations(), + checkpoint_test_base.default_test_combinations(), + combinations.combine( + dataset_ranges=[(1, 2, 3, 4, 5, 6)], + reshuffle_each_iteration=[True, False], + symbolic_checkpoint=[True, False], + ), + ) + ) + def testDeepConcatenateWithBatchAndPrefetch( + self, + verify_fn: Callable[..., None], + dataset_ranges: Tuple[int, ...], + reshuffle_each_iteration: bool, + symbolic_checkpoint: bool, + ): + def _build_dataset(): + prefix_sums = [0] * (len(dataset_ranges) + 1) + for i, value in enumerate(dataset_ranges): + prefix_sums[i + 1] = prefix_sums[i] + value + + dataset = dataset_ops.Dataset.range(prefix_sums[0], prefix_sums[1]) + for i in range(1, len(dataset_ranges)): + to_concat = dataset_ops.Dataset.range( + prefix_sums[i], prefix_sums[i + 1] + ) + dataset = dataset.concatenate(to_concat) + + dataset = dataset.batch(2, drop_remainder=True) + dataset = dataset.prefetch(buffer_size=dataset_ops.AUTOTUNE) + + dataset = global_shuffle_op._global_shuffle( + dataset, seed=10, reshuffle_each_iteration=reshuffle_each_iteration + ) + dataset = dataset.unbatch() + + options = options_lib.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_symbolic_checkpoint = symbolic_checkpoint + return dataset.with_options(options) + + verify_fn( + self, + _build_dataset, + num_outputs=(sum(dataset_ranges) // 2) * 2, + assert_items_equal=reshuffle_each_iteration, + ) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py index 00299dd6002edb..c2c5f266acbb72 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_test.py @@ -15,6 +15,7 @@ """Tests for `tf.data.Dataset.flat_map()`.""" import random from typing import Callable, Optional +import unittest from absl.testing import parameterized import numpy as np @@ -466,6 +467,10 @@ def my_map(x): verify_fn(self, build_dataset, num_outputs=3 * 4 - num_skips) +@unittest.skip( + "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed." + " Please use concatenate dataset op plus global shuffling instead." +) class FlatMapGlobalShuffleTest( test_base.DatasetTestBase, parameterized.TestCase): @@ -511,6 +516,10 @@ def testInputCardinalityTooLarge(self): self.getDatasetOutput(dataset, requires_initialization=True) +@unittest.skip( + "TODO: b/355241367 - `flat_map_dataset_op.cc` still needs to be fixed." + " Please use concatenate dataset op plus global shuffling instead." +) class FlatMapGlobalShuffleCheckpointTest( checkpoint_test_base.CheckpointTestBase, parameterized.TestCase ): diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index 2d00e6bcaf7aa9..2e58aebc0698b6 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -262,13 +262,15 @@ def iterator_thread(): self.assertAllEqual(component[i]**2, result_component) def _parallel_map_dataset_factory(self, components, apply_map, count, - num_parallel_calls, buffer_size): + num_parallel_calls, buffer_size, + use_unbounded_threadpool=False): def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) dataset = dataset_ops.Dataset.from_tensor_slices(components) - dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls) + dataset = apply_map(dataset, _map_fn, num_parallel_calls=num_parallel_calls, + use_unbounded_threadpool=use_unbounded_threadpool) dataset = dataset.prefetch(buffer_size).repeat(count) self.assertEqual( @@ -284,8 +286,10 @@ def _map_fn(x, y, z): combinations.combine(num_parallel_calls=2, buffer_size=2) + combinations.combine(num_parallel_calls=2, buffer_size=4) + combinations.combine(num_parallel_calls=8, buffer_size=8) + - combinations.combine(num_parallel_calls=8, buffer_size=16))) - def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size): + combinations.combine(num_parallel_calls=8, buffer_size=16), + combinations.combine(use_unbounded_threadpool=[None, True, False]))) + def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size, + use_unbounded_threadpool): """Test an dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> @@ -296,7 +300,8 @@ def testParallelMapDataset(self, apply_map, num_parallel_calls, buffer_size): # Test single-threaded access to the iterator. get_next = self.getNext( self._parallel_map_dataset_factory(components, apply_map, 14, - num_parallel_calls, buffer_size)) + num_parallel_calls, buffer_size, + use_unbounded_threadpool)) for _ in range(14): for i in range(7): result = self.evaluate(get_next()) @@ -1535,6 +1540,20 @@ def testCheckpointLargeBuffer(self): del iterator manager.restore_or_initialize() + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine( + use_unbounded_threadpool=[True, False]))) + def testAutotuneUseUnboundedThreadpool(self, use_unbounded_threadpool): + dataset = dataset_ops.Dataset.range(100) + dataset = dataset.map( + lambda x: x * 2, + num_parallel_calls=dataset_ops.AUTOTUNE, + use_unbounded_threadpool=use_unbounded_threadpool, + deterministic=True, + name="map") + self.assertDatasetProduces(dataset, [x * 2 for x in range(100)]) + @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(num_parallel_calls=[None, 1]))) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b2580e31d96701..c8d01f21c9d0d3 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2157,6 +2157,7 @@ def map( num_parallel_calls=None, deterministic=None, synchronous=None, + use_unbounded_threadpool=False, name=None, ) -> "DatasetV2": """Maps `map_func` across the elements of this dataset. @@ -2313,6 +2314,11 @@ def map( saving memory, since even setting `num_parallel_calls=1` will cause one batch to be buffered, while with `synchronous=True` the map transformation doesn't buffer anything. + use_unbounded_threadpool: (Optional.) By default, map functions run in a + limited threadpool based on the number of cores on the machine. This + efficient for CPU-heavy processing, but if the map function performs IO + it is better to use an unbounded threadpool by setting it to `True`. It + is `False` by default. name: (Optional.) A name for the tf.data operation. Returns: @@ -2329,6 +2335,7 @@ def map( num_parallel_calls=num_parallel_calls, deterministic=deterministic, synchronous=synchronous, + use_unbounded_threadpool=use_unbounded_threadpool, name=name, ) # pylint: enable=g-import-not-at-top,protected-access @@ -4092,6 +4099,7 @@ def map( num_parallel_calls=None, deterministic=None, synchronous=None, + use_unbounded_threadpool=False, name=None, ): # Loaded lazily due to a circular dependency (dataset_ops -> map_op -> @@ -4105,12 +4113,17 @@ def map( num_parallel_calls=num_parallel_calls, deterministic=deterministic, synchronous=synchronous, + use_unbounded_threadpool=use_unbounded_threadpool, ) # pylint: enable=g-import-not-at-top,protected-access @deprecation.deprecated(None, "Use `tf.data.Dataset.map()") def map_with_legacy_function( - self, map_func, num_parallel_calls=None, deterministic=None + self, + map_func, + num_parallel_calls=None, + deterministic=None, + use_unbounded_threadpool=False, ) -> "DatasetV1Adapter": """Maps `map_func` across the elements of this dataset. @@ -4133,6 +4146,11 @@ def map_with_legacy_function( elements out of order to trade determinism for performance. If not specified, the `tf.data.Options.deterministic` option (`True` by default) controls the behavior. + use_unbounded_threadpool: (Optional.) By default, map functions run in a + limited threadpool based on the number of cores on the machine. This + efficient for CPU-heavy processing, but if the map function performs IO + it is better to use an unbounded threadpool by setting it to `True`. It + is `False` by default. Returns: Dataset: A `Dataset`. diff --git a/tensorflow/python/data/ops/map_op.py b/tensorflow/python/data/ops/map_op.py index 0a056abcddd78d..f301dee2557eca 100644 --- a/tensorflow/python/data/ops/map_op.py +++ b/tensorflow/python/data/ops/map_op.py @@ -30,6 +30,7 @@ def _map_v2( num_parallel_calls=None, deterministic=None, synchronous=None, + use_unbounded_threadpool=None, name=None, ): """See `Dataset.map()` for details.""" @@ -59,6 +60,7 @@ def _map_v2( num_parallel_calls=num_parallel_calls, deterministic=deterministic, preserve_cardinality=True, + use_unbounded_threadpool=use_unbounded_threadpool, name=name) @@ -68,6 +70,7 @@ def _map_v1( num_parallel_calls=None, deterministic=None, synchronous=None, + use_unbounded_threadpool=None, # pylint: disable=unused-argument ): """See `Dataset.map()` for details.""" if num_parallel_calls is None or debug_mode.DEBUG_MODE: @@ -92,7 +95,8 @@ def _map_v1( map_func, num_parallel_calls, deterministic, - preserve_cardinality=False)) + preserve_cardinality=False, + use_unbounded_threadpool=False)) def _map_v1_with_legacy_function( # pylint: disable=unused-private-name @@ -130,7 +134,8 @@ def _map_v1_with_legacy_function( # pylint: disable=unused-private-name num_parallel_calls, deterministic, preserve_cardinality=False, - use_legacy_function=True)) + use_legacy_function=True, + use_unbounded_threadpool=False)) class _MapDataset(dataset_ops.UnaryDataset): @@ -189,6 +194,7 @@ def __init__(self, use_inter_op_parallelism=True, preserve_cardinality=False, use_legacy_function=False, + use_unbounded_threadpool=False, name=None): """See `Dataset.map()` for details.""" self._input_dataset = input_dataset @@ -207,6 +213,7 @@ def __init__(self, self._preserve_cardinality = preserve_cardinality self._num_parallel_calls = ops.convert_to_tensor( num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls") + self._use_unbounded_threadpool = use_unbounded_threadpool self._name = name variant_tensor = gen_dataset_ops.parallel_map_dataset_v2( input_dataset._variant_tensor, # pylint: disable=protected-access @@ -216,6 +223,7 @@ def __init__(self, deterministic=self._deterministic, use_inter_op_parallelism=self._use_inter_op_parallelism, preserve_cardinality=self._preserve_cardinality, + use_unbounded_threadpool=self._use_unbounded_threadpool, **self._common_args) super().__init__(input_dataset, variant_tensor) diff --git a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc index cd3d2eeb1d1c8a..1f334a09464de4 100644 --- a/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc +++ b/tensorflow/python/distribute/parallel_device/pywrap_parallel_device.cc @@ -19,15 +19,9 @@ limitations under the License. #include "Python.h" #include "pybind11/pybind11.h" // from @pybind11 #include "pybind11/stl.h" // from @pybind11 -#include "tensorflow/c/c_api.h" -#include "tensorflow/c/c_api_experimental.h" -#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/parallel_device/parallel_device.h" -#include "tensorflow/c/safe_ptr.h" -#include "tensorflow/python/lib/core/py_exception_registry.h" #include "tensorflow/python/lib/core/pybind11_lib.h" -#include "tensorflow/python/lib/core/pybind11_status.h" #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" namespace py = pybind11; diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py index 532d7f1555521f..8c49758d560dcd 100644 --- a/tensorflow/python/eager/tensor_test.py +++ b/tensorflow/python/eager/tensor_test.py @@ -508,11 +508,12 @@ def testEagerTensorFormatForVariant(self): f"{t!r}", ">") def testNumpyTooManyDimensions(self): - t = constant_op.constant(1., shape=[1] * 33) + max_dims = 64 if np.lib.NumpyVersion(np.__version__) >= "2.0.0.dev0" else 32 + t = constant_op.constant(1., shape=[1] * (max_dims + 1)) with self.assertRaisesRegex( errors.InvalidArgumentError, - "Cannot convert tensor with 33 dimensions to NumPy array. NumPy arrays " - "can have at most 32 dimensions"): + "Cannot convert tensor with %d dimensions to NumPy array. NumPy arrays " + "can have at most %d dimensions"% (max_dims + 1, max_dims)): t.numpy() def testNumpyDimsTooBig(self): diff --git a/tensorflow/python/flags_pybind.pyi b/tensorflow/python/flags_pybind.pyi index b34ed2f4b68c19..7c450b682a40a8 100644 --- a/tensorflow/python/flags_pybind.pyi +++ b/tensorflow/python/flags_pybind.pyi @@ -24,6 +24,7 @@ class Flags: enable_function_pruning_before_inlining: Flag enable_nested_function_shape_inference: Flag enable_quantized_dtypes_training: Flag + enable_skip_encapsulation_for_non_tpu_graphs: Flag enable_tf2min_ici_weight: Flag graph_building_optimization: Flag more_stack_traces: Flag diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 42414fff7cd1ed..9f4996a0c2c59c 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -6,6 +6,7 @@ load("//tensorflow:strict.default.bzl", "py_strict_library", "py_strict_test") # Placeholder: load py_proto_library load( "//tensorflow:tensorflow.bzl", + "if_hermetic_cuda_tools", "if_not_windows", "if_oss", "if_xla_available", @@ -1046,6 +1047,13 @@ tf_python_pybind_extension( "python_api_dispatcher.h", "//tensorflow/python/lib/core:safe_pyobject_ptr_required_hdrs", ], + # This data is needed to add hermetic CUDA tools in python runfiles. + data = if_hermetic_cuda_tools( + [ + "@cuda_nvcc//:ptxas", + "@cuda_nvcc//:nvvm", + ], + ), enable_stub_generation = True, pytype_srcs = [ "_pywrap_python_api_dispatcher.pyi", @@ -2051,6 +2059,7 @@ py_strict_library( "//tensorflow/python/types:internal", "//tensorflow/python/util:compat", "//tensorflow/python/util:nest", + "//tensorflow/python/util:numpy_compat", "//tensorflow/python/util:tf_export", "//third_party/py/numpy", ], diff --git a/tensorflow/python/framework/extension_type_test.py b/tensorflow/python/framework/extension_type_test.py index 0169690eaf3c33..a97180fab1ce8a 100644 --- a/tensorflow/python/framework/extension_type_test.py +++ b/tensorflow/python/framework/extension_type_test.py @@ -130,7 +130,7 @@ def _masked_array_repr(values, mask): """Returns a string representation for a masked numpy array.""" assert len(values) == len(mask) if len(values.shape) == 1: - items = [repr(v) if m else '_' for (v, m) in zip(values, mask)] + items = [repr(v.item()) if m else '_' for (v, m) in zip(values, mask)] else: items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)] return '[%s]' % ', '.join(items) diff --git a/tensorflow/python/framework/offset_counter_helper_test.cc b/tensorflow/python/framework/offset_counter_helper_test.cc index dcf6f7c3f20dd4..ef616a311a2306 100644 --- a/tensorflow/python/framework/offset_counter_helper_test.cc +++ b/tensorflow/python/framework/offset_counter_helper_test.cc @@ -18,10 +18,10 @@ limitations under the License. #include #include "absl/strings/str_format.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/python/framework/op_reg_offset.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 34b1eed754bbed..823ced42bf766e 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -213,7 +213,15 @@ def numpy_text(tensor, is_repr=False) -> str: """Human readable representation of a tensor's numpy value.""" if tensor.dtype.is_numpy_compatible: # pylint: disable=protected-access - text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) + tensor_numpy = tensor._numpy() + if is_repr: + if np.isscalar(tensor_numpy) and not isinstance(tensor_numpy, bytes): + # .item() converts the numpy scalars to python items. + text = repr(tensor_numpy.item()) + else: + text = repr(tensor_numpy) + else: + text = str(tensor_numpy) # pylint: enable=protected-access else: text = "" diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 59fbeb3429c68d..d629fcdbf1787d 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -13,8 +13,10 @@ # limitations under the License. # ============================================================================== """Utilities to create TensorProtos.""" + import typing from typing import Protocol + import numpy as np from tensorflow.core.framework import tensor_pb2 @@ -27,8 +29,10 @@ from tensorflow.python.types import internal from tensorflow.python.util import compat from tensorflow.python.util import nest +from tensorflow.python.util import numpy_compat from tensorflow.python.util.tf_export import tf_export + # Fallback in case fast_tensor_util is not properly compiled. # pylint: disable=g-import-not-at-top try: @@ -519,7 +523,7 @@ def make_tensor_proto(values, dtype=None, shape=None, verify_shape=False, nparray = np.empty(shape, dtype=np_dt) else: _AssertCompatible(values, dtype) - nparray = np.array(values, dtype=np_dt) + nparray = numpy_compat.np_array(values, np_dt) # check to them. # We need to pass in quantized values as tuples, so don't apply the shape if (list(nparray.shape) != _GetDenseDimensions(values) and diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py index 7304f262b720b1..b4d4ed25a950b3 100644 --- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py @@ -377,7 +377,7 @@ def testReverse0DimAuto(self): self.assertAllEqual(x_tf, x_np) def _reverse1DimAuto(self, np_dtype): - x_np = np.array([1, 200, 3, 40, 5], dtype=np_dtype) + x_np = np.array([1, 120, 3, 40, 5], dtype=np_dtype) for use_gpu in [False, True]: for axis_dtype in [dtypes.int32, dtypes.int64]: @@ -388,7 +388,7 @@ def _reverse1DimAuto(self, np_dtype): self.assertAllEqual(x_tf, np.asarray(x_np)[::-1]) def _reverse2DimAuto(self, np_dtype): - x_np = np.array([[1, 200, 3], [4, 5, 60]], dtype=np_dtype) + x_np = np.array([[1, 120, 3], [4, 5, 60]], dtype=np_dtype) for reverse_f in [array_ops.reverse_v2, array_ops.reverse]: for use_gpu in [False, True]: diff --git a/tensorflow/python/kernel_tests/array_ops/shape_ops_test.py b/tensorflow/python/kernel_tests/array_ops/shape_ops_test.py index 083df45d903f5f..210a244f592234 100644 --- a/tensorflow/python/kernel_tests/array_ops/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops/shape_ops_test.py @@ -115,27 +115,29 @@ def _compareSizeSparse(self, x_np, use_gpu=False): self.assertAllEqual(np_ans, result) self.assertShapeEqual(np_ans, tf_ans) - def _testCpu(self, x): + def _testCpu(self, x, compare_sparse): self._compareShape(x, use_gpu=False) self._compareShapeN(x, use_gpu=False) self._compareRank(x, use_gpu=False) self._compareSize(x, use_gpu=False) - self._compareShapeSparse(x, use_gpu=False) - self._compareRankSparse(x, use_gpu=False) - self._compareSizeSparse(x, use_gpu=False) + if compare_sparse: + self._compareShapeSparse(x, use_gpu=False) + self._compareRankSparse(x, use_gpu=False) + self._compareSizeSparse(x, use_gpu=False) - def _testGpu(self, x): + def _testGpu(self, x, compare_sparse): self._compareShape(x, use_gpu=True) self._compareShapeN(x, use_gpu=True) self._compareRank(x, use_gpu=True) self._compareSize(x, use_gpu=True) - self._compareShapeSparse(x, use_gpu=True) - self._compareRankSparse(x, use_gpu=True) - self._compareSizeSparse(x, use_gpu=True) + if compare_sparse: + self._compareShapeSparse(x, use_gpu=True) + self._compareRankSparse(x, use_gpu=True) + self._compareSizeSparse(x, use_gpu=True) - def _testAll(self, x): - self._testCpu(x) - self._testGpu(x) + def _testAll(self, x, compare_sparse=True): + self._testCpu(x, compare_sparse) + self._testGpu(x, compare_sparse) def testBasic(self): self._testAll(np.random.randn(2)) @@ -153,6 +155,29 @@ def testBool(self): self._testAll(np.random.choice((False, True), size=(2, 3, 5, 7, 11))) self._testAll(np.random.choice((False, True), size=(2, 3, 5, 7, 11, 13))) + def testString(self): + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2,)), compare_sparse=False + ) + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2, 3)), compare_sparse=False + ) + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2, 3, 5)), compare_sparse=False + ) + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2, 3, 5, 7)), + compare_sparse=False, + ) + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2, 3, 5, 7, 11)), + compare_sparse=False, + ) + self._testAll( + np.random.choice(["abcd", "efgh"], size=(2, 3, 5, 7, 11, 13)), + compare_sparse=False, + ) + # Disabled because it takes too long to run, but manually verified # as passing at time of writing. def _test64BitOutput(self): diff --git a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py index 88d51257b517be..67eb28739df0b9 100644 --- a/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py +++ b/tensorflow/python/kernel_tests/linalg/linalg_ops_test.py @@ -336,7 +336,7 @@ def expected_pinv(self, a, rcond): a_pinv = np.zeros(s, dtype=a.dtype) for i in np.ndindex(a.shape[:(a.ndim - 2)]): a_pinv[i] = np.linalg.pinv( - a[i], rcond=rcond if isinstance(rcond, float) else rcond[i]) + a[i], rcond=rcond if isinstance(rcond.tolist(), float) else rcond[i]) return a_pinv def test_symmetric(self): diff --git a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py index 7785b1bff2cfa4..226e0bd03daff8 100644 --- a/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/pooling_ops_test.py @@ -1255,7 +1255,7 @@ def _ConstructAndTestGradient(self, err_tolerance = 1e-4 else: if x_init_value is None: - x_init_value = np.asfarray( + x_init_value = np.asarray( np.arange(1, total_size + 1), dtype=np.float32).reshape(input_sizes) func_name = "max_pool" @@ -1333,7 +1333,7 @@ def _ConstructAndTestSecondGradient(self, err_tolerance = 1e-3 else: if x_init_value is None: - x_init_value = np.asfarray( + x_init_value = np.asarray( np.arange(1, total_size + 1), dtype=np.float32).reshape(input_sizes) func_name = "max_pool" diff --git a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py index 4273c209d42213..5e01d981a90062 100644 --- a/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/variables/resource_variable_ops_test.py @@ -348,8 +348,8 @@ def testRepr(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(1) text = "%r" % v - self.assertEqual( - "", text) + error_msg = "" + self.assertEqual(error_msg, text) def testReprUnavailable(self): with context.eager_mode(): diff --git a/tensorflow/python/lib/core/ndarray_tensor.cc b/tensorflow/python/lib/core/ndarray_tensor.cc index 3b350c1633678c..a54719a4e8f91c 100644 --- a/tensorflow/python/lib/core/ndarray_tensor.cc +++ b/tensorflow/python/lib/core/ndarray_tensor.cc @@ -67,7 +67,6 @@ char const* numpy_type_name(int numpy_type) { TYPE_CASE(NPY_DATETIME); TYPE_CASE(NPY_TIMEDELTA); TYPE_CASE(NPY_HALF); - TYPE_CASE(NPY_NTYPES); TYPE_CASE(NPY_NOTYPE); TYPE_CASE(NPY_CHAR); TYPE_CASE(NPY_USERDEF); @@ -76,6 +75,10 @@ char const* numpy_type_name(int numpy_type) { } } +#if NPY_ABI_VERSION < 0x02000000 +#define PyDataType_FIELDS(descr) ((descr)->fields) +#endif // NPY_ABI_VERSION < 0x02000000 + Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, TF_DataType* out_tf_datatype) { PyObject* key; @@ -84,11 +87,11 @@ Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr, // Return an error if the fields attribute is null. // Occurs with an improper conversion attempt to resource. - if (descr->fields == nullptr) { + if (PyDataType_FIELDS(descr) == nullptr) { return errors::Internal("Unexpected numpy data type"); } - if (PyDict_Next(descr->fields, &pos, &key, &value)) { + if (PyDict_Next(PyDataType_FIELDS(descr), &pos, &key, &value)) { // In Python 3, the keys of numpy custom struct types are unicode, unlike // Python 2, where the keys are bytes. const char* key_string = diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc index 4cd43ae0c37d28..1c81b35e48cc5e 100644 --- a/tensorflow/python/lib/core/py_seq_tensor.cc +++ b/tensorflow/python/lib/core/py_seq_tensor.cc @@ -716,16 +716,22 @@ TFE_TensorHandle* PySeqToTFE_TensorHandle(TFE_Context* ctx, PyObject* obj, // These objects are efficiently handled by Numpy. We transform them into // Numpy arrays and handle them in the Numpy case below. Note that Tensors // implement the __array__ function, and will be handled in this shortcut. - Safe_PyObjectPtr array = - make_safe(PyArray_FromArrayAttr(obj, nullptr, nullptr)); - if (array == nullptr) { - return nullptr; + // We used to call PyArray_FromArrayAttr here, but NumPy 2.0 changed its + // semantics such that it errors if a copy of the array is required. + // (Ideally no copy would be needed here, but that would be a larger change.) + Safe_PyObjectPtr array; + if (PyObject_HasAttrString(obj, "__array__")) { + array = make_safe(PyObject_CallMethod(obj, "__array__", nullptr)); + if (array == nullptr) { + return nullptr; + } + if (!PyArray_Check(array.get())) { + PyErr_SetString(PyExc_ValueError, + "Value returned by __array__ is not a NumPy array"); + return nullptr; + } } - if (array.get() == Py_NotImplemented) { - // The Py_NotImplemented returned from PyArray_FromArrayAttr is not - // Py_INCREF'ed, so we don't want the Safe_PyObjectPtr to Py_DECREF it. - array.release(); - + if (!array) { // Try __array_interface__ objects (such as PIL Image). array = make_safe(PyArray_FromInterface(obj)); if (array == nullptr) { diff --git a/tensorflow/python/lib/io/BUILD b/tensorflow/python/lib/io/BUILD index d5e583d642962c..b30b1c521e6dfc 100644 --- a/tensorflow/python/lib/io/BUILD +++ b/tensorflow/python/lib/io/BUILD @@ -62,7 +62,6 @@ py_strict_library( "//tensorflow:__subpackages__", "//tensorflow:internal", "//third_party/cloud_tpu/convergence_tools:__subpackages__", - "//third_party/proto_splitter:__subpackages__", # TODO(b/277279227): remove this dep from proto_splitter "//third_party/py/tf_slim:__subpackages__", ], deps = [ diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 0cb65da0a185a4..29f16a77bbca1d 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -396,7 +396,7 @@ def expand_dims_v2(input, axis, name=None): Given a tensor `input`, this operation inserts a dimension of length 1 at the dimension index `axis` of `input`'s shape. The dimension index follows Python - indexing rules: It's zero-based, a negative index it is counted backward + indexing rules: It's zero-based, and a negative index is counted backward from the end. This operation is useful to: @@ -1276,7 +1276,39 @@ def _maybe_cast(elem): return _maybe_cast -_NON_AUTOPACKABLE_TYPES = set(np.core.numerictypes.ScalarType) +_NON_AUTOPACKABLE_TYPES = set(( + int, + float, + complex, + bool, + bytes, + str, + memoryview, + np.bool_, + np.complex64, + np.clongdouble, + np.complex128, + np.float16, + np.float32, + np.float64, + np.longdouble, + np.int8, + np.int16, + np.int32, + np.int64, + np.longlong, + np.timedelta64, + np.datetime64, + np.object_, + np.bytes_, + np.str_, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + np.ulonglong, + np.void, +)) _NON_AUTOPACKABLE_TYPES.add(np.ndarray) diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py index 83e464d53d1a8d..f1b679b3de40d1 100644 --- a/tensorflow/python/ops/bitwise_ops_test.py +++ b/tensorflow/python/ops/bitwise_ops_test.py @@ -60,7 +60,7 @@ def count_bits(x): for dtype in dtype_list: with self.cached_session(): print("PopulationCount test: ", dtype) - inputs = np.array(raw_inputs, dtype=dtype.as_numpy_dtype) + inputs = np.array(raw_inputs).astype(dtype.as_numpy_dtype) truth = [count_bits(x) for x in inputs] input_tensor = constant_op.constant(inputs, dtype=dtype) popcnt_result = self.evaluate( diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 80463a67efc9ac..4b6b11853d4c4c 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1416,8 +1416,15 @@ def assert_rank_in( except ValueError as e: if e.args[0] == 'Static rank condition failed': raise ValueError( - '%sTensor %s must have rank in %s. Received rank %d, ' - 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) + '%sTensor %s must have rank in %s. Received rank %d, shape %s' + % ( + message, + name, + tuple(r.item() for r in e.args[2]), + e.args[1], + x.get_shape(), + ) + ) else: raise diff --git a/tensorflow/python/ops/gradient_checker_v2_test.py b/tensorflow/python/ops/gradient_checker_v2_test.py index 19835aeb09e4cb..362feab73b70cb 100644 --- a/tensorflow/python/ops/gradient_checker_v2_test.py +++ b/tensorflow/python/ops/gradient_checker_v2_test.py @@ -255,7 +255,8 @@ def f(x): *gradient_checker.compute_gradient(f, [x])) # Typical test would assert error < max_err, so assert this test would # raise AssertionError, since NaN is not < 1.0. - with self.assertRaisesRegex(AssertionError, "nan not less than 1.0"): + error_msg = r"(nan|np.float32\(nan\)) not less than 1.0" + with self.assertRaisesRegex(AssertionError, error_msg): self.assertLess(error, 1.0) def testGradGrad(self): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 510865596fe4b5..d55366762b8a11 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2010,7 +2010,12 @@ def range(start, limit=None, delta=1, dtype=None, name="range"): # pylint: disa # infer dtype if not explicitly provided if dtype is None: dtype_hierarchy = [ - dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 + dtypes.int32, + dtypes.int64, + dtypes.float16, + dtypes.bfloat16, + dtypes.float32, + dtypes.float64, ] assert all(arg.dtype in dtype_hierarchy for arg in [start, limit, delta]) inferred_dtype = max([arg.dtype for arg in [start, limit, delta]], diff --git a/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py b/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py index 92a58e8190b275..4957ed02cce806 100644 --- a/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py +++ b/tensorflow/python/ops/ragged/convert_to_tensor_or_ragged_tensor_op_test.py @@ -130,7 +130,8 @@ def testConvertRaggedTensorValue(self, value=ragged_factory_ops.constant_value([['a', 'b'], ['c']], dtype=str), dtype=dtypes.int32, - message=r"invalid literal for int\(\) with base 10: 'a'"), + message=(r"invalid literal for int\(\) with base 10: " + r"('a'|np.str_\('a'\))")), ]) def testConvertRaggedTensorValueError(self, value, @@ -216,7 +217,8 @@ def testConvertNumpyArray(self, dict( value=np.array([['a', 'b'], ['c', 'd']], dtype=str), dtype=dtypes.int32, - message=r"invalid literal for int\(\) with base 10: 'a'"), + message=(r"invalid literal for int\(\) with base 10: " + r"('a'|np.str_\('a'\))")), ]) def testConvertNumpyArrayError(self, value, diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index 9e096e01b56d7a..215304c867507c 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -150,14 +150,19 @@ def _ragged_factory(values, row_splits): return ragged_tensor_value.RaggedTensorValue(values, row_splits) def _inner_factory(pylist, dtype, shape, name=None): # pylint: disable=unused-argument - return np.reshape(np.array(pylist, dtype=dtype), shape) + if dtype is object or dtype is None: + return np.reshape(np.array(pylist, dtype=dtype), shape) + else: + return np.reshape(np.array(pylist).astype(dtype), shape) - return _constant_value(_ragged_factory, _inner_factory, pylist, dtype, - ragged_rank, inner_shape) + return _constant_value( + _ragged_factory, _inner_factory, pylist, dtype, ragged_rank, inner_shape + ) -def _constant_value(ragged_factory, inner_factory, pylist, dtype, ragged_rank, - inner_shape): +def _constant_value( + ragged_factory, inner_factory, pylist, dtype, ragged_rank, inner_shape +): """Constructs a constant RaggedTensor or RaggedTensorValue. Args: diff --git a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py index d6b7d12999ba52..03b864b01d86dd 100644 --- a/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py +++ b/tensorflow/python/ops/v1_compat_tests/gradient_checker_test.py @@ -192,7 +192,8 @@ def testNaNGradFails(self): error = gradient_checker.compute_gradient_error(x, (), y, ()) # Typical test would assert error < max_err, so assert this test would # raise AssertionError, since NaN is not < 1.0. - with self.assertRaisesRegex(AssertionError, "False is not true"): + error_msg = "(False|np.False_) is not true" + with self.assertRaisesRegex(AssertionError, error_msg): self.assertTrue(error < 1.0) diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index 0ca7e7bfae738f..7c6c086871fcae 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -225,7 +225,6 @@ py_strict_library( "//tensorflow_models:__subpackages__", "//third_party/cloud_tpu/convergence_tools:__subpackages__", "//third_party/mlperf:__subpackages__", - "//third_party/proto_splitter:__subpackages__", # TODO(b/277279227): remove this dep from proto_splitter "//third_party/py/tf_slim:__subpackages__", ], deps = [ @@ -301,7 +300,7 @@ py_strict_library( py_strict_library( name = "gfile", srcs = ["gfile.py"], - visibility = visibility + ["//third_party/py/tf_slim/training:__pkg__"], + visibility = visibility, deps = [ "//tensorflow/python/lib/io:file_io", "//tensorflow/python/util:deprecation", diff --git a/tensorflow/python/util/BUILD b/tensorflow/python/util/BUILD index cfc2790371eed4..ee56bb821a2f30 100644 --- a/tensorflow/python/util/BUILD +++ b/tensorflow/python/util/BUILD @@ -889,6 +889,20 @@ py_strict_library( ], ) +py_strict_library( + name = "numpy_compat", + srcs = ["numpy_compat.py"], + compatible_with = get_compatible_with_portable(), + visibility = util_subpackage_visibility, + deps = [ + # global_test_configuration is added here because all major tests depend on this + # library. It isn't possible to add these test dependencies via tensorflow.bzl's + # py test because not all tensorflow tests use tensorflow.bzl's py test. + "//tensorflow/python:global_test_configuration", + "//third_party/py/numpy", + ], +) + py_strict_library( name = "object_identity", srcs = ["object_identity.py"], diff --git a/tensorflow/python/util/numpy_compat.py b/tensorflow/python/util/numpy_compat.py new file mode 100644 index 00000000000000..87a705066a8273 --- /dev/null +++ b/tensorflow/python/util/numpy_compat.py @@ -0,0 +1,66 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functions for NumPy 1.x vs. 2.x compatibility.""" + +import numpy as np + + +def np_array(values, dtype): + """Creates a NumPy array containing input values. + + It will make a copy of the object. + + In NumPy 2.x and later, strict type casting can lead to errors when values + overflow the specified dtype. This function addresses this by replacing direct + np.array(..., dtype=...) calls with np.array(...).astype(...). This allows for + intended overflows, aligning with the behavior of older NumPy versions. + + Args: + values: Array_like objects. E.g., a python list, tuple, or an object + whose __array__ method returns an array. + dtype: The desired numpy data type for the array. + + Returns: + A NumPy array with the specified data type. + """ + if dtype is not None and np.issubdtype(dtype, np.number): + return np.array(values).astype(dtype) + else: + return np.array(values, dtype=dtype) + + +def np_asarray(values, dtype): + """Converts input values to a NumPy array. + + It will not make a copy. + + In NumPy 2.x and later, strict type casting can lead to errors when values + overflow the specified dtype. This function addresses this by replacing direct + np.array(..., dtype=...) calls with np.array(...).astype(...). This allows for + intended overflows, aligning with the behavior of older NumPy versions. + + Args: + values: Array_like objects. E.g., a python list, tuple, or an object + whose __array__ method returns an array. + dtype: The desired numpy data type for the array. + + Returns: + A NumPy array with the specified data type. + """ + if dtype is not None and np.issubdtype(dtype, np.number): + return np.asarray(values).astype(dtype) + else: + return np.asarray(values, dtype=dtype) diff --git a/tensorflow/python/util/tf_stack.cc b/tensorflow/python/util/tf_stack.cc index 851bbe57079485..5cfaf5145155b3 100644 --- a/tensorflow/python/util/tf_stack.cc +++ b/tensorflow/python/util/tf_stack.cc @@ -117,6 +117,21 @@ class StackTraceWrapper : public AbstractStackTrace { return cache_->ToFrames(); } + std::vector ToUncachedFrames() const override { + std::vector frames = captured_->ToStackFrames( + *source_map_, [&](const char* f) { return StackTraceFiltering(f); }, + /*reverse_traversal=*/false, /*limit=*/-1); + + // Drop last stack frames. + int newsize = frames.size() - stacklevel_; + if (newsize < 0) { + newsize = 0; + } + frames.resize(newsize); + + return frames; + } + std::vector GetUserFrames(int limit) const override { ComputeFrozen(); return cache_->GetUserFrames(limit); @@ -139,16 +154,7 @@ class StackTraceWrapper : public AbstractStackTrace { return; } - std::vector frames = captured_->ToStackFrames( - *source_map_, [&](const char* f) { return StackTraceFiltering(f); }, - /*reverse_traversal=*/false, /*limit=*/-1); - - // Drop last stack frames. - int newsize = frames.size() - stacklevel_; - if (newsize < 0) { - newsize = 0; - } - frames.resize(newsize); + std::vector frames = ToUncachedFrames(); std::vector user_frames = captured_->ToStackFrames( *source_map_, diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index d1e35f02a0f190..c1c50767cdad69 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -70,6 +70,7 @@ load( "tsl_gpu_library", _cc_header_only_library = "cc_header_only_library", _if_cuda_or_rocm = "if_cuda_or_rocm", + _if_hermetic_cuda_tools = "if_hermetic_cuda_tools", _if_nccl = "if_nccl", _transitive_hdrs = "transitive_hdrs", ) @@ -800,7 +801,7 @@ def tf_cc_shared_object( testonly = kwargs.pop("testonly", False) for name_os, name_os_major, name_os_full in names: - # Windows DLLs cant be versioned + # Windows DLLs can't be versioned if name_os.endswith(".dll"): name_os_major = name_os name_os_full = name_os @@ -1075,7 +1076,8 @@ def tf_cc_binary( ], ), tags = tags, - data = depset(data + added_data_deps), + data = depset(data + added_data_deps).to_list() + + tf_binary_additional_srcs(fullversion = True), linkopts = linkopts + _rpath_linkopts(name_os), visibility = visibility, **kwargs @@ -1568,7 +1570,7 @@ def tf_cc_test( ), data = data + tf_binary_dynamic_kernel_dsos() + - tf_binary_additional_srcs(), + tf_binary_additional_srcs(fullversion = True), exec_properties = tf_exec_properties(kwargs), **kwargs ) @@ -1733,6 +1735,7 @@ def tf_gpu_only_cc_test( tf_gpu_kernel_library( name = gpu_lib_name, srcs = srcs + tf_binary_additional_srcs(), + data = tf_binary_additional_srcs(fullversion = True), deps = deps, testonly = 1, features = features, @@ -3574,3 +3577,6 @@ def replace_with_portable_tf_lib_when_required(non_portable_tf_deps, use_lib_wit def tf_python_framework_friends(): return ["//tensorflow:__subpackages__"] + +def if_hermetic_cuda_tools(if_true, if_false = []): + return _if_hermetic_cuda_tools(if_true, if_false) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt index 17a2776dbd3f05..b43fcc3ed54bdf 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt @@ -129,11 +129,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt index 88a2e4f5bd80f9..e4b1c197f66654 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt index b71ecb13e4622a..999d0c06283825 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt index e4f1c852a4196a..123c3b7035b83a 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt index 266a8f1b533347..2b8fdc367b98e1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt index aff5ee5f057b45..77b396a622ceec 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt index 6986dab3471f78..48c85404aebb4f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -131,11 +131,11 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "map_with_legacy_function" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'use_unbounded_threadpool\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 81efb179bd9a62..967dc41d14a5b3 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -3030,7 +3030,7 @@ tf_module { } member_method { name: "ParallelMapDatasetV2" - argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], " + argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], " } member_method { name: "ParameterizedTruncatedNormal" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt index 1852f769ed71f5..00f5d54a037b33 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt @@ -100,7 +100,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt index 5f1f368791d423..9230f451321399 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt index 85eb6963b7a9b0..e89d811f33777b 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt @@ -101,7 +101,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt index 42a293db5da768..c936ff4467c911 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt index d376170fef7615..f31da9c3f73785 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt index 190a21f261d5a5..8faf54322ea3de 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt @@ -103,7 +103,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt index e19f932c068ab6..490b7f7d640ff3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt @@ -102,7 +102,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt index d6d643970fce20..c2194abf7ff3ba 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.dtensor.-d-tensor-dataset.pbtxt @@ -103,7 +103,7 @@ tf_class { } member_method { name: "map" - argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\', \'deterministic\', \'synchronous\', \'use_unbounded_threadpool\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " } member_method { name: "options" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 81efb179bd9a62..967dc41d14a5b3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -3030,7 +3030,7 @@ tf_module { } member_method { name: "ParallelMapDatasetV2" - argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'\', \'None\'], " + argspec: "args=[\'input_dataset\', \'other_arguments\', \'num_parallel_calls\', \'f\', \'output_types\', \'output_shapes\', \'use_inter_op_parallelism\', \'deterministic\', \'preserve_cardinality\', \'use_unbounded_threadpool\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'default\', \'False\', \'False\', \'\', \'None\'], " } member_method { name: "ParameterizedTruncatedNormal" diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc index c6e42840c6a689..1b1b443eb3c14b 100644 --- a/tensorflow/tools/benchmark/benchmark_model_test.cc +++ b/tensorflow/tools/benchmark/benchmark_model_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/math_ops.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -29,7 +30,6 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/stat_summarizer.h" -#include "tsl/lib/core/status_test_util.h" namespace tensorflow { namespace { diff --git a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh index 72a228f92051d9..992aa6d94c89b0 100644 --- a/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh +++ b/tensorflow/tools/ci_build/build_scripts/ARM_SKIP_TESTS.sh @@ -16,5 +16,4 @@ set -x ARM_SKIP_TESTS="-//tensorflow/lite/... \ --//tensorflow/core/kernels/image:resize_bicubic_op_test \ " diff --git a/tensorflow/tools/graph_transforms/backports_test.cc b/tensorflow/tools/graph_transforms/backports_test.cc index 80a954e062b069..155ec29e93687c 100644 --- a/tensorflow/tools/graph_transforms/backports_test.cc +++ b/tensorflow/tools/graph_transforms/backports_test.cc @@ -192,7 +192,7 @@ TEST(BackportTensorArrayV3Test, TestBackportTensorArrayV3Subtypes) { std::map node_lookup; MapNamesToNodes(result, &node_lookup); ASSERT_EQ(1, node_lookup.count("v3_node")); - EXPECT_TRUE(str_util::EndsWith(node_lookup.at("v3_node")->op(), "V2")); + EXPECT_TRUE(absl::EndsWith(node_lookup.at("v3_node")->op(), "V2")); } } diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index dcdc3c29069c21..3d388cd665499f 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -210,10 +210,10 @@ class ConstantFoldingTest : public ::testing::Test { for (const NodeDef& node : graph_def.node()) { const StringPiece name(node.name()); const int occurrence_count = folded_node_map.count(node.name()); - if (str_util::EndsWith(name, "expect_removed")) { + if (absl::EndsWith(name, "expect_removed")) { EXPECT_EQ(0, occurrence_count) << "node.name()=" << node.name(); } - if (str_util::EndsWith(name, "expect_remains")) { + if (absl::EndsWith(name, "expect_remains")) { EXPECT_EQ(1, occurrence_count) << "node.name()=" << node.name(); } } diff --git a/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt b/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt index 9ac7ee9b800fd7..d78fcfe8a5a965 100644 --- a/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt +++ b/tensorflow/tools/pip_package/THIRD_PARTY_NOTICES.txt @@ -9376,4 +9376,28 @@ OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --------------------------------------------------------------------------------- \ No newline at end of file +-------------------------------------------------------------------------------- + +-------------------------------------------------------------------------------- +== libuv + +Copyright (c) 2015-present libuv project contributors. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to +deal in the Software without restriction, including without limitation the +rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +sell copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +IN THE SOFTWARE. +-------------------------------------------------------------------------------- diff --git a/tensorflow/tools/pip_package/build_pip_package.py b/tensorflow/tools/pip_package/build_pip_package.py index 9588fc19e3d4e9..1846082b8147b8 100644 --- a/tensorflow/tools/pip_package/build_pip_package.py +++ b/tensorflow/tools/pip_package/build_pip_package.py @@ -69,6 +69,36 @@ def prepare_headers(headers: list[str], srcs_dir: str) -> None: srcs_dir: target directory where headers are copied to. """ path_to_exclude = [ + "cuda_cccl/_virtual_includes", + "cuda_cublas/_virtual_includes", + "cuda_cudart/_virtual_includes", + "cuda_cudnn/_virtual_includes", + "cuda_cufft/_virtual_includes", + "cuda_cupti/_virtual_includes", + "cuda_curand/_virtual_includes", + "cuda_cusolver/_virtual_includes", + "cuda_cusparse/_virtual_includes", + "cuda_nccl/_virtual_includes", + "cuda_nvcc/_virtual_includes", + "cuda_nvjitlink/_virtual_includes", + "cuda_nvml/_virtual_includes", + "cuda_nvrtc/_virtual_includes", + "cuda_nvtx/_virtual_includes", + "external/cuda_cccl", + "external/cuda_cublas", + "external/cuda_cudart", + "external/cuda_cudnn", + "external/cuda_cufft", + "external/cuda_cupti", + "external/cuda_curand", + "external/cuda_cusolver", + "external/cuda_cusparse", + "external/cuda_nccl", + "external/cuda_nvcc", + "external/cuda_nvjitlink", + "external/cuda_nvml", + "external/cuda_nvrtc", + "external/cuda_nvtx", "external/pypi", "external/jsoncpp_git/src", "local_config_cuda/cuda/_virtual_includes", diff --git a/tensorflow/tools/proto_splitter/cc/BUILD b/tensorflow/tools/proto_splitter/cc/BUILD index 9672bf1c590975..105cecfae4465c 100644 --- a/tensorflow/tools/proto_splitter/cc/BUILD +++ b/tensorflow/tools/proto_splitter/cc/BUILD @@ -97,9 +97,9 @@ tf_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", + "@local_xla//xla/tsl/lib/core:status_test_util", "@riegeli//riegeli/base:initializer", "@riegeli//riegeli/bytes:cord_reader", "@riegeli//riegeli/bytes:fd_reader", @@ -163,11 +163,11 @@ tf_cc_test( "//tensorflow/tools/proto_splitter/testdata:test_message_proto_cc", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/tsl/lib/core:status_test_util", ] + if_oss([ "//tensorflow/tools/proto_splitter:protos_impl", ]), @@ -179,11 +179,11 @@ cc_library( hdrs = ["graph_def_splitter.h"], deps = [ ":composable_splitter", + ":composable_splitter_base", ":large_node_splitter", ":max_size", ":repeated_field_splitter", ":size_splitter", - ":split", ":util", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", @@ -309,18 +309,15 @@ tf_cc_test( cc_library( name = "large_node_splitter", - srcs = ["large_node_splitter.cc"], hdrs = ["large_node_splitter.h"], deps = [ ":composable_splitter", + ":composable_splitter_base", ":max_size", ":size_splitter", ":util", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc index 55aac6a77821ff..d62acc5eead679 100644 --- a/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/composable_splitter_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "riegeli/bytes/fd_reader.h" // from @riegeli #include "riegeli/bytes/string_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/file_system_helper.h" @@ -39,7 +40,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc index 81e8d5d9a3aec4..7f274734a6b76e 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter.cc @@ -31,18 +31,18 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" #include "tensorflow/tools/proto_splitter/cc/large_node_splitter.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" -#include "tensorflow/tools/proto_splitter/cc/split.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { namespace { @@ -144,7 +144,7 @@ class FunctionDefSplitter : public SizeSplitter { LargeNodeSplitterFactory large_node_splitter_factory; std::vector factories = { &constant_splitter_factory, &large_node_splitter_factory}; - auto ret = RepeatedFieldSplitters::Create( + auto ret = RepeatedFieldSplitter::Create( message(), this, &fields, "node_def"s, &factories); if (!ret.ok()) return ret.status(); auto splitter = ret.value(); @@ -184,7 +184,7 @@ absl::Status GraphDefSplitter::BuildChunks() { LargeNodeSplitterFactory large_node_splitter_factory; std::vector factories = {&constant_splitter_factory, &large_node_splitter_factory}; - auto node_splitter_ret = RepeatedFieldSplitters::Create( + auto node_splitter_ret = RepeatedFieldSplitter::Create( g, this, &field_in_parent, "node"s, &factories); if (!node_splitter_ret.ok()) return node_splitter_ret.status(); auto node_splitter = node_splitter_ret.value(); @@ -193,7 +193,7 @@ absl::Status GraphDefSplitter::BuildChunks() { std::vector library_field = {"library"s}; std::vector fn_factories = {&function_splitter_factory}; auto library_splitter_ret = - RepeatedFieldSplitters::Create( + RepeatedFieldSplitter::Create( g->mutable_library(), this, &library_field, "function"s, &fn_factories); if (!library_splitter_ret.ok()) return library_splitter_ret.status(); @@ -238,5 +238,4 @@ absl::Status GraphDefSplitter::BuildChunks() { return absl::OkStatus(); } -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter diff --git a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc index 1d98a3a390bd33..b5c27118cf0cbc 100644 --- a/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/graph_def_splitter_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include "absl/strings/cord.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -34,7 +35,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/test_util.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" @@ -44,23 +44,6 @@ namespace { using ::tensorflow::proto_splitter::ChunkedMessage; -// Ensures that all Messages are less than the max size. std::string chunks are -// not limited by the max size, so they are ignored in this check. -#define EXPECT_CHUNK_SIZES(chunks, max_size) \ - do { \ - for (auto chunk : *chunks) { \ - if (std::holds_alternative>( \ - chunk)) { \ - EXPECT_LE(std::get>(chunk) \ - ->ByteSizeLong(), \ - max_size); \ - } else if (std::holds_alternative(chunk)) { \ - EXPECT_LE(std::get(chunk)->ByteSizeLong(), \ - max_size); \ - } \ - } \ - } while (0) - TEST(GraphDefSplitterTest, TestLargeConstant) { GraphDef proto; const std::string graph_def_path = diff --git a/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc b/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc deleted file mode 100644 index cf0ff26f51f985..00000000000000 --- a/tensorflow/tools/proto_splitter/cc/large_node_splitter.cc +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/tools/proto_splitter/cc/large_node_splitter.h" - -#include -#include - -#include "absl/memory/memory.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include "tensorflow/tools/proto_splitter/cc/max_size.h" -#include "tensorflow/tools/proto_splitter/cc/size_splitter.h" -#include "tensorflow/tools/proto_splitter/cc/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/protobuf.h" - -namespace tensorflow { -namespace tools::proto_splitter { - -template -absl::StatusOr LargeNodeSplitter::BuildChunksReturnSize() { - MessageType* msg = - tsl::protobuf::DynamicCastToGenerated(message()); - int initial_size = GetInitialSize(); - std::shared_ptr new_msg = std::make_shared(); - msg->Swap(new_msg.get()); - std::vector fields = {}; - auto x = std::make_unique(new_msg); - TF_RETURN_IF_ERROR(AddChunk(std::move(x), &fields, index_)); - return initial_size; -} - -template -absl::StatusOr> -LargeNodeSplitterFactory::CreateSplitter( - tsl::protobuf::Message* message, ComposableSplitterBase* parent_splitter, - std::vector* fields_in_parent, int size) { - if (!(LARGE_SIZE_CHECK(size, GetMaxSize()))) return nullptr; - LargeNodeSplitter* splitter = new LargeNodeSplitter( - message, parent_splitter, fields_in_parent); - return absl::WrapUnique(splitter); -} - -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitter; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; -template class LargeNodeSplitterFactory; - -} // namespace tools::proto_splitter -} // namespace tensorflow diff --git a/tensorflow/tools/proto_splitter/cc/large_node_splitter.h b/tensorflow/tools/proto_splitter/cc/large_node_splitter.h index e5969cf652dd37..15c9964fa44644 100644 --- a/tensorflow/tools/proto_splitter/cc/large_node_splitter.h +++ b/tensorflow/tools/proto_splitter/cc/large_node_splitter.h @@ -20,7 +20,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter_base.h" +#include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace tools::proto_splitter { @@ -40,6 +44,19 @@ class LargeNodeSplitter : public SizeSplitter { int* index_ = nullptr; }; +template +absl::StatusOr LargeNodeSplitter::BuildChunksReturnSize() { + MessageType* msg = + tsl::protobuf::DynamicCastToGenerated(message()); + int initial_size = GetInitialSize(); + std::shared_ptr new_msg = std::make_shared(); + msg->Swap(new_msg.get()); + std::vector fields = {}; + auto x = std::make_unique(new_msg); + TF_RETURN_IF_ERROR(AddChunk(std::move(x), &fields, index_)); + return initial_size; +} + template class LargeNodeSplitterFactory : public SizeSplitterFactory { public: @@ -50,6 +67,17 @@ class LargeNodeSplitterFactory : public SizeSplitterFactory { std::vector* fields_in_parent, int size) override; }; +template +absl::StatusOr> +LargeNodeSplitterFactory::CreateSplitter( + tsl::protobuf::Message* message, ComposableSplitterBase* parent_splitter, + std::vector* fields_in_parent, int size) { + if (!(LARGE_SIZE_CHECK(size, GetMaxSize()))) return nullptr; + LargeNodeSplitter* splitter = new LargeNodeSplitter( + message, parent_splitter, fields_in_parent); + return absl::WrapUnique(splitter); +} + } // namespace tools::proto_splitter } // namespace tensorflow diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc index 01601c7e22a1fc..552009f3916e61 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h" +#include #include #include #include @@ -24,67 +25,63 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" #include "tensorflow/tools/proto_splitter/cc/max_size.h" -#include "tensorflow/tools/proto_splitter/cc/split.h" +#include "tensorflow/tools/proto_splitter/cc/size_splitter.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { // Additional bytes added to each node to account for the extra info needed to // encode the field key (realistically 3 but making it 5 for some wiggle room). constexpr int kExtraBytes = 5; template -absl::StatusOr> -RepeatedFieldSplitters::Create( +absl::StatusOr> +RepeatedFieldSplitter::Create( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories) { - // std::vector all_fields = *fields_in_parent; - // all_fields.push_back(repeated_field); - // std::vector - TF_ASSIGN_OR_RETURN(auto field_ret, GetField(*message, {repeated_field})); if (!field_ret.field->is_repeated()) { return absl::FailedPreconditionError("Unable to split non-repeated field."); } - auto ret = RepeatedFieldSplitters( + auto ret = RepeatedFieldSplitter( message, parent_splitter, fields_in_parent, repeated_field, splitter_factories); return ret; } template -absl::StatusOr RepeatedFieldSplitters< - ParentMessage, RepeatedMessage>::BuildChunksReturnSize() { - // std::vector all_fields = *fields_in_parent(); - // all_fields.push_back(repeated_field_); - - TF_ASSIGN_OR_RETURN(auto ret, GetMutableField(message(), {repeated_field_})); +absl::StatusOr +RepeatedFieldSplitter::BuildChunksReturnSize() { + TF_ASSIGN_OR_RETURN(MutableFieldResult mfr, + GetMutableField(message(), {repeated_field_})); + tsl::protobuf::Message* parent = mfr.parent; + const tsl::protobuf::FieldDescriptor* repeated_field = mfr.field; uint64_t max_size = GetMaxSize(); size_t initial_size = GetInitialSize(); // List of indices at which to split the repeated field. For example, [3, 5] // means that the field list is split into: [:3], [3:5], [5:] - std::vector repeated_msg_split = {0}; + std::vector repeated_msg_split; // Track the total byte size of the current node split. uint64_t total_size = 0; // Linearly iterate through all nodes. It may be possible to optimize this // further by making best guesses as to where to split the nodes, since // most nodes (aside from constants) are relatively small. - int repeated_field_size = - ret.parent->GetReflection()->FieldSize(*ret.parent, ret.field); - for (int i = 0; i < repeated_field_size; ++i) { + int repeated_field_length = + parent->GetReflection()->FieldSize(*parent, repeated_field); + for (int i = 0; i < repeated_field_length; ++i) { tsl::protobuf::Message* node = - ret.parent->GetReflection()->MutableRepeatedMessage(ret.parent, - ret.field, i); + parent->GetReflection()->MutableRepeatedMessage(parent, repeated_field, + i); auto node_size = node->ByteSizeLong(); std::vector new_fields = {repeated_field_, i}; @@ -106,25 +103,20 @@ absl::StatusOr RepeatedFieldSplitters< total_size += node_size + kExtraBytes; } - if (repeated_msg_split.size() > 1) { + if (!repeated_msg_split.empty()) { auto repeated_nodes_ptrs = - ret.parent->GetReflection() - ->template MutableRepeatedPtrField(ret.parent, - ret.field); - - int start = repeated_msg_split[0]; + parent->GetReflection() + ->template MutableRepeatedPtrField(parent, + repeated_field); - std::vector extracted_nodes; - extracted_nodes.resize(repeated_field_size - start); - repeated_nodes_ptrs->ExtractSubrange(start, repeated_field_size - start, + std::vector extracted_nodes(repeated_field_length); + repeated_nodes_ptrs->ExtractSubrange(0, repeated_field_length, &extracted_nodes.at(0)); - repeated_msg_split.push_back(repeated_field_size); - auto extracted_node = extracted_nodes.begin(); - - for (int i = 1; i < repeated_msg_split.size(); ++i) { - start = repeated_msg_split[i - 1]; - int end = repeated_msg_split[i]; + // Last range end is the size of the repeated field. + repeated_msg_split.push_back(repeated_field_length); + int range_start = 0; + for (int range_end : repeated_msg_split) { auto new_msg = std::make_shared(); std::vector empty_fields; auto x = std::make_unique(new_msg); @@ -134,10 +126,12 @@ absl::StatusOr RepeatedFieldSplitters< TF_ASSIGN_OR_RETURN(auto new_ret, GetMutableField(new_msg.get(), repeated_field_)); - for (int j = 0; j < end - start; ++j) { + for (int j = range_start; j < range_end; ++j) { new_msg->GetReflection()->AddAllocatedMessage( - new_msg.get(), new_ret.field, *extracted_node++); + new_msg.get(), new_ret.field, extracted_nodes[j]); } + + range_start = range_end; } } @@ -147,9 +141,8 @@ absl::StatusOr RepeatedFieldSplitters< } // Declare template classes to fix linking error. -template class RepeatedFieldSplitters; -template class RepeatedFieldSplitters; -template class RepeatedFieldSplitters; +template class RepeatedFieldSplitter; +template class RepeatedFieldSplitter; +template class RepeatedFieldSplitter; -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter diff --git a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h index eef7247a1925ef..5395f76ad9b69f 100644 --- a/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h +++ b/tensorflow/tools/proto_splitter/cc/repeated_field_splitter.h @@ -20,25 +20,24 @@ limitations under the License. #include "absl/status/statusor.h" #include "tensorflow/tools/proto_splitter/cc/composable_splitter.h" #include "tensorflow/tools/proto_splitter/cc/size_splitter.h" +#include "tensorflow/tools/proto_splitter/cc/util.h" #include "tsl/platform/protobuf.h" -namespace tensorflow { -namespace tools::proto_splitter { +namespace tensorflow::tools::proto_splitter { // Splitter that works on repeated message fields. template -class RepeatedFieldSplitters : public SizeSplitter { +class RepeatedFieldSplitter : public SizeSplitter { public: - static absl::StatusOr Create( + static absl::StatusOr Create( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories); absl::StatusOr BuildChunksReturnSize() override; - FieldType repeated_field_; private: - explicit RepeatedFieldSplitters( + explicit RepeatedFieldSplitter( tsl::protobuf::Message* message, ComposableSplitter* parent_splitter, std::vector* fields_in_parent, const FieldType& repeated_field, std::vector* splitter_factories) @@ -46,10 +45,10 @@ class RepeatedFieldSplitters : public SizeSplitter { repeated_field_(repeated_field), splitter_factories_(splitter_factories) {} + FieldType repeated_field_; std::vector* splitter_factories_; }; -} // namespace tools::proto_splitter -} // namespace tensorflow +} // namespace tensorflow::tools::proto_splitter #endif // TENSORFLOW_TOOLS_PROTO_SPLITTER_CC_REPEATED_FIELD_SPLITTER_H_ diff --git a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc index b03bcc118f77c3..1712421dfa45bf 100644 --- a/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc +++ b/tensorflow/tools/proto_splitter/cc/saved_model_splitter_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/max_size.h" #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/tools/proto_splitter/cc/test_util.h b/tensorflow/tools/proto_splitter/cc/test_util.h index 9187521fc14712..dd73cbd3bd1b00 100644 --- a/tensorflow/tools/proto_splitter/cc/test_util.h +++ b/tensorflow/tools/proto_splitter/cc/test_util.h @@ -28,6 +28,23 @@ limitations under the License. namespace tensorflow { namespace tools::proto_splitter { +// Ensures that all Messages are less than the max size. std::string chunks are +// not limited by the max size, so they are ignored in this check. +#define EXPECT_CHUNK_SIZES(chunks, max_size) \ + do { \ + for (auto chunk : *chunks) { \ + if (std::holds_alternative>( \ + chunk)) { \ + EXPECT_LE(std::get>(chunk) \ + ->ByteSizeLong(), \ + max_size); \ + } else if (std::holds_alternative(chunk)) { \ + EXPECT_LE(std::get(chunk)->ByteSizeLong(), \ + max_size); \ + } \ + } \ + } while (0) + inline std::string SerializeAsString(const tsl::protobuf::Message& message) { std::string result; { diff --git a/tensorflow/tools/proto_splitter/cc/util_test.cc b/tensorflow/tools/proto_splitter/cc/util_test.cc index e318f7c16f144d..7880519006312c 100644 --- a/tensorflow/tools/proto_splitter/cc/util_test.cc +++ b/tensorflow/tools/proto_splitter/cc/util_test.cc @@ -21,12 +21,12 @@ limitations under the License. #include #include #include "absl/status/status.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/proto_splitter/cc/test_util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" diff --git a/tensorflow/tools/proto_splitter/merge_test.cc b/tensorflow/tools/proto_splitter/merge_test.cc index 5f78f3f2ca7468..06d40d5a55741e 100644 --- a/tensorflow/tools/proto_splitter/merge_test.cc +++ b/tensorflow/tools/proto_splitter/merge_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include #include "absl/strings/str_cat.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/test.h" @@ -31,7 +32,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/cc/util.h" #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/testdata/test_message.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile index 7659bd622579a0..b6d1acab6ae6a6 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile +++ b/tensorflow/tools/tf_sig_build_dockerfiles/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:22.04@sha256:19478ce7fc2ffbce89df29fea5725a8d12e57de52eb9ea570890dc5852aac1ac as builder +FROM ubuntu:22.04@sha256:340d9b015b194dc6e2a13938944e0d016e57b9679963fdeb9ce021daac430221 as builder ################################################################################ # Install devtoolset build dependencies diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats index ecb4e8ec60a7dc..d9157f4e435578 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.usertools/code_check_full.bats @@ -216,6 +216,8 @@ EOF --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ "somepath(//tensorflow/tools/pip_package:wheel, " \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cudart + "\ "@local_config_cuda//cuda:cuda_driver + "\ @@ -236,6 +238,8 @@ EOF bazel cquery \ --experimental_cc_shared_library \ --@local_config_cuda//:enable_cuda \ + --repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ + --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" \ --define framework_shared_object=false \ "somepath(//tensorflow/tools/pip_package:wheel, " \ "@local_config_cuda//cuda:cudart + "\ diff --git a/tensorflow/tools/tfg_graph_transforms/utils.cc b/tensorflow/tools/tfg_graph_transforms/utils.cc index 4e8c030a3c1262..a5b6f0af916518 100644 --- a/tensorflow/tools/tfg_graph_transforms/utils.cc +++ b/tensorflow/tools/tfg_graph_transforms/utils.cc @@ -30,9 +30,9 @@ namespace graph_transforms { namespace { -tsl::StringPiece GetNameWithoutExtension(tsl::StringPiece filename) { +absl::string_view GetNameWithoutExtension(absl::string_view filename) { auto pos = filename.rfind('.'); - if (pos == tsl::StringPiece::npos) return filename; + if (pos == absl::string_view::npos) return filename; return filename.substr(0, pos); } diff --git a/tensorflow/tools/toolchains/remote_config/configs.bzl b/tensorflow/tools/toolchains/remote_config/configs.bzl index f0fa44c759b346..abf72cbc605e91 100644 --- a/tensorflow/tools/toolchains/remote_config/configs.bzl +++ b/tensorflow/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl index ae776c2a2fd388..9c4c93c988901e 100644 --- a/tensorflow/tools/toolchains/remote_config/rbe_config.bzl +++ b/tensorflow/tools/toolchains/remote_config/rbe_config.bzl @@ -1,9 +1,9 @@ """Macro that creates external repositories for remote config.""" load("//tensorflow/tools/toolchains/remote_config:containers.bzl", "containers") -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/tensorflow/workspace2.bzl b/tensorflow/workspace2.bzl index 16c1229536abb0..db96e3fc4383b6 100644 --- a/tensorflow/workspace2.bzl +++ b/tensorflow/workspace2.bzl @@ -29,7 +29,6 @@ load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") load("//third_party/FP16:workspace.bzl", FP16 = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") -load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("//third_party/hexagon:workspace.bzl", hexagon_nn = "repo") @@ -42,7 +41,6 @@ load("//third_party/kissfft:workspace.bzl", kissfft = "repo") load("//third_party/libprotobuf_mutator:workspace.bzl", libprotobuf_mutator = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") -load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/opencl_headers:workspace.bzl", opencl_headers = "repo") load("//third_party/pasta:workspace.bzl", pasta = "repo") load("//third_party/py:python_configure.bzl", "python_configure") @@ -106,9 +104,7 @@ def _tf_toolchains(): # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name = "local_config_clang6") cc_download_clang_toolchain(name = "local_config_download_clang") - cuda_configure(name = "local_config_cuda") tensorrt_configure(name = "local_config_tensorrt") - nccl_configure(name = "local_config_nccl") git_configure(name = "local_config_git") syslibs_configure(name = "local_config_syslibs") python_configure(name = "local_config_python") @@ -154,12 +150,20 @@ def _tf_repositories(): # LINT.IfChange tf_http_archive( name = "XNNPACK", - sha256 = "60a504f285fe529e85f3530d8b9c0e7e42e9c78b87b095e71a4e41b0c6412227", - strip_prefix = "XNNPACK-488a695e3a10269755895da05c2711aadf08489b", - urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/488a695e3a10269755895da05c2711aadf08489b.zip"), + sha256 = "0e5d5c16686beff813e3946b26ca412f28acaf611228d20728ffb6479264fe19", + strip_prefix = "XNNPACK-9ddeb74f9f6866174d61888947e4aa9ffe963b1b", + urls = tf_mirror_urls("https://github.com/google/XNNPACK/archive/9ddeb74f9f6866174d61888947e4aa9ffe963b1b.zip"), ) # LINT.ThenChange(//tensorflow/lite/tools/cmake/modules/xnnpack.cmake) + # XNNPack dependency. + tf_http_archive( + name = "KleidiAI", + sha256 = "88233e427be6579560073267575f00f3b5fc370a31a43bbdd87a1810bd4bf1b6", + strip_prefix = "kleidiai-cddf991af5de49fd34949fa39690e4e906e04074", + urls = tf_mirror_urls("https://gitlab.arm.com/kleidi/kleidiai/-/archive/cddf991af5de49fd34949fa39690e4e906e04074/kleidiai-cddf991af5de49fd34949fa39690e4e906e04074.zip"), + ) + tf_http_archive( name = "FXdiv", sha256 = "3d7b0e9c4c658a84376a1086126be02f9b7f753caa95e009d9ac38d11da444db", @@ -781,9 +785,9 @@ def _tf_repositories(): tf_http_archive( name = "pybind11", - urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.10.4.tar.gz"), - sha256 = "832e2f309c57da9c1e6d4542dedd34b24e4192ecb4d62f6f4866a737454c9970", - strip_prefix = "pybind11-2.10.4", + urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.13.4.tar.gz"), + sha256 = "efc901aa0aab439a3fea6efeaf930b5a349fb06394bf845c64ce15a9cf8f0240", + strip_prefix = "pybind11-2.13.4", build_file = "//third_party:pybind11.BUILD", system_build_file = "//third_party/systemlibs:pybind11.BUILD", ) diff --git a/third_party/absl/nvidia_jetson.patch b/third_party/absl/nvidia_jetson.patch new file mode 100644 index 00000000000000..5328c3a0d605c7 --- /dev/null +++ b/third_party/absl/nvidia_jetson.patch @@ -0,0 +1,35 @@ +From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001 +From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= +Date: Thu, 1 Aug 2024 12:38:52 -0700 +Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665 + +Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732 + +Fix build on NVIDIA Jetson board. Fix #1665 + +This patch is already used by the spark project. +I'm fixing this as this break the build of Tensorflow and JAX on Jetson board. +Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9 + +Merging this change closes #1732 + +COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff +PiperOrigin-RevId: 658501520 +Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181 +--- + absl/base/config.h | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/absl/base/config.h b/absl/base/config.h +index 97c9a22a109..ab1e9860a91 100644 +--- a/absl/base/config.h ++++ b/absl/base/config.h +@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' || + // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code + #ifdef ABSL_INTERNAL_HAVE_ARM_NEON + #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set +-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__) ++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__)) + #define ABSL_INTERNAL_HAVE_ARM_NEON 1 + #endif + diff --git a/third_party/absl/workspace.bzl b/third_party/absl/workspace.bzl index 06f75166ce4bb6..9565a82c331946 100644 --- a/third_party/absl/workspace.bzl +++ b/third_party/absl/workspace.bzl @@ -44,4 +44,5 @@ def repo(): system_link_files = SYS_LINKS, strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT), urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)), + patch_file = ["//third_party/absl:nvidia_jetson.patch"], ) diff --git a/third_party/gloo/gloo.BUILD b/third_party/gloo/gloo.BUILD index 99a8e32c69c8f6..2de0c852ebf007 100644 --- a/third_party/gloo/gloo.BUILD +++ b/third_party/gloo/gloo.BUILD @@ -22,7 +22,7 @@ substitions = { "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", - "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV (__APPLE__ ? 1 : 0)", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", @@ -95,3 +95,14 @@ cc_library( copts = ["-fexceptions"], deps = [":gloo"], ) + +cc_library( + name = "transport_uv", + srcs = glob(["gloo/transport/uv/*.cc"]), + hdrs = glob(["gloo/transport/uv/*.h"]), + copts = ["-fexceptions"], + deps = [ + ":gloo", + "@uv", + ], +) diff --git a/third_party/gpus/check_cuda_libs.py b/third_party/gpus/check_cuda_libs.py index afd6380b0ac203..b1a10a86b9aac6 100644 --- a/third_party/gpus/check_cuda_libs.py +++ b/third_party/gpus/check_cuda_libs.py @@ -14,6 +14,9 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library. diff --git a/third_party/gpus/compiler_common_tools.bzl b/third_party/gpus/compiler_common_tools.bzl new file mode 100644 index 00000000000000..bd07f49ec457bb --- /dev/null +++ b/third_party/gpus/compiler_common_tools.bzl @@ -0,0 +1,174 @@ +"""Common compiler functions. """ + +load( + "//third_party/remote_config:common.bzl", + "err_out", + "raw_exec", + "realpath", +) + +def to_list_of_strings(elements): + """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. + + This is to be used to put a list of strings into the bzl file templates + so it gets interpreted as list of strings in Starlark. + + Args: + elements: list of string elements + + Returns: + single string of elements wrapped in quotes separated by a comma.""" + quoted_strings = ["\"" + element + "\"" for element in elements] + return ", ".join(quoted_strings) + +_INC_DIR_MARKER_BEGIN = "#include <...>" + +# OSX add " (framework directory)" at the end of line, strip it. +_OSX_FRAMEWORK_SUFFIX = " (framework directory)" +_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) + +# TODO(dzc): Once these functions have been factored out of Bazel's +# cc_configure.bzl, load them from @bazel_tools instead. +def _cxx_inc_convert(path): + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + if path.endswith(_OSX_FRAMEWORK_SUFFIX): + path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() + return path + +def _normalize_include_path(repository_ctx, path): + """Normalizes include paths before writing them to the crosstool. + + If path points inside the 'crosstool' folder of the repository, a relative + path is returned. + If path points outside the 'crosstool' folder, an absolute path is returned. + """ + path = str(repository_ctx.path(path)) + crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) + + if path.startswith(crosstool_folder): + # We drop the path to "$REPO/crosstool" and a trailing path separator. + return path[len(crosstool_folder) + 1:] + return path + +def _is_compiler_option_supported(repository_ctx, cc, option): + """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" + result = repository_ctx.execute([ + cc, + option, + "-o", + "/dev/null", + "-c", + str(repository_ctx.path("tools/cpp/empty.cc")), + ]) + return result.stderr.find(option) == -1 + +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sys_root): + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + sysroot = [] + if tf_sys_root: + sysroot += ["--sysroot", tf_sys_root] + result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + + sysroot) + stderr = err_out(result) + index1 = stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = stderr[index1 + 1:] + else: + inc_dirs = stderr[index1 + 1:index2].strip() + + print_resource_dir_supported = _is_compiler_option_supported( + repository_ctx, + cc, + "-print-resource-dir", + ) + + if print_resource_dir_supported: + resource_dir = repository_ctx.execute( + [cc, "-print-resource-dir"], + ).stdout.strip() + "/share" + inc_dirs += "\n" + resource_dir + + compiler_includes = [ + _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) + for p in inc_dirs.split("\n") + ] + + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + +def get_cxx_inc_directories(repository_ctx, cc, tf_sys_root): + """Compute the list of default C and C++ include directories.""" + + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sys_root, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sys_root, + ) + + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp + ] diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl index 8eda7a1cf6ac2b..b9553d9b99ecfe 100644 --- a/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/gpus/crosstool/BUILD.tpl @@ -2,6 +2,7 @@ # Update cuda_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["restricted"]) @@ -133,9 +134,17 @@ filegroup( srcs = [], ) +filegroup( + name = "cuda_nvcc_files", + srcs = %{cuda_nvcc_files}, +) + filegroup( name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + srcs = [ + ":cuda_nvcc_files", + ":clang/bin/crosstool_wrapper_driver_is_not_gcc" + ], ) filegroup( diff --git a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index c46e09484fdfad..eb3a1d8c8ddf02 100644 --- a/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -181,6 +181,9 @@ def InvokeNvcc(argv, log=False): nvccopts += ['--keep', '--keep-dir', tempdir] # Force C++17 dialect (note, everything in just one string!) nvccopts += ['--std c++17'] + # This is so that nvcc does not complain about MSVC or CLANG. + nvccopts += ['-allow-unsupported-compiler'] + nvccopts += ['--expt-extended-lambda', '--expt-relaxed-constexpr'] if log: Log([NVCC_PATH] + nvccopts) diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl index 0b85e59231a374..094431dcedfc12 100644 --- a/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/gpus/cuda/BUILD.tpl @@ -1,6 +1,10 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Please use `hermetic/cuda_configure` instead. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -144,7 +148,6 @@ cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], data = ["cuda/lib/%{cusolver_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -220,7 +223,6 @@ cc_library( name = "cusparse", srcs = ["cuda/lib/%{cusparse_lib}"], data = ["cuda/lib/%{cusparse_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -242,6 +244,41 @@ py_library( srcs = ["cuda/cuda_config.py"], ) +# Build setting that is always true (i.e. it can not be changed on the +# command line). It is used to create the config settings below that are +# always or never satisfied. +bool_setting( + name = "true_setting", + visibility = ["//visibility:private"], + build_setting_default = True, +) + +# Config settings whether TensorFlow is built with hermetic CUDA. +# These configs are never satisfied. +config_setting( + name = "hermetic_cuda_tools", + flag_values = {":true_setting": "False"}, +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":true_setting": "False"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + %{copy_rules} cc_library( @@ -249,3 +286,9 @@ cc_library( # to make bazel query happy. name = "nvptxcompiler", ) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvjitlink", +) \ No newline at end of file diff --git a/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/gpus/cuda/BUILD.windows.tpl index dee0e898d9ae7a..6b25c8398a7144 100644 --- a/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,3 +1,7 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Hermetic CUDA repository rule doesn't support Windows. +# Please use `hermetic/cuda_configure`. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl index bc865cecb3240a..d1c50ea6377b9e 100644 --- a/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/gpus/cuda/build_defs.bzl.tpl @@ -104,9 +104,16 @@ def if_cuda_newer_than(wanted_ver, if_true, if_false = []): wanted_major = int(wanted_ver.split('_')[0]) wanted_minor = int(wanted_ver.split('_')[1]) - configured_version = "%{cuda_version}" - configured_major = int(configured_version.split('.')[0]) - configured_minor = int(configured_version.split('.')[1]) + # Strip "64_" which appears in the CUDA version on Windows. + configured_version = "%{cuda_version}".rsplit("_", 1)[-1] + configured_version_parts = configured_version.split('.') + + # On Windows, the major and minor versions are concatenated without a period and the minor only contains one digit. + if len(configured_version_parts) == 1: + configured_version_parts = [configured_version[0:-1], configured_version[-1:]] + + configured_major = int(configured_version_parts[0]) + configured_minor = int(configured_version_parts[1]) if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): return select({"//conditions:default": if_true}) @@ -142,9 +149,13 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], **kwargs): +def cuda_library(copts = [], tags = [],**kwargs): """Wrapper over cc_library which adds default CUDA options.""" - native.cc_library(copts = cuda_default_copts() + copts, **kwargs) + native.cc_library( + copts = cuda_default_copts() + copts, + tags = tags + ["gpu"], + **kwargs + ) def cuda_cc_test(copts = [], **kwargs): """Wrapper over cc_test which adds default CUDA options.""" diff --git a/third_party/gpus/cuda/hermetic/BUILD b/third_party/gpus/cuda/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/gpus/cuda/hermetic/BUILD.tpl new file mode 100644 index 00000000000000..ccf1b9a030d5ad --- /dev/null +++ b/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -0,0 +1,266 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + ], + deps = [":cudart_headers", + ":cublas_headers", + ":cccl_headers", + ":nvtx_headers", + ":nvcc_headers", + ":cusolver_headers", + ":cufft_headers", + ":cusparse_headers", + ":curand_headers", + ":cupti_headers", + ":nvml_headers"], +) + +cc_library( + name = "cudart_static", + srcs = ["@cuda_cudart//:static"], + linkopts = [ + "-ldl", + "-lpthread", + %{cudart_static_linkopt} + ], +) + +alias( + name = "cuda_driver", + actual = "@cuda_cudart//:cuda_driver", +) + +alias( + name = "cudart_headers", + actual = "@cuda_cudart//:headers", +) + +alias( + name = "cudart", + actual = "@cuda_cudart//:cudart", +) + +alias( + name = "nvtx_headers", + actual = "@cuda_nvtx//:headers", +) + +alias( + name = "nvml_headers", + actual = "@cuda_nvml//:headers", +) + +alias( + name = "nvcc_headers", + actual = "@cuda_nvcc//:headers", +) + +alias( + name = "cccl_headers", + actual = "@cuda_cccl//:headers", +) + +alias( + name = "cublas_headers", + actual = "@cuda_cublas//:headers", +) + +alias( + name = "cusolver_headers", + actual = "@cuda_cusolver//:headers", +) + +alias( + name = "cufft_headers", + actual = "@cuda_cufft//:headers", +) + +alias( + name = "cusparse_headers", + actual = "@cuda_cusparse//:headers", +) + +alias( + name = "curand_headers", + actual = "@cuda_curand//:headers", +) + +alias( + name = "cublas", + actual = "@cuda_cublas//:cublas", +) + +alias( + name = "cublasLt", + actual = "@cuda_cublas//:cublasLt", +) + +alias( + name = "cusolver", + actual = "@cuda_cusolver//:cusolver", +) + +alias( + name = "cudnn", + actual = "@cuda_cudnn//:cudnn", +) + +alias( + name = "cudnn_header", + actual = "@cuda_cudnn//:headers", +) + +alias( + name = "cufft", + actual = "@cuda_cufft//:cufft", +) + +alias( + name = "curand", + actual = "@cuda_curand//:curand", +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = ":cuda_headers", +) + +alias( + name = "cupti_headers", + actual = "@cuda_cupti//:headers", +) + +alias( + name = "cupti_dsos", + actual = "@cuda_cupti//:cupti", +) + +alias( + name = "cusparse", + actual = "@cuda_cusparse//:cusparse", +) + +alias( + name = "cuda-nvvm", + actual = "@cuda_nvcc//:nvvm", +) + +alias( + name = "nvjitlink", + actual = "@cuda_nvjitlink//:nvjitlink" +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +# Config setting whether TensorFlow is built with hermetic CUDA. +alias( + name = "hermetic_cuda_tools", + actual = "@local_config_cuda//:is_cuda_enabled", +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":include_hermetic_cuda_libs": "True"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvptxcompiler", +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl new file mode 100644 index 00000000000000..85c0cbbb196fef --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -0,0 +1,15 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + hdrs = glob([ + %{comment}"include/cub/**", + %{comment}"include/cuda/**", + %{comment}"include/nv/**", + %{comment}"include/thrust/**", + ]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/gpus/cuda/hermetic/cuda_configure.bzl new file mode 100644 index 00000000000000..270b73c3884855 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -0,0 +1,521 @@ +"""Repository rule for hermetic CUDA autoconfiguration. + +`cuda_configure` depends on the following environment variables: + + * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. + * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for + both host and device code compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + * `HERMETIC_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default + is `3.5,5.2`. If not specified, the value will be determined by the + `TF_CUDA_COMPUTE_CAPABILITIES`. + * `PYTHON_BIN_PATH`: The python binary path +""" + +load( + "//third_party/gpus:compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", + "which", +) + +def _find_cc(repository_ctx): + """Find the C++ compiler.""" + cc_path_envvar = _CLANG_CUDA_COMPILER_PATH + cc_name = "clang" + + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Return the absolute path. + return cc_name + cc = which(repository_ctx, cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(cc_name, cc_path_envvar)) + return cc + +def _auto_configure_fail(msg): + """Output failure message when cuda configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _verify_build_defines(params): + """Verify all variables that crosstool/BUILD.tpl expects are substituted. + + Args: + params: dict of variables that will be passed to the BUILD.tpl template. + """ + missing = [] + for param in [ + "cxx_builtin_include_directories", + "extra_no_canonical_prefixes_flags", + "host_compiler_path", + "host_compiler_prefix", + "host_compiler_warnings", + "linker_bin_path", + "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", + "unfiltered_compile_flags", + "win_compiler_deps", + ]: + if ("%{" + param + "}") not in params: + missing.append(param) + + if missing: + _auto_configure_fail( + "BUILD.tpl template is missing these variables: " + + str(missing) + + ".\nWe only got: " + + str(params) + + ".", + ) + +def get_cuda_version(repository_ctx): + return (get_host_environ(repository_ctx, HERMETIC_CUDA_VERSION) or + get_host_environ(repository_ctx, TF_CUDA_VERSION)) + +def enable_cuda(repository_ctx): + """Returns whether to build with CUDA support.""" + return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) + +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, _TF_NVCC_CLANG) + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + +def _py_tmpl_dict(d): + return {"%{cuda_config}": str(d)} + +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "\"\"," if cpu_value == "Darwin" else "\"-lrt\"," + +def _compute_capabilities(repository_ctx): + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + + Returns: + list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = (get_host_environ( + repository_ctx, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + ) or + get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + )) + capabilities = (capabilities or "compute_35,compute_52").split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]): + # If all capabilities are in 'x.y' format, only include PTX for the + # highest capability. + cc_list = sorted([x.replace(".", "") for x in capabilities]) + capabilities = [ + "sm_%s" % x + for x in cc_list[:-1] + ] + ["compute_%s" % cc_list[-1]] + for i, capability in enumerate(capabilities): + parts = capability.split(".") + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): + _auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue + _auto_configure_fail("Invalid compute capability: %s" % capability) + + return capabilities + +def _compute_cuda_extra_copts(compute_capabilities): + copts = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + copts.append("--cuda-include-ptx=%s" % capability) + copts.append("--cuda-gpu-arch=%s" % capability) + + return str(copts) + +def _get_cuda_config(repository_ctx): + """Detects and returns information about the CUDA installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. + cudnn_version: The version of cuDNN on the system. + compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. + """ + + return struct( + cuda_version = get_cuda_version(repository_ctx), + cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), + cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), + cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), + cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), + curand_version = repository_ctx.read(repository_ctx.attr.curand_version), + cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), + cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), + cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + compute_capabilities = _compute_capabilities(repository_ctx), + cpu_value = get_cpu_value(repository_ctx), + ) + +_DUMMY_CROSSTOOL_BZL_FILE = """ +def error_gpu_disabled(): + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + + "to build with GPU support. Please re-run ./configure and enter 'Y' " + + "at the prompt to build with GPU support.") + + native.genrule( + name = "error_gen_crosstool", + outs = ["CROSSTOOL"], + cmd = "echo 'Should not be run.' && exit 1", + ) + + native.filegroup( + name = "crosstool", + srcs = [":CROSSTOOL"], + output_licenses = ["unencumbered"], + ) +""" + +_DUMMY_CROSSTOOL_BUILD_FILE = """ +load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") + +error_gpu_disabled() +""" + +def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + + # Set up BUILD file for cuda/. + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "False", + "%{cuda_extra_copts}": "[]", + "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + }, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({}), + ) + + # If cuda_configure is not configured to build with GPU support, and the user + # attempts to build with --config=cuda, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _create_local_cuda_repository(repository_ctx): + """Creates the repository containing files set up to build with CUDA.""" + cuda_config = _get_cuda_config(repository_ctx) + + # Set up BUILD file for cuda/ + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + cuda_config.compute_capabilities, + ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt( + cuda_config.cpu_value, + ), + }, + ) + + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + + # Set up crosstool/ + cc = _find_cc(repository_ctx) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) + + cuda_defines = {} + + # We do not support hermetic CUDA on Windows. + # This ensures the CROSSTOOL file parser is happy. + cuda_defines.update({ + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + "%{win_compiler_deps}": ":empty", + }) + + cuda_defines["%{builtin_sysroot}"] = tf_sysroot + cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root + cuda_defines["%{compiler}"] = "clang" + cuda_defines["%{host_compiler_prefix}"] = "/usr/bin" + cuda_defines["%{linker_bin_path}"] = "" + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" + cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + host_compiler_includes, + ) + cuda_defines["%{cuda_nvcc_files}"] = "if_cuda([\"@{nvcc_archive}//:bin\", \"@{nvcc_archive}//:nvvm\"])".format( + nvcc_archive = repository_ctx.attr.nvcc_binary.repo_name, + ) + + if not is_nvcc_and_clang: + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{compiler_deps}"] = ":cuda_nvcc_files" + repository_ctx.file( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + "", + ) + else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + + nvcc_relative_path = "%s/%s" % ( + repository_ctx.attr.nvcc_binary.workspace_root, + repository_ctx.attr.nvcc_binary.name, + ) + cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + + wrapper_defines = { + "%{cpu_compiler}": str(cc), + "%{cuda_version}": cuda_config.cuda_version, + "%{nvcc_path}": nvcc_relative_path, + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": "True", + } + repository_ctx.template( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + repository_ctx.attr.crosstool_wrapper_driver_is_not_gcc_tpl, + wrapper_defines, + ) + + _verify_build_defines(cuda_defines) + + # Only expand template variables in the BUILD file + repository_ctx.template( + "crosstool/BUILD", + repository_ctx.attr.crosstool_build_tpl, + cuda_defines, + ) + + # No templating of cc_toolchain_config - use attributes and templatize the + # BUILD file. + repository_ctx.template( + "crosstool/cc_toolchain_config.bzl", + repository_ctx.attr.cc_toolchain_config_tpl, + {}, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": ", ".join([ + cc.split("_")[1] + for cc in cuda_config.compute_capabilities + ]), + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({ + "cuda_version": cuda_config.cuda_version, + "cudnn_version": cuda_config.cudnn_version, + "cuda_compute_capabilities": cuda_config.compute_capabilities, + "cpu_compiler": str(cc), + }), + ) + +def _cuda_autoconf_impl(repository_ctx): + """Implementation of the cuda_autoconf repository rule.""" + build_file = repository_ctx.attr.local_config_cuda_build_file + + if not enable_cuda(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) + + repository_ctx.symlink(build_file, "BUILD") + +_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" +_HERMETIC_CUDA_COMPUTE_CAPABILITIES = "HERMETIC_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" +TF_CUDA_VERSION = "TF_CUDA_VERSION" +TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NVCC_CLANG = "TF_NVCC_CLANG" +_TF_SYSROOT = "TF_SYSROOT" + +_ENVIRONS = [ + _CLANG_CUDA_COMPILER_PATH, + TF_NEED_CUDA, + _TF_NVCC_CLANG, + TF_CUDA_VERSION, + HERMETIC_CUDA_VERSION, + _TF_CUDA_COMPUTE_CAPABILITIES, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + _TF_SYSROOT, + _PYTHON_BIN_PATH, + "TMP", + "TMPDIR", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", +] + +cuda_configure = repository_rule( + implementation = _cuda_autoconf_impl, + environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), + "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), + "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), + "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), + "cuda_config_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.h.tpl")), + "cuda_config_py_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.py.tpl")), + "crosstool_wrapper_driver_is_not_gcc_tpl": attr.label(default = Label("//third_party/gpus/crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl")), + "crosstool_build_tpl": attr.label(default = Label("//third_party/gpus/crosstool:BUILD.tpl")), + "cc_toolchain_config_tpl": attr.label(default = Label("//third_party/gpus/crosstool:cc_toolchain_config.bzl.tpl")), + }, +) +"""Detects and configures the hermetic CUDA toolchain. + +Add the following to your WORKSPACE file: + +```python +cuda_configure(name = "local_config_cuda") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl new file mode 100644 index 00000000000000..510235d801de4e --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -0,0 +1,44 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cublas_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublas.so.%{libcublas_version}", + deps = [":cublasLt"], +) + +cc_import( + name = "cublasLt_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublasLt.so.%{libcublaslt_version}", +) +%{multiline_comment} +cc_library( + name = "cublas", + visibility = ["//visibility:public"], + %{comment}deps = [":cublas_shared_library"], +) + +cc_library( + name = "cublasLt", + visibility = ["//visibility:public"], + %{comment}deps = [":cublasLt_shared_library"], +) + +cc_library( + name = "headers", + %{comment}hdrs = [ + %{comment}"include/cublas.h", + %{comment}"include/cublasLt.h", + %{comment}"include/cublas_api.h", + %{comment}"include/cublas_v2.h", + %{comment}], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl new file mode 100644 index 00000000000000..f7ba469b42b76a --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -0,0 +1,126 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) + +filegroup( + name = "static", + srcs = ["lib/libcudart_static.a"], + visibility = ["@local_config_cuda//cuda:__pkg__"], +) +%{multiline_comment} +# TODO: Replace system provided library with hermetic NVIDIA driver library. +cc_import( + name = "cuda_driver_shared_library", + interface_library = "lib/stubs/libcuda.so", + system_provided = 1, +) + +cc_import( + name = "cudart_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcudart.so.%{libcudart_version}", +) +%{multiline_comment} +cc_library( + name = "cuda_driver", + %{comment}deps = [":cuda_driver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + %{comment}deps = [ + %{comment}":cuda_driver", + %{comment}":cudart_shared_library", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/builtin_types.h", + %{comment}"include/channel_descriptor.h", + %{comment}"include/common_functions.h", + %{comment}"include/cooperative_groups/**", + %{comment}"include/cooperative_groups.h", + %{comment}"include/cuComplex.h", + %{comment}"include/cuda.h", + %{comment}"include/cudaEGL.h", + %{comment}"include/cudaEGLTypedefs.h", + %{comment}"include/cudaGL.h", + %{comment}"include/cudaGLTypedefs.h", + %{comment}"include/cudaProfilerTypedefs.h", + %{comment}"include/cudaTypedefs.h", + %{comment}"include/cudaVDPAU.h", + %{comment}"include/cudaVDPAUTypedefs.h", + %{comment}"include/cuda_awbarrier.h", + %{comment}"include/cuda_awbarrier_helpers.h", + %{comment}"include/cuda_awbarrier_primitives.h", + %{comment}"include/cuda_bf16.h", + %{comment}"include/cuda_bf16.hpp", + %{comment}"include/cuda_device_runtime_api.h", + %{comment}"include/cuda_egl_interop.h", + %{comment}"include/cuda_fp16.h", + %{comment}"include/cuda_fp16.hpp", + %{comment}"include/cuda_fp8.h", + %{comment}"include/cuda_fp8.hpp", + %{comment}"include/cuda_gl_interop.h", + %{comment}"include/cuda_occupancy.h", + %{comment}"include/cuda_pipeline.h", + %{comment}"include/cuda_pipeline_helpers.h", + %{comment}"include/cuda_pipeline_primitives.h", + %{comment}"include/cuda_runtime.h", + %{comment}"include/cuda_runtime_api.h", + %{comment}"include/cuda_surface_types.h", + %{comment}"include/cuda_texture_types.h", + %{comment}"include/cuda_vdpau_interop.h", + %{comment}"include/cudart_platform.h", + %{comment}"include/device_atomic_functions.h", + %{comment}"include/device_atomic_functions.hpp", + %{comment}"include/device_double_functions.h", + %{comment}"include/device_functions.h", + %{comment}"include/device_launch_parameters.h", + %{comment}"include/device_types.h", + %{comment}"include/driver_functions.h", + %{comment}"include/driver_types.h", + %{comment}"include/host_config.h", + %{comment}"include/host_defines.h", + %{comment}"include/library_types.h", + %{comment}"include/math_constants.h", + %{comment}"include/math_functions.h", + %{comment}"include/mma.h", + %{comment}"include/nvfunctional", + %{comment}"include/sm_20_atomic_functions.h", + %{comment}"include/sm_20_atomic_functions.hpp", + %{comment}"include/sm_20_intrinsics.h", + %{comment}"include/sm_20_intrinsics.hpp", + %{comment}"include/sm_30_intrinsics.h", + %{comment}"include/sm_30_intrinsics.hpp", + %{comment}"include/sm_32_atomic_functions.h", + %{comment}"include/sm_32_atomic_functions.hpp", + %{comment}"include/sm_32_intrinsics.h", + %{comment}"include/sm_32_intrinsics.hpp", + %{comment}"include/sm_35_atomic_functions.h", + %{comment}"include/sm_35_intrinsics.h", + %{comment}"include/sm_60_atomic_functions.h", + %{comment}"include/sm_60_atomic_functions.hpp", + %{comment}"include/sm_61_intrinsics.h", + %{comment}"include/sm_61_intrinsics.hpp", + %{comment}"include/surface_functions.h", + %{comment}"include/surface_indirect_functions.h", + %{comment}"include/surface_types.h", + %{comment}"include/texture_fetch_functions.h", + %{comment}"include/texture_indirect_functions.h", + %{comment}"include/texture_types.h", + %{comment}"include/vector_functions.h", + %{comment}"include/vector_functions.hpp", + %{comment}"include/vector_types.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl new file mode 100644 index 00000000000000..165c5b1579e73f --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -0,0 +1,73 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_infer.so.%{libcudnn_ops_infer_version}", +) + +cc_import( + name = "cudnn_cnn_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_infer.so.%{libcudnn_cnn_infer_version}", +) + +cc_import( + name = "cudnn_ops_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_train.so.%{libcudnn_ops_train_version}", +) + +cc_import( + name = "cudnn_cnn_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_train.so.%{libcudnn_cnn_train_version}", +) + +cc_import( + name = "cudnn_adv_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_infer.so.%{libcudnn_adv_infer_version}", +) + +cc_import( + name = "cudnn_adv_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_train.so.%{libcudnn_adv_train_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_ops_infer", + %{comment}":cudnn_ops_train", + %{comment}":cudnn_cnn_infer", + %{comment}":cudnn_cnn_train", + %{comment}":cudnn_adv_infer", + %{comment}":cudnn_adv_train", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl new file mode 100644 index 00000000000000..7f36054a51bb5b --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -0,0 +1,80 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops.so.%{libcudnn_ops_version}", +) + +cc_import( + name = "cudnn_cnn", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn.so.%{libcudnn_cnn_version}", +) + +cc_import( + name = "cudnn_adv", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv.so.%{libcudnn_adv_version}", +) + +cc_import( + name = "cudnn_graph", + hdrs = [":headers"], + shared_library = "lib/libcudnn_graph.so.%{libcudnn_graph_version}", +) + +cc_import( + name = "cudnn_engines_precompiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_precompiled.so.%{libcudnn_engines_precompiled_version}", +) + +cc_import( + name = "cudnn_engines_runtime_compiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_runtime_compiled.so.%{libcudnn_engines_runtime_compiled_version}", +) + +cc_import( + name = "cudnn_heuristic", + hdrs = [":headers"], + shared_library = "lib/libcudnn_heuristic.so.%{libcudnn_heuristic_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_engines_precompiled", + %{comment}":cudnn_ops", + %{comment}":cudnn_graph", + %{comment}":cudnn_cnn", + %{comment}":cudnn_adv", + %{comment}":cudnn_engines_runtime_compiled", + %{comment}":cudnn_heuristic", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl new file mode 100644 index 00000000000000..48ccb0ea3cd197 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -0,0 +1,29 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cufft_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcufft.so.%{libcufft_version}", +) +%{multiline_comment} +cc_library( + name = "cufft", + %{comment}deps = [":cufft_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudalibxt.h", + %{comment}"include/cufft*.h" + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl new file mode 100644 index 00000000000000..3efe76f470953f --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -0,0 +1,59 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cupti_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcupti.so.%{libcupti_version}", +) +%{multiline_comment} +cc_library( + name = "cupti", + %{comment}deps = [":cupti_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/Openacc/**", + %{comment}"include/Openmp/**", + %{comment}"include/cuda_stdint.h", + %{comment}"include/cupti.h", + %{comment}"include/cupti_activity.h", + %{comment}"include/cupti_activity_deprecated.h", + %{comment}"include/cupti_callbacks.h", + %{comment}"include/cupti_checkpoint.h", + %{comment}"include/cupti_driver_cbid.h", + %{comment}"include/cupti_events.h", + %{comment}"include/cupti_metrics.h", + %{comment}"include/cupti_nvtx_cbid.h", + %{comment}"include/cupti_pcsampling.h", + %{comment}"include/cupti_pcsampling_util.h", + %{comment}"include/cupti_profiler_target.h", + %{comment}"include/cupti_result.h", + %{comment}"include/cupti_runtime_cbid.h", + %{comment}"include/cupti_sass_metrics.h", + %{comment}"include/cupti_target.h", + %{comment}"include/cupti_version.h", + %{comment}"include/generated_cudaGL_meta.h", + %{comment}"include/generated_cudaVDPAU_meta.h", + %{comment}"include/generated_cuda_gl_interop_meta.h", + %{comment}"include/generated_cuda_meta.h", + %{comment}"include/generated_cuda_runtime_api_meta.h", + %{comment}"include/generated_cuda_vdpau_interop_meta.h", + %{comment}"include/generated_cudart_removed_meta.h", + %{comment}"include/generated_nvtx_meta.h", + %{comment}"include/nvperf_common.h", + %{comment}"include/nvperf_cuda_host.h", + %{comment}"include/nvperf_host.h", + %{comment}"include/nvperf_target.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/extras/CUPTI/include", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl new file mode 100644 index 00000000000000..50e5a8f18a96fd --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -0,0 +1,26 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "curand_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcurand.so.%{libcurand_version}", +) +%{multiline_comment} +cc_library( + name = "curand", + %{comment}deps = [":curand_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob(["include/curand*.h"]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl new file mode 100644 index 00000000000000..943a08ebeb96e1 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -0,0 +1,34 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusolver_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusolver.so.%{libcusolver_version}", + deps = [ + "@cuda_nvjitlink//:nvjitlink", + "@cuda_cusparse//:cusparse", + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + ], +) +%{multiline_comment} +cc_library( + name = "cusolver", + %{comment}deps = [":cusolver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cusolver*.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl new file mode 100644 index 00000000000000..46b24366ce1c04 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -0,0 +1,27 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusparse_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusparse.so.%{libcusparse_version}", + deps = ["@cuda_nvjitlink//:nvjitlink"], +) +%{multiline_comment} +cc_library( + name = "cusparse", + %{comment}deps = [":cusparse_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = ["include/cusparse.h"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl b/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl new file mode 100644 index 00000000000000..fdda3aaf92cea5 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl @@ -0,0 +1,125 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistributions JSON repository initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_JSON_DICT", +) + +def _get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_json_file_content(repository_ctx, url_to_sha256, json_file_name): + if len(url_to_sha256) > 1: + (url, sha256) = url_to_sha256 + else: + url = url_to_sha256[0] + sha256 = "" + repository_ctx.download( + url = tf_mirror_urls(url), + sha256 = sha256, + output = json_file_name, + ) + return repository_ctx.read(repository_ctx.path(json_file_name)) + +def _cuda_redist_json_impl(repository_ctx): + cuda_version = (_get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + _get_env_var(repository_ctx, "TF_CUDA_VERSION")) + local_cuda_path = _get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + cudnn_version = (_get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + _get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + local_cudnn_path = _get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + supported_cuda_versions = repository_ctx.attr.cuda_json_dict.keys() + if (cuda_version and not local_cuda_path and + (cuda_version not in supported_cuda_versions)): + fail( + ("The supported CUDA versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add JSON URL for" + + " CUDA version={version}.") + .format( + supported_versions = supported_cuda_versions, + version = cuda_version, + ), + ) + supported_cudnn_versions = repository_ctx.attr.cudnn_json_dict.keys() + if cudnn_version and not local_cudnn_path and (cudnn_version not in supported_cudnn_versions): + fail( + ("The supported CUDNN versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDNN_VERSION" + + " environment variable or add JSON URL for" + + " CUDNN version={version}.") + .format( + supported_versions = supported_cudnn_versions, + version = cudnn_version, + ), + ) + cuda_redistributions = "{}" + cudnn_redistributions = "{}" + if cuda_version and not local_cuda_path: + cuda_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cuda_json_dict[cuda_version], + "redistrib_cuda_%s.json" % cuda_version, + ) + if cudnn_version and not local_cudnn_path: + cudnn_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cudnn_json_dict[cudnn_version], + "redistrib_cudnn_%s.json" % cudnn_version, + ) + + repository_ctx.file( + "distributions.bzl", + """CUDA_REDISTRIBUTIONS = {cuda_redistributions} + +CUDNN_REDISTRIBUTIONS = {cudnn_redistributions} +""".format( + cuda_redistributions = cuda_redistributions, + cudnn_redistributions = cudnn_redistributions, + ), + ) + repository_ctx.file( + "BUILD", + "", + ) + +cuda_redist_json = repository_rule( + implementation = _cuda_redist_json_impl, + attrs = { + "cuda_json_dict": attr.string_list_dict(mandatory = True), + "cudnn_json_dict": attr.string_list_dict(mandatory = True), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "HERMETIC_CUDNN_VERSION", + "TF_CUDA_VERSION", + "TF_CUDNN_VERSION", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", + ], +) + +def cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT): + cuda_redist_json( + name = "cuda_redist_json", + cuda_json_dict = cuda_json_dict, + cudnn_json_dict = cudnn_json_dict, + ) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl new file mode 100644 index 00000000000000..7757a92a90b795 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -0,0 +1,75 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "bin/nvcc", +]) + +filegroup( + name = "nvvm", + srcs = [ + "nvvm/libdevice/libdevice.10.bc", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "nvlink", + srcs = [ + "bin/nvlink", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "fatbinary", + srcs = [ + "bin/fatbinary", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin2c", + srcs = [ + "bin/bin2c", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "ptxas", + srcs = [ + "bin/ptxas", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin", + srcs = glob([ + "bin/**", + "nvvm/bin/**", + ]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "link_stub", + srcs = [ + "bin/crt/link.stub", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/crt/**", + %{comment}"include/fatbinary_section.h", + %{comment}"include/nvPTXCompiler.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl new file mode 100644 index 00000000000000..9784a84471f1a7 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -0,0 +1,17 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nvjitlink_shared_library", + shared_library = "lib/libnvJitLink.so.%{libnvjitlink_version}", +) +%{multiline_comment} +cc_library( + name = "nvjitlink", + %{comment}deps = [":nvjitlink_shared_library"], + visibility = ["//visibility:public"], +) + diff --git a/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl new file mode 100644 index 00000000000000..23ee30f09f8ff3 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -0,0 +1,10 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = ["include/nvml.h"], + include_prefix = "third_party/gpus/cuda/nvml/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl new file mode 100644 index 00000000000000..986ef0c8f76166 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl @@ -0,0 +1,9 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +filegroup( + name = "nvprune", + srcs = [ + "bin/nvprune", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl new file mode 100644 index 00000000000000..de18489b455b79 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -0,0 +1,20 @@ +licenses(["restricted"]) # NVIDIA proprietary license +%{multiline_comment} +cc_import( + name = "nvrtc_main", + shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", +) + +cc_import( + name = "nvrtc_builtins", + shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", +) +%{multiline_comment} +cc_library( + name = "nvrtc", + %{comment}deps = [ + %{comment}":nvrtc_main", + %{comment}":nvrtc_builtins", + %{comment}], + visibility = ["//visibility:public"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl new file mode 100644 index 00000000000000..3457f41a502dee --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nvToolsExt*.h", + %{comment}"include/nvtx3/**", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl new file mode 100644 index 00000000000000..d2015e737540c3 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -0,0 +1,491 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDNN_REDIST_PATH_PREFIX", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +OS_ARCH_DICT = { + "amd64": "x86_64-unknown-linux-gnu", + "aarch64": "aarch64-unknown-linux-gnu", +} +_REDIST_ARCH_DICT = { + "linux-x86_64": "x86_64-unknown-linux-gnu", + "linux-sbsa": "aarch64-unknown-linux-gnu", +} + +SUPPORTED_ARCHIVE_EXTENSIONS = [ + ".zip", + ".jar", + ".war", + ".aar", + ".tar", + ".tar.gz", + ".tgz", + ".tar.xz", + ".txz", + ".tar.zst", + ".tzst", + ".tar.bz2", + ".tbz", + ".ar", + ".deb", + ".whl", +] + +def get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def get_archive_name(url): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the archive name without extension.""" + filename = _get_file_name(url) + for extension in SUPPORTED_ARCHIVE_EXTENSIONS: + if filename.endswith(extension): + return filename[:-len(extension)] + return filename + +LIB_EXTENSION = ".so." + +def _get_lib_name_and_version(path): + extension_index = path.rfind(LIB_EXTENSION) + last_slash_index = path.rfind("/") + lib_name = path[last_slash_index + 1:extension_index] + lib_version = path[extension_index + len(LIB_EXTENSION):] + return (lib_name, lib_version) + +def _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_dir_path = repository_ctx.path("lib") + if not lib_dir_path.exists: + return [] + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]).lower() + lib_dir_content = lib_dir_path.readdir() + return [ + str(f) + for f in lib_dir_content + if (LIB_EXTENSION in str(f) and + main_lib_name in str(f).lower()) + ] + +def get_lib_name_to_version_dict(repository_ctx): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns a dict of library names and major versions.""" + lib_name_to_version_dict = {} + for path in _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_name, lib_version = _get_lib_name_and_version(path) + key = "%%{%s_version}" % lib_name.lower() + + # We need to find either major or major.minor version if there is no + # file with major version. E.g. if we have the following files: + # libcudart.so + # libcudart.so.12 + # libcudart.so.12.3.2, + # we will save save {"%{libcudart_version}": "12"}. + if len(lib_version.split(".")) == 1: + lib_name_to_version_dict[key] = lib_version + if (len(lib_version.split(".")) == 2 and + key not in lib_name_to_version_dict): + lib_name_to_version_dict[key] = lib_version + return lib_name_to_version_dict + +def create_dummy_build_file(repository_ctx, use_comment_symbols = True): + repository_ctx.template( + "BUILD", + repository_ctx.attr.build_templates[0], + { + "%{multiline_comment}": "'''" if use_comment_symbols else "", + "%{comment}": "#" if use_comment_symbols else "", + }, + ) + +def _get_build_template(repository_ctx, major_lib_version): + template = None + for i in range(0, len(repository_ctx.attr.versions)): + for dist_version in repository_ctx.attr.versions[i].split(","): + if dist_version == major_lib_version: + template = repository_ctx.attr.build_templates[i] + break + if not template: + fail("No build template found for {} version {}".format( + repository_ctx.name, + major_lib_version, + )) + return template + +def get_major_library_version(repository_ctx, lib_name_to_version_dict): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the major library version provided the versions dict.""" + major_version = "" + if len(lib_name_to_version_dict) == 0: + return major_version + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]) + key = "%%{%s_version}" % main_lib_name + major_version = lib_name_to_version_dict[key] + return major_version + +def create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_lib_version): + # buildifier: disable=function-docstring-args + """Creates a BUILD file for the repository.""" + if len(major_lib_version) == 0: + build_template_content = repository_ctx.read( + repository_ctx.attr.build_templates[0], + ) + if "_version}" not in build_template_content: + create_dummy_build_file(repository_ctx, use_comment_symbols = False) + else: + create_dummy_build_file(repository_ctx) + return + build_template = _get_build_template( + repository_ctx, + major_lib_version.split(".")[0], + ) + repository_ctx.template( + "BUILD", + build_template, + lib_name_to_version_dict | { + "%{multiline_comment}": "", + "%{comment}": "", + }, + ) + +def _create_symlinks(repository_ctx, local_path, dirs): + for dir in dirs: + repository_ctx.symlink( + "{path}/{dir}".format( + path = local_path, + dir = dir, + ), + dir, + ) + +def use_local_path(repository_ctx, local_path, dirs): + # buildifier: disable=function-docstring-args + """Creates repository using local redistribution paths.""" + _create_symlinks( + repository_ctx, + local_path, + dirs, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _use_local_cuda_path(repository_ctx, local_cuda_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDA repository.""" + use_local_path( + repository_ctx, + local_cuda_path, + ["include", "lib", "bin", "nvvm"], + ) + +def _use_local_cudnn_path(repository_ctx, local_cudnn_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDNN repository.""" + use_local_path(repository_ctx, local_cudnn_path, ["include", "lib"]) + +def _download_redistribution(repository_ctx, arch_key, path_prefix): + (url, sha256) = repository_ctx.attr.url_dict[arch_key] + + # If url is not relative, then appending prefix is not needed. + if not (url.startswith("http") or url.startswith("file:///")): + url = path_prefix + url + archive_name = get_archive_name(url) + file_name = _get_file_name(url) + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + if repository_ctx.attr.override_strip_prefix: + strip_prefix = repository_ctx.attr.override_strip_prefix + else: + strip_prefix = archive_name + repository_ctx.extract( + archive = file_name, + stripPrefix = strip_prefix, + ) + repository_ctx.delete(file_name) + +def _use_downloaded_cuda_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" + major_version = "" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cuda_version: + # If no CUDA version is found, comment out all cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cuda_redist_path_prefix, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version(repository_ctx, lib_name_to_version_dict) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _cuda_repo_impl(repository_ctx): + local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + if local_cuda_path: + _use_local_cuda_path(repository_ctx, local_cuda_path) + else: + _use_downloaded_cuda_redistribution(repository_ctx) + +cuda_repo = repository_rule( + implementation = _cuda_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cuda_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDA_PATH", + ], +) + +def _use_downloaded_cudnn_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDNN redistribution and initializes hermetic CUDNN repository.""" + cudnn_version = None + major_version = "" + cudnn_version = (get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cudnn_version: + # If no CUDNN version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + arch_key = "cuda{version}_{arch}".format( + version = cuda_version.split(".")[0], + arch = arch_key, + ) + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cudnn_redist_path_prefix, + ) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _cudnn_repo_impl(repository_ctx): + local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + if local_cudnn_path: + _use_local_cudnn_path(repository_ctx, local_cudnn_path) + else: + _use_downloaded_cudnn_redistribution(repository_ctx) + +cudnn_repo = repository_rule( + implementation = _cudnn_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cudnn_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDNN_VERSION", + "TF_CUDNN_VERSION", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDNN_PATH", + ], +) + +def _get_redistribution_urls(dist_info): + url_dict = {} + for arch in _REDIST_ARCH_DICT.keys(): + if "relative_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["relative_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + if "full_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["full_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + for cuda_version, data in dist_info[arch].items(): + # CUDNN JSON might contain paths for each CUDA version. + path_key = "relative_path" + if path_key not in data.keys(): + path_key = "full_path" + url_dict["{cuda_version}_{arch}".format( + cuda_version = cuda_version, + arch = _REDIST_ARCH_DICT[arch], + )] = [data[path_key], data.get("sha256", "")] + return url_dict + +def get_version_and_template_lists(version_to_template): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns lists of versions and templates provided in the dict.""" + template_to_version_map = {} + for version, template in version_to_template.items(): + if template not in template_to_version_map.keys(): + template_to_version_map[template] = [version] + else: + template_to_version_map[template].append(version) + version_list = [] + template_list = [] + for template, versions in template_to_version_map.items(): + version_list.append(",".join(versions)) + template_list.append(Label(template)) + return (version_list, template_list) + +def cudnn_redist_init_repository( + cudnn_redistributions, + cudnn_redist_path_prefix = CUDNN_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDNN repository.""" + if "cudnn" in cudnn_redistributions.keys(): + url_dict = _get_redistribution_urls(cudnn_redistributions["cudnn"]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates["cudnn"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cudnn_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cudnn_redist_path_prefix = cudnn_redist_path_prefix, + ) + +def cuda_redist_init_repositories( + cuda_redistributions, + cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDA repositories.""" + for redist_name, _ in redist_versions_to_build_templates.items(): + if redist_name in ["cudnn", "cuda_nccl"]: + continue + if redist_name in cuda_redistributions.keys(): + url_dict = _get_redistribution_urls(cuda_redistributions[redist_name]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates[redist_name] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cuda_redist_path_prefix = cuda_redist_path_prefix, + ) diff --git a/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl new file mode 100644 index 00000000000000..d7ccff736a4801 --- /dev/null +++ b/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -0,0 +1,243 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistribution versions.""" + +CUDA_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/" +CUDNN_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/" + +CUDA_REDIST_JSON_DICT = { + "11.8": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_11.8.0.json", + "941a950a4ab3b95311c50df7b3c8bca973e0cdda76fc2f4b456d2d5e4dac0281", + ], + "12.1.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.1.1.json", + "bafea3cb83a4cf5c764eeedcaac0040d0d3c5db3f9a74550da0e7b6ac24d378c", + ], + "12.2.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.2.0.json", + "d883762c6339c8ebb3ffb072facc8f7265cd257d2db16a475fff9a9306ecea89", + ], + "12.3.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.1.json", + "b3cc4181d711cf9b6e3718f323b23813c24f9478119911d7b4bceec9b437dbc3", + ], + "12.3.2": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.2.json", + "1b6eacf335dd49803633fed53ef261d62c193e5a56eee5019e7d2f634e39e7ef", + ], + "12.4.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.0.json", + "a4f496b8d5299939b34c9ef88dc4274821f8c9451b2d7c9bcee53166932da067", + ], + "12.4.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.1.json", + "9cd815f3b71c2e3686ef2219b7794b81044f9dcefaa8e21dacfcb5bc4d931892", + ], + "12.5.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.0.json", + "166664b520bfe51f27abcc8c7a934f4cb6ea287f8c399b5f8255f6f4d214569a", + ], + "12.5.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.1.json", + "7ab9c76014ae4907fa1b51738af599607a5fd8ca3a5c4bb4c3b31338cc642a93", + ], + "12.6.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", + "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", + ], +} + +CUDNN_REDIST_JSON_DICT = { + "8.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", + "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", + ], + "8.9.4.25": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", + "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", + ], + "8.9.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.6.json", + "6069ef92a2b9bb18cebfbc944964bd2b024b76f2c2c35a43812982e0bc45cf0c", + ], + "8.9.7.29": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.7.29.json", + "a0734f26f068522464fa09b2f2c186dfbe6ad7407a88ea0c50dd331f0c3389ec", + ], + "9.1.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.1.1.json", + "d22d569405e5683ff8e563d00d6e8c27e5e6a902c564c23d752b22a8b8b3fe20", + ], + "9.2.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.0.json", + "6852eb279b95d2b5775f7a7737ec133bed059107f863cdd8588f3ae6f13eadd7", + ], + "9.2.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.1.json", + "9a4198c59b2e66b2b115a736ebe4dc8f3dc6d78161bb494702f824da8fc77b99", + ], + "9.3.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.3.0.json", + "d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e", + ], +} + +# The versions are different for x86 and aarch64 architectures because only +# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. +CUDA_12_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + }, + "aarch64-unknown-linux-gnu": { + "version": "2.20.5", + "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", + "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + }, +} + +CUDA_11_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/ac/9a/8b6a28b3b87d5fddab0e92cd835339eb8fbddaa71ae67518c8c1b3d05bae/nvidia_nccl_cu11-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "49d8350629c7888701d1fd200934942671cb5c728f49acc5a0b3a768820bed29", + }, +} + +CUDA_NCCL_WHEELS = { + "11.8": CUDA_11_NCCL_WHEEL_DICT, + "12.1.1": CUDA_12_NCCL_WHEEL_DICT, + "12.2.0": CUDA_12_NCCL_WHEEL_DICT, + "12.3.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.2": CUDA_12_NCCL_WHEEL_DICT, + "12.4.0": CUDA_12_NCCL_WHEEL_DICT, + "12.1.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.1": CUDA_12_NCCL_WHEEL_DICT, + "12.6.0": CUDA_12_NCCL_WHEEL_DICT, +} + +REDIST_VERSIONS_TO_BUILD_TEMPLATES = { + "cuda_nccl": { + "repo_name": "cuda_nccl", + "version_to_template": { + "2": "//third_party/nccl/hermetic:cuda_nccl.BUILD.tpl", + }, + }, + "cudnn": { + "repo_name": "cuda_cudnn", + "version_to_template": { + "9": "//third_party/gpus/cuda/hermetic:cuda_cudnn9.BUILD.tpl", + "8": "//third_party/gpus/cuda/hermetic:cuda_cudnn.BUILD.tpl", + }, + }, + "libcublas": { + "repo_name": "cuda_cublas", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + }, + }, + "cuda_cudart": { + "repo_name": "cuda_cudart", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + }, + }, + "libcufft": { + "repo_name": "cuda_cufft", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + "10": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + }, + }, + "cuda_cupti": { + "repo_name": "cuda_cupti", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + }, + }, + "libcurand": { + "repo_name": "cuda_curand", + "version_to_template": { + "10": "//third_party/gpus/cuda/hermetic:cuda_curand.BUILD.tpl", + }, + }, + "libcusolver": { + "repo_name": "cuda_cusolver", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cusolver.BUILD.tpl", + }, + }, + "libcusparse": { + "repo_name": "cuda_cusparse", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + }, + }, + "libnvjitlink": { + "repo_name": "cuda_nvjitlink", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", + }, + }, + "cuda_nvrtc": { + "repo_name": "cuda_nvrtc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + }, + }, + "cuda_cccl": { + "repo_name": "cuda_cccl", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + }, + }, + "cuda_nvcc": { + "repo_name": "cuda_nvcc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + }, + }, + "cuda_nvml_dev": { + "repo_name": "cuda_nvml", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + }, + }, + "cuda_nvprune": { + "repo_name": "cuda_nvprune", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + }, + }, + "cuda_nvtx": { + "repo_name": "cuda_nvtx", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + }, + }, +} diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index fefbf081c87e1c..8bf1db2b0f8f9f 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for CUDA autoconfiguration. +NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + `cuda_configure` depends on the following environment variables: * `TF_NEED_CUDA`: Whether to enable building with CUDA. @@ -53,6 +55,11 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -67,20 +74,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -def to_list_of_strings(elements): - """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. - - This is to be used to put a list of strings into the bzl file templates - so it gets interpreted as list of strings in Starlark. - - Args: - elements: list of string elements - - Returns: - single string of elements wrapped in quotes separated by a comma.""" - quoted_strings = ["\"" + element + "\"" for element in elements] - return ", ".join(quoted_strings) - def verify_build_defines(params): """Verify all variables that crosstool/BUILD.tpl expects are substituted. @@ -238,156 +231,6 @@ def find_cc(repository_ctx, use_cuda_clang): " environment variable").format(target_cc_name, cc_path_envvar)) return cc -_INC_DIR_MARKER_BEGIN = "#include <...>" - -# OSX add " (framework directory)" at the end of line, strip it. -_OSX_FRAMEWORK_SUFFIX = " (framework directory)" -_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) - -def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - if path.endswith(_OSX_FRAMEWORK_SUFFIX): - path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() - return path - -def _normalize_include_path(repository_ctx, path): - """Normalizes include paths before writing them to the crosstool. - - If path points inside the 'crosstool' folder of the repository, a relative - path is returned. - If path points outside the 'crosstool' folder, an absolute path is returned. - """ - path = str(repository_ctx.path(path)) - crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) - - if path.startswith(crosstool_folder): - # We drop the path to "$REPO/crosstool" and a trailing path separator. - return path[len(crosstool_folder) + 1:] - return path - -def _is_compiler_option_supported(repository_ctx, cc, option): - """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" - result = repository_ctx.execute([ - cc, - option, - "-o", - "/dev/null", - "-c", - str(repository_ctx.path("tools/cpp/empty.cc")), - ]) - return result.stderr.find(option) == -1 - -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - sysroot = [] - if tf_sysroot: - sysroot += ["--sysroot", tf_sysroot] - result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + - sysroot) - stderr = err_out(result) - index1 = stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = stderr[index1 + 1:] - else: - inc_dirs = stderr[index1 + 1:index2].strip() - - print_resource_dir_supported = _is_compiler_option_supported( - repository_ctx, - cc, - "-print-resource-dir", - ) - - if print_resource_dir_supported: - resource_dir = repository_ctx.execute( - [cc, "-print-resource-dir"], - ).stdout.strip() + "/share" - inc_dirs += "\n" + resource_dir - - compiler_includes = [ - _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) - for p in inc_dirs.split("\n") - ] - - # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc - # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) - # but Bazel might encounter either (usually reported by the compiler) - # especially when a compiler wrapper (e.g. ccache) is used. - # So we need to also include paths where symlinks are not resolved. - - # Try to find real path to CC installation to "see through" compiler wrappers - # GCC has the path to g++ - index1 = result.stderr.find("COLLECT_GCC=") - if index1 != -1: - index1 = result.stderr.find("=", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname - else: - # Clang has the directory - index1 = result.stderr.find("InstalledDir: ") - if index1 != -1: - index1 = result.stderr.find(" ", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname - else: - # Fallback to the CC path - cc_topdir = repository_ctx.path(cc).dirname.dirname - - # We now have the compiler installation prefix, e.g. /symlink/gcc - # And the resolved installation prefix, e.g. /opt/gcc - cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() - cc_topdir = str(cc_topdir).strip() - - # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. - # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] - # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] - if cc_topdir_resolved != cc_topdir: - unresolved_compiler_includes = [ - cc_topdir + inc[len(cc_topdir_resolved):] - for inc in compiler_includes - if inc.startswith(cc_topdir_resolved) - ] - compiler_includes = compiler_includes + unresolved_compiler_includes - return compiler_includes - -def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): - """Compute the list of default C and C++ include directories.""" - - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - True, - tf_sysroot, - ) - includes_c = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - False, - tf_sysroot, - ) - - return includes_cpp + [ - inc - for inc in includes_c - if inc not in includes_cpp - ] - def auto_configure_fail(msg): """Output failure message when cuda configuration fails.""" red = "\033[0;31m" @@ -1293,6 +1136,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cuda_nvcc_files}"] = "[]" if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ diff --git a/third_party/gpus/find_cuda_config.py b/third_party/gpus/find_cuda_config.py index b88694af5c014d..68623bf671da71 100644 --- a/third_party/gpus/find_cuda_config.py +++ b/third_party/gpus/find_cuda_config.py @@ -14,6 +14,9 @@ # ============================================================================== """Prints CUDA library and header directories and versions found on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + The script searches for CUDA library and header files on the system, inspects them to determine their version and prints the configuration to stdout. The paths to inspect and the required versions are specified through environment diff --git a/third_party/gpus/rocm_configure.bzl b/third_party/gpus/rocm_configure.bzl index 4ddd23ceda1690..4eb3c2eb77155b 100644 --- a/third_party/gpus/rocm_configure.bzl +++ b/third_party/gpus/rocm_configure.bzl @@ -22,12 +22,15 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) load( ":sycl_configure.bzl", diff --git a/third_party/gpus/sycl_configure.bzl b/third_party/gpus/sycl_configure.bzl index 05330b2fe53195..dd80694e7274f5 100644 --- a/third_party/gpus/sycl_configure.bzl +++ b/third_party/gpus/sycl_configure.bzl @@ -16,11 +16,14 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index d5a5b4677de410..821b0b238614b3 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,25 +1,30 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp ---- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp -+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp -@@ -15,9 +15,7 @@ - #include "mlir/Dialect/Linalg/IR/Linalg.h" - #include "mlir/Dialect/Linalg/Utils/Utils.h" - #include "mlir/Dialect/Tensor/IR/Tensor.h" --#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" - #include "mlir/Dialect/Utils/StaticValueUtils.h" --#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - #include "llvm/Support/MathExtras.h" - - namespace mlir { diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel -@@ -11350,7 +11350,6 @@ - ":TensorTransforms", - ":TensorUtils", - ":TilingInterface", -- ":TosaDialect", - ":TransformUtils", - ":ValueBoundsOpInterface", - ":VectorDialect", +@@ -342,6 +342,7 @@ + "include/mlir/IR/PDLPatternMatch.h.inc", + "include/mlir/Interfaces/CallInterfaces.h", + "include/mlir/Interfaces/DataLayoutInterfaces.h", ++ "include/mlir/Interfaces/InferIntRangeInterface.h", + "include/mlir/Interfaces/SideEffectInterfaces.h", + ], + hdrs = glob([ +@@ -362,6 +363,7 @@ + ":BytecodeOpInterfaceIncGen", + ":CallOpInterfacesIncGen", + ":DataLayoutInterfacesIncGen", ++ ":InferIntRangeInterfaceIncGen", + ":OpAsmInterfaceIncGen", + ":RegionKindInterfaceIncGen", + ":SideEffectInterfacesIncGen", +@@ -5422,7 +5424,9 @@ + hdrs = glob(["include/mlir/Dialect/LLVMIR/Transforms/*.h"]), + includes = ["include"], + deps = [ ++ ":DataLayoutInterfaces", + ":FuncDialect", ++ ":InliningUtils", + ":IR", + ":LLVMDialect", + ":LLVMPassIncGen", diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index c588af29e52bd2..5f8535bcee878a 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "9ddfe62f5c11e3f65f444209f514029ded2d58b9" - LLVM_SHA256 = "cb59f31fd0060e9d6f1142c702cad742b52a294ef9dbed87a864213fdcc007cd" + LLVM_COMMIT = "1115dee248e68a155001ac3712a189299d104863" + LLVM_SHA256 = "cbfe9694c137ed4489b1667dd01429b7595b40aa47b8d3ae4cafa8a6cff2ef8f" tf_http_archive( name = name, diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD index 4b3ad84d836933..8c730960bc3ed3 100644 --- a/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -12,7 +12,7 @@ _CMAKE_COMMON_LIST = { "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH", "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", @@ -109,6 +109,7 @@ _COPTS_LIST = select({ "-UUSE_CBLAS", "-DDNNL_ENABLE_MAX_CPU_ISA", "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", ] + tf_openmp_copts() _INCLUDES_LIST = [ @@ -119,6 +120,7 @@ _INCLUDES_LIST = [ "src/cpu", "src/cpu/gemm", "src/cpu/x64/xbyak", + "src/graph", ] _TEXTUAL_HDRS_LIST = glob([ @@ -129,6 +131,15 @@ _TEXTUAL_HDRS_LIST = glob([ "src/cpu/**/*.hpp", "src/cpu/jit_utils/**/*.hpp", "src/cpu/x64/xbyak/*.h", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", ]) + [ ":dnnl_config_h", ":dnnl_version_h", @@ -160,6 +171,16 @@ cc_library( "src/cpu/**/*.cpp", "src/common/ittnotify/*.c", "src/cpu/jit_utils/**/*.cpp", + "src/cpu/x64/**/*.cpp", + "src/graph/interface/*.cpp", + "src/graph/backend/*.cpp", + "src/graph/backend/dnnl/*.cpp", + "src/graph/backend/fake/*.cpp", + "src/graph/backend/dnnl/passes/*.cpp", + "src/graph/backend/dnnl/patterns/*.cpp", + "src/graph/backend/dnnl/kernels/*.cpp", + "src/graph/utils/*.cpp", + "src/graph/utils/pm/*.cpp", ], exclude = [ "src/cpu/aarch64/**", diff --git a/third_party/nanobind/nanobind.BUILD b/third_party/nanobind/nanobind.BUILD index c9f307b75ef0ca..72b47585b5e5d0 100644 --- a/third_party/nanobind/nanobind.BUILD +++ b/third_party/nanobind/nanobind.BUILD @@ -4,9 +4,12 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "nanobind", - srcs = glob([ - "src/*.cpp", - ]), + srcs = glob( + [ + "src/*.cpp", + ], + exclude = ["src/nb_combined.cpp"], + ), copts = ["-fexceptions"], defines = [ "NB_BUILD=1", diff --git a/third_party/nanobind/pr438.patch b/third_party/nanobind/pr438.patch deleted file mode 100644 index edb7d61700e03b..00000000000000 --- a/third_party/nanobind/pr438.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp -index 86f64d1..91f3932 100644 ---- a/src/nb_enum.cpp -+++ b/src/nb_enum.cpp -@@ -73,6 +73,13 @@ static PyObject *nb_enum_get_doc(PyObject *self, void *) { - return result; - } - -+static PyObject *nb_enum_get_value(PyObject *self, void *) { -+ enum_supplement &supp = nb_enum_supplement(Py_TYPE(self)); -+ return supp.is_signed ? nb_enum_int_signed(self) -+ : nb_enum_int_unsigned(self); -+} -+ -+ - NB_NOINLINE static PyObject *nb_enum_int_signed(PyObject *o) { - type_data *t = nb_type_data(Py_TYPE(o)); - const void *p = inst_ptr((nb_inst *) o); -@@ -141,6 +148,8 @@ error: - static PyGetSetDef nb_enum_getset[] = { - { "__doc__", nb_enum_get_doc, nullptr, nullptr, nullptr }, - { "__name__", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "name", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "value", nb_enum_get_value, nullptr, nullptr, nullptr }, - { nullptr, nullptr, nullptr, nullptr, nullptr } - }; - -diff --git a/tests/test_enum.py b/tests/test_enum.py -index 2a6e9ff..1063eef 100644 ---- a/tests/test_enum.py -+++ b/tests/test_enum.py -@@ -14,6 +14,9 @@ def test01_unsigned_enum(): - assert int(t.Enum.A) == 0 - assert int(t.Enum.B) == 1 - assert int(t.Enum.C) == 0xffffffff -+ assert t.Enum.A.value == 0 -+ assert t.Enum.B.value == 1 -+ assert t.Enum.C.value == 0xffffffff - assert t.Enum(0) is t.Enum.A - assert t.Enum(1) is t.Enum.B - assert t.Enum(0xffffffff) is t.Enum.C -@@ -48,6 +51,9 @@ def test02_signed_enum(): - assert int(t.SEnum.A) == 0 - assert int(t.SEnum.B) == 1 - assert int(t.SEnum.C) == -1 -+ assert t.SEnum.A.value == 0 -+ assert t.SEnum.B.value == 1 -+ assert t.SEnum.C.value == -1 - assert t.SEnum(0) is t.SEnum.A - assert t.SEnum(1) is t.SEnum.B - assert t.SEnum(-1) is t.SEnum.C \ No newline at end of file diff --git a/third_party/nanobind/pr461.patch b/third_party/nanobind/pr461.patch deleted file mode 100644 index aa0a51b68175a3..00000000000000 --- a/third_party/nanobind/pr461.patch +++ /dev/null @@ -1,39 +0,0 @@ -diff --git a/src/nb_type.cpp b/src/nb_type.cpp ---- a/src/nb_type.cpp -+++ b/src/nb_type.cpp -@@ -36,6 +36,11 @@ static PyObject **nb_weaklist_ptr(PyObje - return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; - } - -+static PyGetSetDef inst_getset[] = { -+ { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, -+ { nullptr, nullptr, nullptr, nullptr, nullptr } -+}; -+ - static int inst_clear(PyObject *self) { - PyObject **dict = nb_dict_ptr(self); - if (dict) -@@ -923,8 +928,11 @@ PyObject *nb_type_new(const type_init_da - } - - bool has_traverse = false; -- for (PyType_Slot *ts = slots; ts != s; ++ts) -+ bool has_getset = false; -+ for (PyType_Slot *ts = slots; ts != s; ++ts) { - has_traverse |= ts->slot == Py_tp_traverse; -+ has_getset |= ts->slot == Py_tp_getset; -+ } - - Py_ssize_t dictoffset = 0, weaklistoffset = 0; - int num_members = 0; -@@ -948,6 +956,10 @@ PyObject *nb_type_new(const type_init_da - has_traverse = true; - } - spec.basicsize = (int) basicsize; -+ -+ if (!has_getset) { -+ *s++ = { Py_tp_getset, (void *) inst_getset }; -+ } - } - - if (is_weak_referenceable) { diff --git a/third_party/nanobind/workspace.bzl b/third_party/nanobind/workspace.bzl index 9f9022dbaa8d12..1c692d396e9b98 100644 --- a/third_party/nanobind/workspace.bzl +++ b/third_party/nanobind/workspace.bzl @@ -5,12 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-1.9.2", - sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), + strip_prefix = "nanobind-2.1.0", + sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", - patch_file = [ - "//third_party/nanobind:pr438.patch", # Remove when updating to nanobind 2.0.0. - "//third_party/nanobind:pr461.patch", # Remove when updating to nanobind 2.0.0. - ], ) diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl index 53a6d4e1e41890..a0930df34ecec8 100644 --- a/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/nccl/build_defs.bzl.tpl @@ -5,7 +5,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # CUDA toolkit version as tuple (e.g. '(11, 1)'). _cuda_version = %{cuda_version} -_cuda_clang = %{cuda_clang} def _rdc_copts(): """Returns copts for compiling relocatable device code.""" @@ -121,25 +120,25 @@ _device_link = rule( "gpu_archs": attr.string_list(), "nvlink_args": attr.string_list(), "_nvlink": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"), + default = Label("%{nvlink_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_fatbinary": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"), + default = Label("%{fatbinary_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_bin2c": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"), + default = Label("%{bin2c_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_link_stub": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"), + default = Label("%{link_stub_label}"), allow_single_file = True, ), }, @@ -189,7 +188,7 @@ _prune_relocatable_code = rule( "input": attr.label(mandatory = True, allow_files = True), "gpu_archs": attr.string_list(), "_nvprune": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"), + default = Label("%{nvprune_label}"), allow_single_file = True, executable = True, cfg = "host", diff --git a/third_party/nccl/hermetic/BUILD b/third_party/nccl/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl new file mode 100644 index 00000000000000..61d7809bcdaad1 --- /dev/null +++ b/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -0,0 +1,30 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nccl_shared_library", + shared_library = "lib/libnccl.so.%{libnccl_version}", + hdrs = [":headers"], + deps = ["@local_config_cuda//cuda:cuda_headers", ":headers"], +) +%{multiline_comment} +cc_library( + name = "nccl", + %{comment}deps = [":nccl_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nccl*.h", + %{comment}]), + include_prefix = "third_party/nccl", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/nccl/hermetic/nccl_configure.bzl new file mode 100644 index 00000000000000..75f5a10b6fe24e --- /dev/null +++ b/third_party/nccl/hermetic/nccl_configure.bzl @@ -0,0 +1,183 @@ +"""Repository rule for hermetic NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should + be used, "0" if NCCL should be linked in statically. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + +""" + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "TF_NEED_CUDA", + "enable_cuda", + "get_cuda_version", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", +) + +_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl_via_stub", + }), + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +def _create_local_nccl_repository(repository_ctx): + cuda_version = get_cuda_version(repository_ctx).split(".")[:2] + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + + if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + + repository_ctx.template("generated_names.bzl", repository_ctx.attr.generated_names_tpl, {}) + repository_ctx.template( + "build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), + "%{nvlink_label}": "@cuda_nvcc//:nvlink", + "%{fatbinary_label}": "@cuda_nvcc//:fatbinary", + "%{bin2c_label}": "@cuda_nvcc//:bin2c", + "%{link_stub_label}": "@cuda_nvcc//:link_stub", + "%{nvprune_label}": "@cuda_nvprune//:nvprune", + }, + ) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version) + +def _nccl_autoconf_impl(repository_ctx): + if (not enable_cuda(repository_ctx) or + get_cpu_value(repository_ctx) != "Linux"): + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + else: + _create_local_nccl_repository(repository_ctx) + +_ENVIRONS = [ + TF_NEED_CUDA, + TF_CUDA_VERSION, + _TF_NCCL_USE_STUB, + HERMETIC_CUDA_VERSION, + "LOCAL_NCCL_PATH", +] + +nccl_configure = repository_rule( + environ = _ENVIRONS, + implementation = _nccl_autoconf_impl, + attrs = { + "environ": attr.string_dict(), + "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), + "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), + "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), + }, +) +"""Downloads and configures the hermetic NCCL configuration. + +Add the following to your WORKSPACE file: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl new file mode 100644 index 00000000000000..244cb851ddf591 --- /dev/null +++ b/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -0,0 +1,145 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic NCCL repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "OS_ARCH_DICT", + "create_build_file", + "create_dummy_build_file", + "get_archive_name", + "get_env_var", + "get_lib_name_to_version_dict", + "get_major_library_version", + "get_version_and_template_lists", + "use_local_path", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_NCCL_WHEELS", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +def _use_downloaded_nccl_wheel(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads NCCL wheel and inits hermetic NCCL repository.""" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + major_version = "" + if not cuda_version: + # If no CUDA version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch = OS_ARCH_DICT[repository_ctx.os.arch] + dict_key = "{cuda_version}-{arch}".format( + cuda_version = cuda_version, + arch = arch, + ) + supported_versions = repository_ctx.attr.url_dict.keys() + if dict_key not in supported_versions: + fail( + ("The supported NCCL versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add NCCL distribution for" + + " CUDA version={version}, OS={arch}.") + .format( + supported_versions = supported_versions, + version = cuda_version, + arch = arch, + ), + ) + sha256 = repository_ctx.attr.sha256_dict[dict_key] + url = repository_ctx.attr.url_dict[dict_key] + + archive_name = get_archive_name(url) + file_name = archive_name + ".zip" + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + repository_ctx.extract( + archive = file_name, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + repository_ctx.delete(file_name) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _use_local_nccl_path(repository_ctx, local_nccl_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic NCCL repository.""" + use_local_path(repository_ctx, local_nccl_path, ["include", "lib"]) + +def _cuda_nccl_repo_impl(repository_ctx): + local_nccl_path = get_env_var(repository_ctx, "LOCAL_NCCL_PATH") + if local_nccl_path: + _use_local_nccl_path(repository_ctx, local_nccl_path) + else: + _use_downloaded_nccl_wheel(repository_ctx) + +cuda_nccl_repo = repository_rule( + implementation = _cuda_nccl_repo_impl, + attrs = { + "sha256_dict": attr.string_dict(mandatory = True), + "url_dict": attr.string_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "strip_prefix": attr.string(), + }, + environ = ["HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "LOCAL_NCCL_PATH"], +) + +def nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes NCCL repository.""" + nccl_artifacts_dict = {"sha256_dict": {}, "url_dict": {}} + for cuda_version, nccl_wheel_info in cuda_nccl_wheels.items(): + for arch in OS_ARCH_DICT.values(): + if arch in nccl_wheel_info.keys(): + cuda_version_to_arch_key = "%s-%s" % (cuda_version, arch) + nccl_artifacts_dict["sha256_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch].get("sha256", "") + nccl_artifacts_dict["url_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch]["url"] + repo_data = redist_versions_to_build_templates["cuda_nccl"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_nccl_repo( + name = repo_data["repo_name"], + sha256_dict = nccl_artifacts_dict["sha256_dict"], + url_dict = nccl_artifacts_dict["url_dict"], + versions = versions, + build_templates = templates, + strip_prefix = "nvidia/nccl", + ) diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl index 22cf64d4771062..59f8b5c08ef0ee 100644 --- a/third_party/nccl/nccl_configure.bzl +++ b/third_party/nccl/nccl_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for NCCL configuration. +NB: DEPRECATED! Use `hermetic/nccl_configure` rule instead. + `nccl_configure` depends on the following environment variables: * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. @@ -8,7 +10,6 @@ files. * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is `/usr/local/cuda,usr/`. - * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC. * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should be used, "0" if NCCL should be linked in statically. @@ -33,7 +34,6 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" _TF_NCCL_VERSION = "TF_NCCL_VERSION" _TF_NEED_CUDA = "TF_NEED_CUDA" _TF_CUDA_PATHS = "TF_CUDA_PATHS" -_TF_CUDA_CLANG = "TF_CUDA_CLANG" _TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" @@ -129,7 +129,11 @@ def _create_local_nccl_repository(repository_ctx): _label("build_defs.bzl.tpl"), { "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), - "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)), + "%{nvlink_label}": "@local_config_cuda//cuda:cuda/bin/nvlink", + "%{fatbinary_label}": "@local_config_cuda//cuda:cuda/bin/fatbinary", + "%{bin2c_label}": "@local_config_cuda//cuda:cuda/bin/bin2c", + "%{link_stub_label}": "@local_config_cuda//cuda:cuda/bin/crt/link.stub", + "%{nvprune_label}": "@local_config_cuda//cuda:cuda/bin/nvprune", }, ) else: @@ -181,7 +185,6 @@ _ENVIRONS = [ _TF_CUDA_COMPUTE_CAPABILITIES, _TF_NEED_CUDA, _TF_CUDA_PATHS, - _TF_CUDA_CLANG, ] remote_nccl_configure = repository_rule( diff --git a/third_party/py/python_repo.bzl b/third_party/py/python_repo.bzl index f8fdd1033b5e2f..13aed2b687129f 100644 --- a/third_party/py/python_repo.bzl +++ b/third_party/py/python_repo.bzl @@ -255,8 +255,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -272,13 +276,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -361,6 +364,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/shardy/BUILD b/third_party/shardy/BUILD index ea1ecdb548c1f4..bf3ae84c142f65 100644 --- a/third_party/shardy/BUILD +++ b/third_party/shardy/BUILD @@ -2,4 +2,7 @@ # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) -exports_files(srcs = ["workspace.bzl"]) +exports_files(srcs = [ + "temporary.patch", + "workspace.bzl", +]) diff --git a/third_party/shardy/workspace.bzl b/third_party/shardy/workspace.bzl index c82f3275766f90..6d91def025b34a 100644 --- a/third_party/shardy/workspace.bzl +++ b/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "8f92b38a2400ce5dc72f97067b02c635ed4f3d00" - SHARDY_SHA256 = "3d91370627e81ce5285e5a6ec0d6dbefc786ae32f6d1ebcb4aa61fd247378b91" + SHARDY_COMMIT = "7e3ddfb532b3b53cb0b108014c24a86ac147e9f6" + SHARDY_SHA256 = "1d304e1e6f1132fe3ccb969d28798bc6ee90db84d10c85113ef8573eae350325" tf_http_archive( name = "shardy", diff --git a/third_party/spirv_llvm_translator/BUILD b/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe9..77fefee2b13b6d 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,28 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -1283,6 +1283,7 @@ + "@llvm-project//mlir:AllExtensions", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:TosaDialect", + ], + ) +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py b/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +--- stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py ++++ stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +@@ -32,9 +32,9 @@ + + # Make LLVM and StableHLO tools available in RUN directives + tools = [ +- 'stablehlo-opt', +- 'FileCheck', +- 'stablehlo-translate', ++ 'stablehlo-opt', ++ 'FileCheck', ++ 'stablehlo-translate', + ] + tool_dirs = [ + config.llvm_tools_dir, diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index f9c14a65d4dbb3..6c0cea3e8f16f5 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba" - STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332" + STABLEHLO_COMMIT = "23d3e1414b0be1c1b5256f0949520dc4f0a0705c" + STABLEHLO_SHA256 = "ad694a3da43a2a432c8c5f1c60be39fc211e28834cca07cf663ce8dc85d920fe" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 4de9536cbddbab..3466def95fd60d 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "60277ba976739502e45ad26585e071568fa44af1" - TFRT_SHA256 = "7634f696ad57f0ec914c4092cd8a2d19371f024abeb23d06c8eb5c18be660405" + TFRT_COMMIT = "07992d7c1ead60f610c17b7c1f9e50b6898adc87" + TFRT_SHA256 = "e1de8d371248d3dfc6e9ebd0e4094b57ce04d9545ae3756b5a84c33482614d5f" tf_http_archive( name = "tf_runtime", diff --git a/third_party/triton/llvm_integration/cl656020169.patch b/third_party/triton/llvm_integration/cl656020169.patch deleted file mode 100644 index 7586a90b14ccf6..00000000000000 --- a/third_party/triton/llvm_integration/cl656020169.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp ---- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -@@ -117,7 +117,7 @@ private: - auto operands = callOp.getOperands(); - auto result = callOp.getResult(); - -- LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value(); -+ LLVM::LLVMFunctionType calleeType = callOp.getVarCalleeType().value(); - Type returnType = calleeType.getReturnType(); - - auto loc = callOp.getLoc(); diff --git a/third_party/triton/workspace.bzl b/third_party/triton/workspace.bzl index 70656397d11b99..29287cc59f3210 100644 --- a/third_party/triton/workspace.bzl +++ b/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl650171855" - TRITON_SHA256 = "db4073455d0b86de6b71f7ee9472588c8e7c73a181f262b6231f3fdff1ece685" + TRITON_COMMIT = "cl664783844" + TRITON_SHA256 = "d5779d331008dd3a4941dd59e61385ec964987da74454248446ac3e36b874007" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/triton/xla_extensions/series.bzl b/third_party/triton/xla_extensions/series.bzl index 757c2b95a1be4a..19ba85b57b3672 100644 --- a/third_party/triton/xla_extensions/series.bzl +++ b/third_party/triton/xla_extensions/series.bzl @@ -8,5 +8,6 @@ IMPORTANT: This is a temporary hack while we are figuring out the proper way to extensions_files_patch_list = [ "//third_party/triton/xla_extensions:sparse_dot.patch", # Sparsity internal patch + "//third_party/triton/xla_extensions:sparsity_layout.patch", # Sparsity internal patch # Add new patches just above this line ] diff --git a/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/triton/xla_extensions/sparse_dot.patch index 21ed97b5afb822..dadc7732a4f280 100644 --- a/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/triton/xla_extensions/sparse_dot.patch @@ -57,7 +57,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index 012786dae..6043b764a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -498,6 +498,119 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, +@@ -498,6 +498,123 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } @@ -173,6 +173,10 @@ index 012786dae..6043b764a 100644 + ArrayRef tensorShape) const { + return ::getShapePerCTATile(getParent(), tensorShape); +} ++std::optional SparseDotMetaEncodingAttr::toLinearLayout( ++ ArrayRef shape) const { ++ return ::toLinearLayout(shape, getParent()); ++} + } // namespace gpu } // namespace triton @@ -273,9 +277,9 @@ index d74e0a224..4e45f7c4c 100644 + return op->hasTrait() || isa(op); +} + - // Replace the ForOp's yield with a new one with the given operands appended. - static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. + static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + tt::CoarseSchedule &schedule, @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) @@ -344,52 +348,6 @@ index d74e0a224..4e45f7c4c 100644 if (auto dotEnc = dyn_cast( dot.getResult().getType().getEncoding())) { auto loadTy = cast(op->getResultTypes()[0]); -diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -index 8c1f18e45..c39110d12 100644 ---- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -@@ -38,6 +38,10 @@ public: - auto srcEncoding = srcType.getEncoding(); - if (isa(srcEncoding)) - return; -+ if (isa(dstType.getEncoding())) { -+ replaceSparseMetaEncoding(cvtOp); -+ return; -+ } - auto dstDotOp = - dyn_cast(dstType.getEncoding()); - if (!dstDotOp) -@@ -86,6 +90,30 @@ public: - cvtOp.erase(); - }); - } -+ -+ private: -+ void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) { -+ auto srcType = cast(cvtOp.getOperand().getType()); -+ auto srcEncoding = srcType.getEncoding(); -+ auto sharedLayout = triton::gpu::SharedEncodingAttr::get( -+ cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding), -+ triton::gpu::getCTALayout(srcEncoding)); -+ -+ auto dstType = cast(cvtOp.getType()); -+ auto sharedMemorySpace = -+ triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); -+ auto tmpType = triton::MemDescType::get( -+ dstType.getShape(), dstType.getElementType(), sharedLayout, -+ sharedMemorySpace); -+ -+ OpBuilder builder(cvtOp); -+ auto tmp = builder.create( -+ cvtOp.getLoc(), tmpType, cvtOp.getSrc()); -+ auto newConvert = builder.create( -+ cvtOp.getLoc(), dstType, tmp); -+ cvtOp.replaceAllUsesWith(newConvert.getResult()); -+ cvtOp.erase(); -+ } - }; - - } // namespace gpu diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fd..37795c20c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp diff --git a/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/triton/xla_extensions/sparsity_layout.patch index b64ddbdbdab683..4daf4f2856069c 100644 --- a/third_party/triton/xla_extensions/sparsity_layout.patch +++ b/third_party/triton/xla_extensions/sparsity_layout.patch @@ -2,19 +2,20 @@ diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conv index 34fb89954..a0172e107 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, +@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> std::optional { -- llvm_unreachable("Argument rematerialization should not happen in Triton " -- "-> TritonGPU conversion"); -+ // TODO(b/354860562): reenable or remove. -+ // llvm_unreachable("Argument rematerialization should not happen in Triton " -+ // "-> TritonGPU conversion"); ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining arguments that have been converted to a new type. ++ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); return std::nullopt; - }); - -@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, +@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> std::optional { @@ -31,7 +32,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index df3d3b042..e38c184f6 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert +@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert // heuristic to accommodate fused attention. auto srcType = op.getSrc().getType(); auto dstType = op.getType(); diff --git a/third_party/uv/BUILD b/third_party/uv/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/uv/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/uv/uv.BUILD b/third_party/uv/uv.BUILD new file mode 100644 index 00000000000000..43c194a53ea516 --- /dev/null +++ b/third_party/uv/uv.BUILD @@ -0,0 +1,82 @@ +# Description: +# libuv is a cross-platform asynchronous I/O library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "uv", + srcs = [ + "src/fs-poll.c", + "src/idna.c", + "src/inet.c", + "src/random.c", + "src/strscpy.c", + "src/threadpool.c", + "src/timer.c", + "src/uv-common.c", + "src/uv-data-getter-setters.c", + "src/version.c", + ] + [ + "src/unix/async.c", + "src/unix/core.c", + "src/unix/dl.c", + "src/unix/fs.c", + "src/unix/getaddrinfo.c", + "src/unix/getnameinfo.c", + "src/unix/loop.c", + "src/unix/loop-watcher.c", + "src/unix/pipe.c", + "src/unix/poll.c", + "src/unix/process.c", + "src/unix/random-devurandom.c", + "src/unix/signal.c", + "src/unix/stream.c", + "src/unix/tcp.c", + "src/unix/thread.c", + "src/unix/tty.c", + "src/unix/udp.c", + ] + select({ + "@platforms//os:osx": [ + "src/unix/bsd-ifaddrs.c", + "src/unix/darwin.c", + "src/unix/darwin-proctitle.c", + "src/unix/fsevents.c", + "src/unix/kqueue.c", + "src/unix/proctitle.c", + "src/unix/random-getentropy.c", + ], + }), + # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt. + hdrs = [ + "include/uv.h", + "src/heap-inl.h", + "src/idna.h", + "src/queue.h", + "src/strscpy.h", + "src/unix/atomic-ops.h", + "src/unix/internal.h", + "src/unix/spinlock.h", + "src/uv-common.h", + ] + select({ + "@platforms//os:osx": [ + "src/unix/darwin-stub.h", + ], + }) + glob(["include/uv/*.h"]), + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = [ + "include", + "src", + ], + textual_hdrs = [ + "include/uv.h", + ], +) diff --git a/third_party/uv/workspace.bzl b/third_party/uv/workspace.bzl new file mode 100644 index 00000000000000..8d26ab4dcd41b5 --- /dev/null +++ b/third_party/uv/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import libuv.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports libuv.""" + + UV_VERSION = "v1.38.0" + UV_SHA256 = "71344f62c5020ed3643ad0bcba98ae4d7d6037285923c5416844d7c141a3ff93" + + tf_http_archive( + name = "uv", + sha256 = UV_SHA256, + strip_prefix = "libuv-{version}".format(version = UV_VERSION), + urls = tf_mirror_urls("https://dist.libuv.org/dist/{version}/libuv-{version}.tar.gz".format(version = UV_VERSION)), + build_file = "//third_party/uv:uv.BUILD", + ) diff --git a/third_party/xla/.bazelrc b/third_party/xla/.bazelrc index b94693e05efab8..9e565e91a1b903 100644 --- a/third_party/xla/.bazelrc +++ b/third_party/xla/.bazelrc @@ -219,13 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -234,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -351,6 +354,13 @@ build:windows --features=archive_param_file build:windows --copt=/d2ReducedOptimizeHugeFunctions build:windows --host_copt=/d2ReducedOptimizeHugeFunctions +# Before VS 2017 15.8, the member "type" would non-conformingly have an +# alignment of only alignof(max_align_t). VS 2017 15.8 was fixed to handle this +# correctly, but the fix inherently changes layout and breaks binary +# compatibility (*only* for uses of aligned_storage with extended alignments). +build:windows --copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE +build:windows --host_copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE + # Enable the runfiles symlink tree on Windows. This makes it possible to build # the pip package on Windows without an intermediate data-file archive, as the # build_pip_package script in its current form (as of Aug 2023) uses the @@ -538,10 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -566,6 +572,9 @@ build:rbe_win_clang --compiler=clang-cl build:rbe_win_clang --linkopt=/FORCE:MULTIPLE build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE +# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits. +build:rbe_windows_x86_cpu --config=rbe_win_clang + # END TF REMOTE BUILD EXECUTION OPTIONS # TFLite build configs for generic embedded Linux @@ -623,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -637,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -668,9 +675,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -780,17 +786,19 @@ test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/ # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. -# CPU PYCPP: +# LINUX CPU PYCPP: test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# CUDA PYCPP: + +# LINUX CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# ARM64 PYCPP + +# LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on # Linux x86 so that we can use RBE. Since tests still need to run on the single # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. @@ -823,6 +831,13 @@ build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +# WINDOWS X86-64 CPU PYCPP +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only +test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... + # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS diff --git a/third_party/xla/.github/workflows/bazel_query.yml b/third_party/xla/.github/workflows/bazel_query.yml new file mode 100644 index 00000000000000..969383fb09062f --- /dev/null +++ b/third_party/xla/.github/workflows/bazel_query.yml @@ -0,0 +1,40 @@ +# Copyright 2024 The OpenXLA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +name: Bazel Query +permissions: + contents: read +on: + pull_request: + +env: + # Have `go install` place binaries in $PATH + GOBIN: "/usr/local/bin" + +jobs: + bazel-query: + runs-on: ubuntu-22.04 + defaults: + run: + shell: bash + timeout-minutes: 2 + steps: + - name: "Checking out repository" + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + - name: "Install bazelisk" + run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 + - name: "Run bazel query //xla/..." + run: bazelisk query //xla/... > /dev/null + - name: "Run bazel query deps(//xla/...)" + run: bazelisk query "deps(//xla/...)" > /dev/null diff --git a/third_party/xla/.github/workflows/buildifier.yml b/third_party/xla/.github/workflows/buildifier.yml index 55140675aa28c3..797b88484a0860 100644 --- a/third_party/xla/.github/workflows/buildifier.yml +++ b/third_party/xla/.github/workflows/buildifier.yml @@ -13,7 +13,8 @@ # limitations under the License. # ============================================================================ name: Buildifier -permissions: read-all +permissions: + contents: read on: pull_request: diff --git a/third_party/xla/.github/workflows/check_contents.yml b/third_party/xla/.github/workflows/check_contents.yml index fd38adfd0adda3..1756b36275051d 100644 --- a/third_party/xla/.github/workflows/check_contents.yml +++ b/third_party/xla/.github/workflows/check_contents.yml @@ -19,7 +19,8 @@ # files once XLA moves out of Tensorflow internally. # TODO(ddunleavy): Update this after METADATA files are consolidated. name: Check Contents -permissions: read-all +permissions: + contents: read on: pull_request: diff --git a/third_party/xla/.github/workflows/clang_format.yml b/third_party/xla/.github/workflows/clang_format.yml index 2701311d047371..e22b67eae0ec86 100644 --- a/third_party/xla/.github/workflows/clang_format.yml +++ b/third_party/xla/.github/workflows/clang_format.yml @@ -14,7 +14,8 @@ # ============================================================================ name: Clang Format -permissions: read-all +permissions: + contents: read on: pull_request: diff --git a/third_party/xla/.kokoro/macos/build.sh b/third_party/xla/.kokoro/macos/build.sh index c3e0c126560afb..1aedf1badf55d2 100644 --- a/third_party/xla/.kokoro/macos/build.sh +++ b/third_party/xla/.kokoro/macos/build.sh @@ -37,32 +37,6 @@ function install_build_env_tools(){ sudo wget --no-verbose -O "/usr/local/bin/bazel" \ "https://github.com/bazelbuild/bazelisk/releases/download/v1.11.0/bazelisk-darwin-amd64" \ && chmod +x "/usr/local/bin/bazel" - - echo "===== Installing Pyenv =====" - # Install pyenv; Set up a virtual environment to control dependencies and their - # versions - git clone --branch v2.3.17 https://github.com/pyenv/pyenv.git /Users/kbuilder/.tf_pyenv - export PYENV_ROOT=/Users/kbuilder/.tf_pyenv - export PATH="$PYENV_ROOT/bin:$PATH" # if `pyenv` is not already on PATH - eval "$(pyenv init --path)" - eval "$(pyenv init -)" - - echo "===== Installing Python =====" - # Install Python and set the local python version - pyenv install -s "${TF_PYENV_VERSION}" - pyenv rehash - pyenv local "${TF_PYENV_VERSION}" - # Do a sanity check to make sure that we using the correct Python version - echo "===== Python version =====" - python --version - # Set up virtual environment and activate it - python -m venv /Users/kbuilder/.tf-venv && source /Users/kbuilder/.tf-venv/bin/activate - - # Setup links to Python. Referenced in ./macos.bazelrc - ln -s /Users/kbuilder/.tf-venv/lib/python* /Users/kbuilder/.tf-venv/lib/python - - echo "===== Upgrading to latest pip =====" - python -m pip install --upgrade pip } # Run the tests under /Volumes/BuildData/ so that we don't run into VM @@ -72,8 +46,6 @@ export TEST_TMPDIR=/Volumes/BuildData/bazel_output install_build_env_tools -python -m pip install numpy==1.21.4 - TARGET_FILTER="-//xla/hlo/experimental/... -//xla/python_api/... -//xla/python/... -//xla/service/gpu/..." TAGS_FILTER="-no_oss,-oss_excluded,-gpu,-no_mac,-nomac,-mac_excluded,-requires-gpu-nvidia,-requires-gpu-amd" diff --git a/third_party/xla/README.md b/third_party/xla/README.md index be0325eefc03ba..1a6d70a29cde25 100644 --- a/third_party/xla/README.md +++ b/third_party/xla/README.md @@ -7,6 +7,11 @@ The XLA compiler takes models from popular ML frameworks such as PyTorch, TensorFlow, and JAX, and optimizes them for high-performance execution across different hardware platforms including GPUs, CPUs, and ML accelerators. + + + OpenXLA Ecosystem + + ## Get started If you want to use XLA to compile your ML project, refer to the corresponding diff --git a/third_party/xla/WORKSPACE b/third_party/xla/WORKSPACE index 9d046e22949091..a18ebde79da786 100644 --- a/third_party/xla/WORKSPACE +++ b/third_party/xla/WORKSPACE @@ -52,3 +52,50 @@ xla_workspace1() load(":workspace0.bzl", "xla_workspace0") xla_workspace0() + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/third_party/xla/build_tools/build.py b/third_party/xla/build_tools/build.py index bbaab695d8eb9d..ec989e21737e0e 100755 --- a/third_party/xla/build_tools/build.py +++ b/third_party/xla/build_tools/build.py @@ -23,7 +23,6 @@ The script also assumes that the working directory never changes modulo `cd`ing into the repo that should be built (mostly `github/xla`, but also JAX and TF). """ -import contextlib import dataclasses import enum import logging @@ -33,8 +32,8 @@ import time from typing import Any, Dict, List, Tuple -_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} +_CONTAINER_NAME = "xla_ci" # TODO(ddunleavy): move this to the bazelrc _DEFAULT_BAZEL_OPTIONS = dict( test_output="errors", @@ -54,7 +53,7 @@ tty=True, volume="./github:/github", ) - +_KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} _XLA_DEFAULT_TARGET_PATTERNS = ( "//xla/...", "//build_tools/...", @@ -91,10 +90,37 @@ class BuildType(enum.Enum): @dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) -class DockerImage: - """Class representing a docker image.""" +class Build: + """Class representing a build of XLA.""" + type_: BuildType + repo: str image_url: str + target_patterns: Tuple[str, ...] + configs: Tuple[str, ...] = () + build_tag_filters: Tuple[str, ...] = () + test_tag_filters: Tuple[str, ...] = () + action_env: Dict[str, Any] = dataclasses.field(default_factory=dict) + test_env: Dict[str, Any] = dataclasses.field(default_factory=dict) + options: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def bazel_test_command(self) -> List[str]: + """Returns a bazel test command for this build. + + Returns: List of command line arguments + """ + options = _dict_to_cli_options(self.options) + configs = [f"--config={config}" for config in self.configs] + build_tag_filters = ( + f"--build_tag_filters={','.join(self.build_tag_filters)}" + ) + test_tag_filters = f"--test_tag_filters={','.join(self.test_tag_filters)}" + action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()] + test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] + + tag_filters = [build_tag_filters, test_tag_filters] + all_options = tag_filters + configs + action_env + test_env + options + return ["bazel", "test", *all_options, "--", *self.target_patterns] def _pull_docker_image_with_retries(self, retries=3) -> None: """Pulls docker image with retries to avoid transient rate limit errors.""" @@ -112,10 +138,9 @@ def _pull_docker_image_with_retries(self, retries=3) -> None: # TODO(ddunleavy): get sha # _write_to_sponge_config("TF_INFO_DOCKER_SHA", sha) - @contextlib.contextmanager - def pull_and_run( + def pull_and_run_docker_image( self, - name: str = "xla_ci", + name: str, command: Tuple[str, ...] = ("bash",), **kwargs: Any, ): @@ -126,67 +151,19 @@ def pull_and_run( command: Command given to `docker run`, e.g. `bash` **kwargs: Extra options passed to `docker run`. - Yields: - A function that accepts a command as a list of args, and runs those on the - corresponding docker container. It shouldn't be used outside the `with` - block, as the container will be stopped after the end of the block. - - This manages pulling, starting, and stopping the container. Example usage: - ``` - with image.pull_and_run() as docker_exec: - docker_exec(["command", "--with", "--flags"]) - ``` + Returns: + None. """ - try: - self._pull_docker_image_with_retries() - options = _dict_to_cli_options(kwargs) - sh([ - "docker", - "run", - "--name", - name, - *options, - self.image_url, - *command, - ]) - docker_exec = lambda args: sh(["docker", "exec", name, *args]) - yield docker_exec - finally: - sh(["docker", "stop", name]) - - -@dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) -class Build: - """Class representing a build of XLA.""" + self._pull_docker_image_with_retries() - type_: BuildType - repo: str - docker_image: DockerImage - target_patterns: Tuple[str, ...] - configs: Tuple[str, ...] = () - build_tag_filters: Tuple[str, ...] = () - test_tag_filters: Tuple[str, ...] = () - action_env: Dict[str, Any] = dataclasses.field(default_factory=dict) - test_env: Dict[str, Any] = dataclasses.field(default_factory=dict) - options: Dict[str, Any] = dataclasses.field(default_factory=dict) + assert "workdir" not in kwargs + _, repo_name = self.repo.split("/") + workdir = f"/github/{repo_name}" - def bazel_test_command(self) -> List[str]: - """Returns a bazel test command for this build. + options = ["--name", name, "--workdir", workdir] + options += _dict_to_cli_options(kwargs) - Returns: List of command line arguments - """ - options = _dict_to_cli_options(self.options) - configs = [f"--config={config}" for config in self.configs] - build_tag_filters = ( - f"--build_tag_filters={','.join(self.build_tag_filters)}" - ) - test_tag_filters = f"--test_tag_filters={','.join(self.test_tag_filters)}" - action_env = [f"--action_env={k}={v}" for k, v in self.action_env.items()] - test_env = [f"--test_env={k}={v}" for k, v in self.test_env.items()] - - tag_filters = [build_tag_filters, test_tag_filters] - all_options = tag_filters + configs + action_env + test_env + options - return ["bazel", "test", *all_options, "--", *self.target_patterns] + sh(["docker", "run", *options, self.image_url, *command]) def _tag_filters_for_compute_capability( @@ -202,18 +179,12 @@ def _tag_filters_for_compute_capability( return tag_filters -_DEFAULT_IMAGE = DockerImage( - image_url="gcr.io/tensorflow-sigs/build:latest-python3.11", -) +_DEFAULT_IMAGE = "gcr.io/tensorflow-sigs/build:latest-python3.11" # TODO(b/338885148): Remove this once the TF containers have cuDNN 9 -_CUDNN_9_IMAGE = DockerImage( - image_url="gcr.io/tensorflow-sigs/build@sha256:0a9728e258d7e0e5830d1960a65968ffdc1d138af5441e30948918e0d50ab2c7", -) +_CUDNN_9_IMAGE = "gcr.io/tensorflow-sigs/build@sha256:0a9728e258d7e0e5830d1960a65968ffdc1d138af5441e30948918e0d50ab2c7" -_ARM64_JAX_MULTI_PYTHON_IMAGE = DockerImage( - image_url="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python", -) +_ARM64_JAX_MULTI_PYTHON_IMAGE = "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build-arm64:jax-latest-multi-python" def nvidia_gpu_build_with_compute_capability( @@ -223,7 +194,7 @@ def nvidia_gpu_build_with_compute_capability( return Build( type_=type_, repo="openxla/xla", - docker_image=_CUDNN_9_IMAGE, + image_url=_CUDNN_9_IMAGE, target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, configs=configs, test_tag_filters=("-no_oss", "requires-gpu-nvidia") + extra_gpu_tags, @@ -245,7 +216,7 @@ def nvidia_gpu_build_with_compute_capability( _CPU_X86_BUILD = Build( type_=BuildType.CPU_X86, repo="openxla/xla", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=("warnings", "nonccl", "rbe_linux_cpu"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, build_tag_filters=cpu_x86_tag_filter, @@ -263,7 +234,7 @@ def nvidia_gpu_build_with_compute_capability( _CPU_ARM64_BUILD = Build( type_=BuildType.CPU_ARM64, repo="openxla/xla", - docker_image=_ARM64_JAX_MULTI_PYTHON_IMAGE, + image_url=_ARM64_JAX_MULTI_PYTHON_IMAGE, configs=("warnings", "rbe_cross_compile_linux_arm64_xla", "nonccl"), target_patterns=_XLA_DEFAULT_TARGET_PATTERNS, options={**_DEFAULT_BAZEL_OPTIONS, "build_tests_only": True}, @@ -280,7 +251,7 @@ def nvidia_gpu_build_with_compute_capability( _JAX_CPU_BUILD = Build( type_=BuildType.JAX_CPU, repo="google/jax", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "avx_posix", "mkl_open_source_only", @@ -300,7 +271,7 @@ def nvidia_gpu_build_with_compute_capability( _JAX_GPU_BUILD = Build( type_=BuildType.JAX_GPU, repo="google/jax", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "avx_posix", "mkl_open_source_only", @@ -323,7 +294,7 @@ def nvidia_gpu_build_with_compute_capability( _TENSORFLOW_CPU_BUILD = Build( type_=BuildType.TENSORFLOW_CPU, repo="tensorflow/tensorflow", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "release_cpu_linux", "rbe_linux_cpu", @@ -347,7 +318,7 @@ def nvidia_gpu_build_with_compute_capability( _TENSORFLOW_GPU_BUILD = Build( type_=BuildType.TENSORFLOW_GPU, repo="tensorflow/tensorflow", - docker_image=_DEFAULT_IMAGE, + image_url=_DEFAULT_IMAGE, configs=( "release_gpu_linux", "rbe_linux_cuda", @@ -412,13 +383,24 @@ def main(): "github/xla/.bazelrc", ], ) + sh( + [ + "sed", + "-i", + r"s/8\.9\.7\.29/9.1.1/g", + "github/xla/.bazelrc", + ], + ) sh(["nvidia-smi"]) - with build.docker_image.pull_and_run( - workdir=f"/github/{repo_name}", **_DEFAULT_DOCKER_OPTIONS - ) as docker_exec: - docker_exec(build.bazel_test_command()) - docker_exec(["bazel", "analyze-profile", "profile.json.gz"]) + build.pull_and_run_docker_image( + _CONTAINER_NAME, + **_DEFAULT_DOCKER_OPTIONS, + ) + docker_exec = lambda cmd: sh(["docker", "exec", _CONTAINER_NAME, *cmd]) + docker_exec(build.bazel_test_command()) + docker_exec(["bazel", "analyze-profile", "profile.json.gz"]) + sh(["docker", "stop", _CONTAINER_NAME]) if __name__ == "__main__": diff --git a/third_party/xla/build_tools/configure/BUILD b/third_party/xla/build_tools/configure/BUILD index 6b84ba404c9043..ed518510f5eae3 100644 --- a/third_party/xla/build_tools/configure/BUILD +++ b/third_party/xla/build_tools/configure/BUILD @@ -33,6 +33,7 @@ py_test( data = [ "testdata/clang.bazelrc", "testdata/cuda_clang.bazelrc", + "testdata/default_cuda_clang.bazelrc", "testdata/gcc.bazelrc", "testdata/nvcc_clang.bazelrc", "testdata/nvcc_gcc.bazelrc", diff --git a/third_party/xla/build_tools/configure/configure.py b/third_party/xla/build_tools/configure/configure.py index 39cfd7a01ecbf0..43e0f234d49cfd 100755 --- a/third_party/xla/build_tools/configure/configure.py +++ b/third_party/xla/build_tools/configure/configure.py @@ -27,11 +27,6 @@ the clang in your path. If that isn't the correct clang, you can override like `./configure.py --backend=cpu --clang_path=`. -NOTE(ddunleavy): Lots of these things should probably be outside of configure.py -but are here because of complexity in `cuda_configure.bzl` and the TF bazelrc. -Once XLA has it's own bazelrc, and cuda_configure.bzl is replaced or refactored, -we can probably make this file smaller. - TODO(ddunleavy): add more thorough validation. """ import argparse @@ -45,18 +40,9 @@ import sys from typing import Optional -_REQUIRED_CUDA_LIBRARIES = ["cublas", "cuda", "cudnn"] _DEFAULT_BUILD_AND_TEST_TAG_FILTERS = ("-no_oss",) # Assume we are being invoked from the symlink at the root of the repo _XLA_SRC_ROOT = pathlib.Path(__file__).absolute().parent -_FIND_CUDA_CONFIG = str( - _XLA_SRC_ROOT - / "third_party" - / "tsl" - / "third_party" - / "gpus" - / "find_cuda_config.py" -) _XLA_BAZELRC_NAME = "xla_configure.bazelrc" _KW_ONLY_IF_PYTHON310 = {"kw_only": True} if sys.version_info >= (3, 10) else {} @@ -224,11 +210,12 @@ class DiscoverablePathsAndVersions: ld_library_path: Optional[str] = None # CUDA specific - cublas_version: Optional[str] = None - cuda_toolkit_path: Optional[str] = None + cuda_version: Optional[str] = None cuda_compute_capabilities: Optional[list[str]] = None cudnn_version: Optional[str] = None - nccl_version: Optional[str] = None + local_cuda_path: Optional[str] = None + local_cudnn_path: Optional[str] = None + local_nccl_path: Optional[str] = None def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): """Gets paths and versions as needed by the config. @@ -247,7 +234,7 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): ) # Notably, we don't use `_find_executable_or_die` for lld, as it changes - # which commands it accepts based on it's name! ld.lld is symlinked to a + # which commands it accepts based on its name! ld.lld is symlinked to a # different executable just called lld, which should not be invoked # directly. self.lld_path = self.lld_path or shutil.which("ld.lld") @@ -261,64 +248,6 @@ def get_relevant_paths_and_versions(self, config: "XLAConfigOptions"): if not self.cuda_compute_capabilities: self.cuda_compute_capabilities = _get_cuda_compute_capabilities_or_die() - self._get_cuda_libraries_paths_and_versions_if_needed(config) - - def _get_cuda_libraries_paths_and_versions_if_needed( - self, config: "XLAConfigOptions" - ): - """Gets cuda paths and versions if user left any unspecified. - - This uses `find_cuda_config.py` to find versions for all libraries in - `_REQUIRED_CUDA_LIBRARIES`. - - Args: - config: config that determines which libraries should be found. - """ - should_find_nccl = config.using_nccl and self.nccl_version is None - any_cuda_config_unset = any([ - self.cublas_version is None, - self.cuda_toolkit_path is None, - self.cudnn_version is None, - should_find_nccl, - ]) - - maybe_nccl = ["nccl"] if should_find_nccl else [] - - if any_cuda_config_unset: - logging.info( - "Some CUDA config versions and paths were not provided, " - "so trying to find them using find_cuda_config.py" - ) - try: - find_cuda_config_proc = subprocess.run( - [ - sys.executable, - _FIND_CUDA_CONFIG, - *_REQUIRED_CUDA_LIBRARIES, - *maybe_nccl, - ], - capture_output=True, - check=True, - text=True, - ) - except subprocess.CalledProcessError as e: - logging.info("Command %s failed. Is CUDA installed?", e.cmd) - logging.info("Dumping %s ouptut:\n %s", e.cmd, e.output) - raise e - - cuda_config = dict( - tuple(line.split(": ")) - for line in find_cuda_config_proc.stdout.strip().split("\n") - ) - - self.cublas_version = self.cublas_version or cuda_config["cublas_version"] - self.cuda_toolkit_path = ( - self.cuda_toolkit_path or cuda_config["cuda_toolkit_path"] - ) - self.cudnn_version = self.cudnn_version or cuda_config["cudnn_version"] - if should_find_nccl: - self.nccl_version = self.nccl_version or cuda_config["nccl_version"] - @dataclasses.dataclass(frozen=True, **_KW_ONLY_IF_PYTHON310) class XLAConfigOptions: @@ -333,7 +262,6 @@ class XLAConfigOptions: # CUDA specific cuda_compiler: CudaCompiler using_nccl: bool - using_tensorrt: bool def to_bazelrc_lines( self, @@ -392,19 +320,31 @@ def to_bazelrc_lines( ) # Lines needed for CUDA backend regardless of CUDA/host compiler + if dpav.cuda_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDA_VERSION={dpav.cuda_version}" + ) rc.append( - f"build --action_env CUDA_TOOLKIT_PATH={dpav.cuda_toolkit_path}" - ) - rc.append(f"build --action_env TF_CUBLAS_VERSION={dpav.cublas_version}") - rc.append( - "build --action_env" - f" TF_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" + "build:cuda --repo_env" + f" HERMETIC_CUDA_COMPUTE_CAPABILITIES={','.join(dpav.cuda_compute_capabilities)}" ) - rc.append(f"build --action_env TF_CUDNN_VERSION={dpav.cudnn_version}") - rc.append(f"build --repo_env TF_NEED_TENSORRT={int(self.using_tensorrt)}") - if self.using_nccl: - rc.append(f"build --action_env TF_NCCL_VERSION={dpav.nccl_version}") - else: + if dpav.cudnn_version: + rc.append( + f"build:cuda --repo_env HERMETIC_CUDNN_VERSION={dpav.cudnn_version}" + ) + if dpav.local_cuda_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDA_PATH={dpav.local_cuda_path}" + ) + if dpav.local_cudnn_path: + rc.append( + f"build:cuda --repo_env LOCAL_CUDNN_PATH={dpav.local_cudnn_path}" + ) + if dpav.local_nccl_path: + rc.append( + f"build:cuda --repo_env LOCAL_NCCL_PATH={dpav.local_nccl_path}" + ) + if not self.using_nccl: rc.append("build --config nonccl") elif self.backend == Backend.ROCM: pass @@ -476,7 +416,6 @@ def _parse_args(): default="-Wno-sign-compare", ) parser.add_argument("--nccl", action="store_true") - parser.add_argument("--tensorrt", action="store_true") # Path and version overrides path_help = "Optional: will be found on PATH if possible." @@ -492,13 +431,35 @@ def _parse_args(): parser.add_argument("--lld_path", help=path_help) # CUDA specific - find_cuda_config_help = ( - "Optional: will be found using `find_cuda_config.py` if flag is not set." + parser.add_argument( + "--cuda_version", + help="Optional: CUDA will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--cudnn_version", + help="Optional: CUDNN will be downloaded by Bazel if the flag is set", + ) + parser.add_argument( + "--local_cuda_path", + help=( + "Optional: Local CUDA dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_cudnn_path", + help=( + "Optional: Local CUDNN dir will be used in dependencies if the flag" + " is set" + ), + ) + parser.add_argument( + "--local_nccl_path", + help=( + "Optional: Local NCCL dir will be used in dependencies if the flag" + " is set" + ), ) - parser.add_argument("--cublas_version", help=find_cuda_config_help) - parser.add_argument("--cuda_toolkit_path", help=find_cuda_config_help) - parser.add_argument("--cudnn_version", help=find_cuda_config_help) - parser.add_argument("--nccl_version", help=find_cuda_config_help) return parser.parse_args() @@ -518,7 +479,6 @@ def main(): python_bin_path=args.python_bin_path, compiler_options=args.compiler_options, using_nccl=args.nccl, - using_tensorrt=args.tensorrt, ) bazelrc_lines = config.to_bazelrc_lines( @@ -527,11 +487,12 @@ def main(): gcc_path=args.gcc_path, lld_path=args.lld_path, ld_library_path=args.ld_library_path, - cublas_version=args.cublas_version, - cuda_compute_capabilities=args.cuda_compute_capabilities, - cuda_toolkit_path=args.cuda_toolkit_path, + cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, - nccl_version=args.nccl_version, + cuda_compute_capabilities=args.cuda_compute_capabilities, + local_cuda_path=args.local_cuda_path, + local_cudnn_path=args.local_cudnn_path, + local_nccl_path=args.local_nccl_path, ) ) diff --git a/third_party/xla/build_tools/configure/configure_test.py b/third_party/xla/build_tools/configure/configure_test.py index e29e718b78547d..8457ff40aea3ee 100644 --- a/third_party/xla/build_tools/configure/configure_test.py +++ b/third_party/xla/build_tools/configure/configure_test.py @@ -34,12 +34,20 @@ # CUDA specific paths and versions _CUDA_SPECIFIC_PATHS_AND_VERSIONS = { - "cublas_version": "12.3", - "cuda_toolkit_path": "/usr/local/cuda-12.2", + "cuda_version": '"12.1.1"', "cuda_compute_capabilities": ["7.5"], - "cudnn_version": "8", + "cudnn_version": '"8.6"', + "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", +} +_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH = { + "cuda_compute_capabilities": [ + "sm_50", + "sm_60", + "sm_70", + "sm_80", + "compute_90", + ], "ld_library_path": "/usr/local/nvidia/lib:/usr/local/nvidia/lib64", - "nccl_version": "2", } @@ -66,6 +74,11 @@ def setUpClass(cls): with (testdata / "cuda_clang.bazelrc").open() as f: cls.cuda_clang_bazelrc_lines = [line.strip() for line in f.readlines()] + with (testdata / "default_cuda_clang.bazelrc").open() as f: + cls.default_cuda_clang_bazelrc_lines = [ + line.strip() for line in f.readlines() + ] + with (testdata / "nvcc_clang.bazelrc").open() as f: cls.nvcc_clang_bazelrc_lines = [line.strip() for line in f.readlines()] @@ -85,7 +98,6 @@ def test_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -107,7 +119,6 @@ def test_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -128,7 +139,6 @@ def test_cuda_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.CLANG, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -141,6 +151,27 @@ def test_cuda_clang_bazelrc(self): self.assertEqual(bazelrc_lines, self.cuda_clang_bazelrc_lines) + def test_default_cuda_clang_bazelrc(self): + config = XLAConfigOptions( + backend=Backend.CUDA, + os=OS.LINUX, + python_bin_path=_PYTHON_BIN_PATH, + host_compiler=HostCompiler.CLANG, + compiler_options=list(_COMPILER_OPTIONS), + cuda_compiler=CudaCompiler.CLANG, + using_nccl=False, + ) + + bazelrc_lines = config.to_bazelrc_lines( + DiscoverablePathsAndVersions( + clang_path=_CLANG_PATH, + clang_major_version=17, + **_CUDA_COMPUTE_CAPABILITIES_AND_LD_LIBRARY_PATH, + ) + ) + + self.assertEqual(bazelrc_lines, self.default_cuda_clang_bazelrc_lines) + def test_nvcc_clang_bazelrc(self): config = XLAConfigOptions( backend=Backend.CUDA, @@ -150,7 +181,6 @@ def test_nvcc_clang_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( @@ -172,7 +202,6 @@ def test_nvcc_gcc_bazelrc(self): compiler_options=list(_COMPILER_OPTIONS), cuda_compiler=CudaCompiler.NVCC, using_nccl=False, - using_tensorrt=False, ) bazelrc_lines = config.to_bazelrc_lines( diff --git a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc index a6e7a423bfc490..502bc8541c1285 100644 --- a/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/cuda_clang.bazelrc @@ -3,11 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config cuda_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc new file mode 100644 index 00000000000000..4623f6f52073fa --- /dev/null +++ b/third_party/xla/build_tools/configure/testdata/default_cuda_clang.bazelrc @@ -0,0 +1,19 @@ +build --action_env CLANG_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build --repo_env CC=/usr/lib/llvm-18/bin/clang +build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang +build --config cuda_clang +build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 +build --config nonccl +build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 +build --action_env PYTHON_BIN_PATH=/usr/bin/python3 +build --python_path /usr/bin/python3 +test --test_env LD_LIBRARY_PATH +test --test_size_filters small,medium +build --copt -Wno-sign-compare +build --copt -Wno-error=unused-command-line-argument +build --copt -Wno-gnu-offsetof-extensions +build --build_tag_filters -no_oss +build --test_tag_filters -no_oss +test --build_tag_filters -no_oss +test --test_tag_filters -no_oss diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc index e147dbd687b118..8cd19224698311 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_clang.bazelrc @@ -3,11 +3,9 @@ build --repo_env CC=/usr/lib/llvm-18/bin/clang build --repo_env BAZEL_COMPILER=/usr/lib/llvm-18/bin/clang build --config nvcc_clang build --action_env CLANG_CUDA_COMPILER_PATH=/usr/lib/llvm-18/bin/clang -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc index 863209697362de..be90a87545368b 100644 --- a/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc +++ b/third_party/xla/build_tools/configure/testdata/nvcc_gcc.bazelrc @@ -1,10 +1,8 @@ build --action_env GCC_HOST_COMPILER_PATH=/usr/bin/gcc build --config cuda -build --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-12.2 -build --action_env TF_CUBLAS_VERSION=12.3 -build --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 -build --action_env TF_CUDNN_VERSION=8 -build --repo_env TF_NEED_TENSORRT=0 +build:cuda --repo_env HERMETIC_CUDA_VERSION="12.1.1" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES=7.5 +build:cuda --repo_env HERMETIC_CUDNN_VERSION="8.6" build --config nonccl build --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 build --action_env PYTHON_BIN_PATH=/usr/bin/python3 diff --git a/third_party/xla/build_tools/rocm/run_xla.sh b/third_party/xla/build_tools/rocm/run_xla.sh index 22c3f6551dce36..d7eee422ec01db 100755 --- a/third_party/xla/build_tools/rocm/run_xla.sh +++ b/third_party/xla/build_tools/rocm/run_xla.sh @@ -41,7 +41,7 @@ if [[ -n $1 ]]; then ROCM_INSTALL_DIR=$1 else if [[ -z "${ROCM_PATH}" ]]; then - ROCM_INSTALL_DIR=/opt/rocm-6.0.2 + ROCM_INSTALL_DIR=/opt/rocm-6.2.0 else ROCM_INSTALL_DIR=$ROCM_PATH fi diff --git a/third_party/xla/docs/build_from_source.md b/third_party/xla/docs/build_from_source.md index c273f7f3cdf8c0..8b30f9995d08e3 100644 --- a/third_party/xla/docs/build_from_source.md +++ b/third_party/xla/docs/build_from_source.md @@ -65,12 +65,11 @@ docker exec xla_gpu ./configure.py --backend=CUDA docker exec xla_gpu bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` -If you want to build XLA targets with GPU support without Docker you need to -install the following additional dependencies: -[`cuda-12.3`](https://developer.nvidia.com/cuda-downloads), -[`cuDNN-8.9`](https://developer.nvidia.com/cudnn). +For more details regarding +[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) -Then configure and build targets using the following commands: +You can build XLA targets with GPU support without Docker as well. Configure and +build targets using the following commands: ``` ./configure.py --backend=CUDA @@ -79,4 +78,4 @@ bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` For more details regarding -[TensorFlow's GPU docker images you can check out this document.](https://www.tensorflow.org/install/source#gpu_support_3) +[hermetic CUDA you can check out this document.](docs/hermetic_cuda.md) diff --git a/third_party/xla/docs/custom_call.md b/third_party/xla/docs/custom_call.md index eb97ad7da79de8..2471df68331057 100644 --- a/third_party/xla/docs/custom_call.md +++ b/third_party/xla/docs/custom_call.md @@ -1,4 +1,4 @@ -# XLA custom calls +# XLA Custom Calls This document describes how to write and use XLA custom calls using XLA FFI library. Custom call is a mechanism to describe an external "operation" in the @@ -23,6 +23,269 @@ hides all the low level details of underlying C APIs from the end user. > and custom call target references or to use C-style namespacing directly in > the function name. +## JAX + XLA Custom Calls + +See [JAX documentation](https://jax.readthedocs.io/en/latest/ffi.html) for +end to end examples of integrating custom calls and XLA FFI with JAX. + +## XLA FFI Binding + +XLA FFI binding is a compile-time specification of the custom call signature: +custom call arguments, attributes and their types, and additional parameters +passed via the execution context (i.e., gpu stream for GPU backend). XLA FFI +finding can be bound to any C++ callable (function pointer, lambda, etc.) with +compatible `operator()` signature. Constructed handler decodes XLA FFI call +frame (defined by the stable C API), type check all parameters, and forward +decoded results to the user-defined callback. + +XLA FFI binding heavily relies on template metaprogramming to be be able to +compile constructed handler to the most efficient machine code. Run time +overheads are in order of a couple of nanoseconds for each custom call +parameter. + +XLA FFI customization points implemented as template specializations, and +users can define how to decode their custom types, i.e., it is possible +to define custom decoding for user-defined `enum class` types. + +### Returning Errors From Custom Calls + +Custom call implementations must return `xla::ffi::Error` value to signal +success or error to XLA runtime. It is similar to `absl::Status`, and has +the same set of error codes. We do not use `absl::Status` because it does +not have a stable ABI and it would be unsafe to pass it between dynamically +loaded custom call library, and XLA itself. + +```c++ +// Handler that always returns an error. +auto always_error = Ffi::Bind().To( + []() { return Error(ErrorCode::kInternal, "Oops!"); }); + +// Handler that always returns a success. +auto always_success = Ffi::Bind().To( + []() { return Error::Success(); }); + +``` + +### Buffer Arguments And Results + +XLA uses destination passing style for results: custom calls (or any other XLA +operations for that matter) do not allocate memory for results, and instead +write into destinations passed by XLA runtime. XLA uses static buffer +assignment, and allocates buffers for all values based on their live ranges at +compile time. + +Results passed to FFI handlers wrapped into a `Result` template, that +has a pointer-like semantics: `operator->` gives access to the underlying +parameter. + +`AnyBuffer` arguments and results gives access to custom call buffer parameters +of any data type. This is useful when custom call has a generic implementation +that works for multiple data types, and custom call implementation does run time +dispatching based on data type. `AnyBuffer` gives access to the buffer data +type, dimensions, and a pointer to the buffer itself. + +```mlir +%0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = 4 : i32 +} : (tensor<2x2xf32>) -> tensor<2x2xf32> +``` + + +```c++ +// Buffers of any rank and data type. +auto handler = Ffi::Bind().Arg().Ret().To( + [](AnyBuffer arg, Result res) -> Error { + void* arg_data = arg.untyped_data(); + void* res_data = res->untyped_data(); + return Error::Success(); + }); +``` + +### Constrained Buffer Arguments And Results + +`Buffer` allows to add constraints on the buffer data type and rank, and they +will be automatically checked by the handler and return an error to XLA runtime, +if run time arguments do not match the FFI handler signature. + +```c++ +// Buffers of any rank and F32 data type. +auto handler = Ffi::Bind().Arg>().Ret>().To( + [](Buffer arg, Result> res) -> Error { + float* arg_data = arg.typed_data(); + float* res_data = res->typed_data(); + return Error::Success(); + }); +``` + +```c++ +// Buffers of rank 2 and F32 data type. +auto handler = Ffi::Bind().Arg>().Ret>().To( + [](BufferR2 arg, Result> res) -> Error { + float* arg_data = arg.typed_data(); + float* res_data = res->typed_data(); + return Error::Success(); + }); +``` + +### Variadic Arguments And Results + +If the number of arguments and result can be different in different instances of +a custom call, they can be decoded at run time using `RemainingArgs` and +`RemainingRets`. + +``` +auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To( + [](RemainingArgs args, RemainingRets results) -> Error { + ErrorOr arg = args.get(0); + ErrorOr> res = results.get(0); + + if (!arg.has_value()) { + return Error(ErrorCode::kInternal, arg.error()); + } + + if (!res.has_value()) { + return Error(ErrorCode::kInternal, res.error()); + } + + return Error::Success(); + }); +``` + +Variadic arguments and results can be declared after regular arguments and +results, however binding regular arguments and results after variadic one is +illegal. + +```c++ +auto handler = + Ffi::Bind() + .Arg() + .RemainingArgs() + .Ret() + .RemainingRets() + .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret, + RemainingRets results) -> Error { return Error::Success(); }); +``` + +### Attributes + +XLA FFI supports automatic decoding of `mlir::DictionaryAttr` passed as a +`custom_call` `backend_config` into FFI handler arguments. + +Note: See [stablehlo RFC](https://github.com/openxla/stablehlo/blob/main/rfcs/20240312-standardize-customcallop.md) +for details, and `stablehlo.custom_call` operation specification. + +```mlir +%0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= { + i32 = 42 : i32, + str = "string" + }, + api_version = 4 : i32 +} : (tensor) -> tensor +``` + +In this example custom call has a single buffer argument and two attributes, and +XLA FFI can automatically decode them and pass to the user-defined callable. + +```c++ +auto handler = Ffi::Bind() + .Arg>() + .Attr("i32") + .Attr("str") + .To([](BufferR0 buffer, int32_t i32, std::string_view str) { + return Error::Success(); + }); +``` + +### User-Defined Enum Attributes + +XLA FFI can automatically decode integral MLIR attributes into user-defined +enums. Enum class must have the same underlying integral type, and decoding +has to be explicitly registered with XLA FFI. + + +```mlir +%0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= { + command = 0 : i32 + }, + api_version = 4 : i32 +} : (tensor) -> tensor +``` + +```c++ +enum class Command : int32_t { + kAdd = 0, + kMul = 1, +}; + +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command); + +auto handler = Ffi::Bind().Attr("command").To( + [](Command command) -> Error { return Error::Success(); }); +``` + +### Binding All Custom Call Attributes + +It is possible to get access to all custom call attributes as a dictionary +and lazily decode only the attributes that are needed at run time. + +```c++ +auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error { + ErrorOr i32 = attrs.get("i32"); + return Error::Success(); +}); +``` + +### User-defined Struct Attributes + +XLA FFI can decode dictionary attributes into user-defined structs. + +```mlir +%0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= { + range = { lo = 0 : i64, hi = 42 : i64 } + }, + api_version = 4 : i32 +} : (tensor) -> tensor +``` + +In example above `range` is an `mlir::DictionaryAttr` attribute, and instead +of accessing dictionary fields by name, it can be automatically decoded as +a C++ struct. Decoding has to be explicitly registered with a +`XLA_FFI_REGISTER_STRUCT_ATTR_DECODING` macro (behind the scene it defines +a template specialization in `::xla::ffi` namespace, thus macro must be added to +the global namespace). + +```c++ +struct Range { + int64_t lo; + int64_t hi; +}; + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember("i64"), + StructMember("i64")); + +auto handler = Ffi::Bind().Attr("range").To([](Range range) -> Error{ + return Error::Success(); +}); +``` + +Custom attributes can be loaded from a dictionary, just like any other +attribute. In example below, all custom call attributes decoded as a +`Dictionary`, and a `range` can be accessed by name. + +```c++ +auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error { + ErrorOr range = attrs.get("range"); + return Error::Success(); +}); +``` + ## Create a custom call on CPU You can create an HLO instruction that represents a custom call via XLA's client diff --git a/third_party/xla/docs/determinism.md b/third_party/xla/docs/determinism.md index d8cd934e5cb1fc..09a1e4fba2241a 100644 --- a/third_party/xla/docs/determinism.md +++ b/third_party/xla/docs/determinism.md @@ -8,6 +8,10 @@ once and avoid it in subsequent compilations. Otherwise due to fluctuations in measurements different kernels can be picked as the fastest ones in different compilation runs. +`--xla_gpu_require_complete_aot_autotune_results` can be used to ensure that no +autotuning happens on repeated compilations - they either reuse compatible +results of previous runs or fail. + ## Execution Programs compiled by XLA can be non-deterministic on operations like scatter, diff --git a/third_party/xla/docs/developer_guide.md b/third_party/xla/docs/developer_guide.md index 53b3efcd8cab5c..b736309b7fbc59 100644 --- a/third_party/xla/docs/developer_guide.md +++ b/third_party/xla/docs/developer_guide.md @@ -64,6 +64,16 @@ docker exec xla ./configure.py --backend=CUDA docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/... ``` +**NB:** please note that with hermetic CUDA rules, you don't have to build XLA +in Docker. You can build XLA for GPU on your machine without GPUs and without +NVIDIA driver installed: + +```sh +./configure.py --backend=CUDA + +bazel build --test_output=all --spawn_strategy=sandboxed //xla/... +``` + Your first build will take quite a while because it has to build the entire stack, including XLA, MLIR, and StableHLO. diff --git a/third_party/xla/docs/hermetic_cuda.md b/third_party/xla/docs/hermetic_cuda.md new file mode 100644 index 00000000000000..18cc228d743461 --- /dev/null +++ b/third_party/xla/docs/hermetic_cuda.md @@ -0,0 +1,544 @@ +# Hermetic CUDA overview + +Hermetic CUDA uses a specific downloadable version of CUDA instead of the user’s +locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL distributions, +and then use CUDA libraries and tools as dependencies in various Bazel targets. +This enables more reproducible builds for Google ML projects and supported CUDA +versions. + +## Supported hermetic CUDA, CUDNN versions + +The supported CUDA versions are specified in `CUDA_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The supported CUDNN versions are specified in `CUDNN_REDIST_JSON_DICT` +dictionary, +[third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + +The `.bazelrc` files of individual projects have `HERMETIC_CUDA_VERSION`, +`HERMETIC_CUDNN_VERSION` environment variables set to the versions used by +default when `--config=cuda` is specified in Bazel command options. + +## Environment variables controlling the hermetic CUDA/CUDNN versions + +`HERMETIC_CUDA_VERSION` environment variable should consist of major, minor and +patch CUDA version, e.g. `12.3.2`. +`HERMETIC_CUDNN_VERSION` environment variable should consist of major, minor and +patch CUDNN version, e.g. `9.1.1`. + +Three ways to set the environment variables for Bazel commands: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=HERMETIC_CUDA_VERSION="12.3.2" \ +--repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + +# OR set the environment variable globally in your shell: +export HERMETIC_CUDA_VERSION="12.3.2" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export HERMETIC_CUDNN_VERSION="9.1.1" +``` + +If `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` are not present, the +hermetic CUDA/CUDNN repository rules will look up `TF_CUDA_VERSION` and +`TF_CUDNN_VERSION` environment variables values. This is made for the backward +compatibility with non-hermetic CUDA/CUDNN repository rules. + +The mapping between CUDA version and NCCL distribution version to be downloaded +is specified in [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + +## Upgrade hermetic CUDA/CUDNN version +1. Create and submit a pull request with updated `CUDA_REDIST_JSON_DICT`, + `CUDA_REDIST_JSON_DICT` dictionaries in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl). + + Update `CUDA_NCCL_WHEELS` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + + Update `REDIST_VERSIONS_TO_BUILD_TEMPLATES` in + [third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl) + if needed. + +2. For RBE executions: update `TF_CUDA_VERSION` and/or `TF_CUDNN_VERSION` in + [toolchains/remote_config/rbe_config.bzl](https://github.com/openxla/xla/blob/main/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl). + +3. For RBE executions: update `cuda_version`, `cudnn_version`, `TF_CUDA_VERSION` + and `TF_CUDNN_VERSION` in + [toolchains/remote_config/configs.bzl](https://github.com/openxla/xla/blob/main/tools/toolchains/remote_config/configs.bzl). + +4. For each Google ML project create a separate pull request with updated + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` in `.bazelrc` file. + + The PR presubmit job executions will launch bazel tests and download hermetic + CUDA/CUDNN distributions. Verify that the presubmit jobs passed before + submitting the PR. + +## Pointing to CUDA/CUDNN/NCCL redistributions on local file system + +You can use the local CUDA/CUDNN/NCCL dirs as a source of redistributions. The following additional environment variables are required: + +``` +LOCAL_CUDA_PATH +LOCAL_CUDNN_PATH +LOCAL_NCCL_PATH +``` + +Example: + +``` +# Add an entry to your `.bazelrc` file +build:cuda --repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +build:cuda --repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +build:cuda --repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR pass it directly to your specific build command +bazel build --config=cuda \ +--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ +--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ +--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + +# OR set the environment variable globally in your shell: +export LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" +export LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" +export LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" +``` + +The structure of the folders inside CUDA dir should be the following (as if the archived redistributions were unpacked into one place): + +``` +/ + include/ + bin/ + lib/ + nvvm/ +``` + +The structure of the folders inside CUDNN dir should be the following: + +``` + + include/ + lib/ +``` + +The structure of the folders inside NCCL dir should be the following: + +``` + + include/ + lib/ +``` + +## Custom CUDA/CUDNN archives and NCCL wheels + +There are three options that allow usage of custom CUDA/CUDNN distributions. + +### Custom CUDA/CUDNN redistribution JSON files + +This option allows to use custom distributions for all CUDA/CUDNN dependencies +in Google ML projects. + +1. Create `cuda_redist.json` and/or `cudnn_redist.json` files. + + `cuda_redist.json` show follow the format below: + + ``` + { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + `cudnn_redist.json` show follow the format below: + + ``` + { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the downstream project dependent on XLA, update the hermetic cuda JSON + repository call in `WORKSPACE` file. Both web links and local file paths are + allowed. Example: + + ``` + _CUDA_JSON_DICT = { + "12.4.0": [ + "file:///home/user/Downloads/redistrib_12.4.0_updated.json", + ], + } + + _CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], + } + + cuda_json_init_repository( + cuda_json_dict = _CUDA_JSON_DICT, + cudnn_json_dict = _CUDNN_JSON_DICT, + ) + ``` + + If JSON files contain relative paths to distributions, the path prefix should + be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. Example + + ``` + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", + ) + ``` + +### Custom CUDA/CUDNN distributions + +This option allows to use custom distributions for some CUDA/CUDNN dependencies +in Google ML projects. + +1. In the downstream project dependent on XLA, remove the lines below: + + ``` + <...> + "CUDA_REDIST_JSON_DICT", + <...> + "CUDNN_REDIST_JSON_DICT", + <...> + + cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT, + ) + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + ``` + +2. In the same `WORKSPACE` file, create dictionaries with distribution paths. + + The dictionary with CUDA distributions show follow the format below: + + ``` + _CUSTOM_CUDA_REDISTRIBUTIONS = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + } + }, + } + ``` + + The dictionary with CUDNN distributions show follow the format below: + + ``` + _CUSTOM_CUDNN_REDISTRIBUTIONS = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn/linux-x86_64/cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + } + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn/linux-sbsa/cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + } + } + } + } + ``` + + The `relative_path` field can be replaced with `full_path` for the full URLs + and absolute local paths starting with `file:///`. + +2. In the same `WORKSPACE` file, pass the created dictionaries to the repository + rule. If the dictionaries contain relative paths to distributions, the path + prefix should be updated in `cuda_redist_init_repositories()` and + `cudnn_redist_init_repository()` calls. + + ``` + cuda_redist_init_repositories( + cuda_redistributions = _CUSTOM_CUDA_REDISTRIBUTIONS, + cuda_redist_path_prefix = "file:///home/usr/Downloads/dists/", + ) + + cudnn_redist_init_repository( + cudnn_redistributions = _CUSTOM_CUDNN_REDISTRIBUTIONS, + cudnn_redist_path_prefix = "file:///home/usr/Downloads/dists/cudnn/" + ) + ``` +### Combination of the options above + +In the example below, `CUDA_REDIST_JSON_DICT` is merged with custom JSON data in +`_CUDA_JSON_DICT`, and `CUDNN_REDIST_JSON_DICT` is merged with +`_CUDNN_JSON_DICT`. + +The distributions data in `_CUDA_DIST_DICT` overrides the content of resulting +CUDA JSON file, and the distributions data in `_CUDNN_DIST_DICT` overrides the +content of resulting CUDNN JSON file. The NCCL wheels data is merged from +`CUDA_NCCL_WHEELS` and `_NCCL_WHEEL_DICT`. + +``` +load( + //third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDA_NCCL_WHEELS", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_PATH_PREFIX", + "CUDNN_REDIST_JSON_DICT", +) + +_CUDA_JSON_DICT = { + "12.4.0": [ + "file:///usr/Downloads/redistrib_12.4.0_updated.json", + ], +} + +_CUDNN_JSON_DICT = { + "9.0.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.0.0.json", + ], +} + +cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT | _CUDA_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT | _CUDNN_JSON_DICT, +) + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +_CUDA_DIST_DICT = { + "cuda_cccl": { + "linux-x86_64": { + "relative_path": "cuda_cccl-linux-x86_64-12.4.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "cuda_cccl-linux-sbsa-12.4.99-archive.tar.xz", + }, + }, + "libcusolver": { + "linux-x86_64": { + "full_path": "file:///usr/Downloads/dists/libcusolver-linux-x86_64-11.6.0.99-archive.tar.xz", + }, + "linux-sbsa": { + "relative_path": "libcusolver-linux-sbsa-11.6.0.99-archive.tar.xz", + }, + }, +} + +_CUDNN_DIST_DICT = { + "cudnn": { + "linux-x86_64": { + "cuda12": { + "relative_path": "cudnn-linux-x86_64-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + "linux-sbsa": { + "cuda12": { + "relative_path": "cudnn-linux-sbsa-9.0.0.312_cuda12-archive.tar.xz", + }, + }, + }, +} + +cudnn_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS | _CUDA_DIST_DICT, + cuda_redist_path_prefix = "file:///usr/Downloads/dists/", +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS | _CUDNN_DIST_DICT, + cudnn_redist_path_prefix = "file:///usr/Downloads/dists/cudnn/" +) + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +_NCCL_WHEEL_DICT = { + "12.4.0": { + "x86_64-unknown-linux-gnu": { + "url": "https://files.pythonhosted.org/packages/38/00/d0d4e48aef772ad5aebcf70b73028f88db6e5640b36c38e90445b7a57c45/nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", + }, + }, +} + +nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS | _NCCL_WHEEL_DICT, +) +``` + +## DEPRECATED: Non-hermetic CUDA/CUDNN usage +Though non-hermetic CUDA/CUDNN usage is deprecated, it might be used for +some experiments currently unsupported officially (for example, building wheels +on Windows with CUDA). + +Here are the steps to use non-hermetic CUDA installed locally in Google ML +projects: + +1. Delete calls to hermetic CUDA repository rules from the `WORKSPACE` + file of the project dependent on XLA. + +2. Add the calls to non-hermetic CUDA repository rules to the bottom of the + `WORKSPACE` file. + + For XLA and JAX: + ``` + load("@local_tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@local_tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + + For Tensorflow: + ``` + load("@local_tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure") + cuda_configure(name = "local_config_cuda") + load("@local_tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure") + nccl_configure(name = "local_config_nccl") + ``` + +3. Set the following environment variables directly in your shell or in + `.bazelrc` file as shown below: + ``` + build:cuda --action_env=TF_CUDA_VERSION= + build:cuda --action_env=TF_CUDNN_VERSION= + build:cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES= + build:cuda --action_env=LD_LIBRARY_PATH= + build:cuda --action_env=CUDA_TOOLKIT_PATH= + build:cuda --action_env=TF_CUDA_PATHS= + build:cuda --action_env=NCCL_INSTALL_PATH= + ``` + + Note that `TF_CUDA_VERSION` and `TF_CUDNN_VERSION` should consist of major and + minor versions only (e.g. `12.3` for CUDA and `9.1` for CUDNN). + +4. Now you can run `bazel` command to use locally installed CUDA and CUDNN. + + For XLA, no changes in the command options are needed. + + For JAX, use `--override_repository=tsl=` flag in the Bazel command + options. + + For Tensorflow, use `--override_repository=local_tsl=` flag in the + Bazel command options. + +## Configure hermetic CUDA + +1. In the downstream project dependent on XLA, add the following lines to the + bottom of the `WORKSPACE` file: + + Note: use @local_tsl instead of @tsl in Tensorflow project. + + ``` + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", + ) + + cuda_json_init_repository() + + load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", + ) + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", + ) + + cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, + ) + + cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, + ) + + load( + "@local_tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", + ) + + cuda_configure(name = "local_config_cuda") + + load( + "@local_tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", + ) + + nccl_redist_init_repository() + + load( + "@local_tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", + ) + + nccl_configure(name = "local_config_nccl") + ``` + +2. To select specific versions of hermetic CUDA and CUDNN, set the + `HERMETIC_CUDA_VERSION` and `HERMETIC_CUDNN_VERSION` environment variables + respectively. Use only supported versions. You may set the environment + variables directly in your shell or in `.bazelrc` file as shown below: + ``` + build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" + build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" + build:cuda --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" + ``` + +3. To enable Hermetic CUDA during test execution, or when running a binary via + bazel, make sure to add `--@local_config_cuda//cuda:include_hermetic_cuda_libs=true` + flag to your bazel command. You can provide it either directly in a shell or + in `.bazelrc`: + ``` + test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true + ``` + The flag is needed to make sure that CUDA dependencies are properly provided + to test executables. The flag is false by default to avoid unwanted coupling + of Google-released Python wheels to CUDA binaries. diff --git a/third_party/xla/docs/images/openxla.svg b/third_party/xla/docs/images/openxla.svg new file mode 100644 index 00000000000000..bb97db4af1c268 --- /dev/null +++ b/third_party/xla/docs/images/openxla.svg @@ -0,0 +1,266 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/xla/docs/images/openxla_dark.svg b/third_party/xla/docs/images/openxla_dark.svg new file mode 100644 index 00000000000000..ae2dc4c874c13f --- /dev/null +++ b/third_party/xla/docs/images/openxla_dark.svg @@ -0,0 +1,255 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/third_party/xla/docs/operation_semantics.md b/third_party/xla/docs/operation_semantics.md index 55ed575f90f2de..55849974726628 100644 --- a/third_party/xla/docs/operation_semantics.md +++ b/third_party/xla/docs/operation_semantics.md @@ -1214,8 +1214,8 @@ A set of element-wise binary arithmetic operations is supported. Where `Op` is one of `Add` (addition), `Sub`(subtraction), `Mul` (multiplication), `Div` (division), `Pow` (power), `Rem` (remainder), `Max` -(maximum), `Min` (minimum), `LogicalAnd` (logical AND), `LogicalOr` (logical -OR), `LogicalXor` (logical XOR), `ShiftLeft` (Left Shift), +(maximum), `Min` (minimum), `And` (logical AND), `Or` (logical +OR), `Xor` (logical XOR), `ShiftLeft` (Left Shift), `ShiftRightArithmetic` (arithmetic Right Shift), `ShiftRightLogical` (logical Right Shift), `Atan2` (2-argument arctangent), or `Complex` (combines real and imaginary parts into a complex number) @@ -1305,12 +1305,22 @@ XlaBuilder supports these element-wise unary functions: `Abs(operand)` Element-wise abs `x -> |x|`. +`Cbrt(operand)` Element-wise cubic root operation `x -> cbrt(x)`. + `Ceil(operand)` Element-wise ceil `x -> ⌈x⌉`. +`Clz(operand)` Element-wise count leading zeros. + `Cos(operand)` Element-wise cosine `x -> cos(x)`. +`Erf(operand)` Element-wise error function `x -> erf(x)` where + +$$\text{erf}(x) = \frac{2}{\sqrt{\pi}}\int_0^x e^{-t^2} \, dt$$. + `Exp(operand)` Element-wise natural exponential `x -> e^x`. +`Expm1(operand)` Element-wise natural exponential minus one `x -> e^x - 1`. + `Floor(operand)` Element-wise floor `x -> ⌊x⌋`. `Imag(operand)` Element-wise imaginary part of a complex (or real) @@ -1323,19 +1333,25 @@ if and only if the corresponding input element is finite. `Log(operand)` Element-wise natural logarithm `x -> ln(x)`. -`LogicalNot(operand)` Element-wise logical not `x -> !(x)`. +`Log1p(operand)` Element-wise shifted natural logarithm `x -> ln(1+x)`. `Logistic(operand)` Element-wise logistic function computation `x -> logistic(x)`. +`Neg(operand)` Element-wise negation `x -> -x`. + +`Not(operand)` Element-wise logical not `x -> !(x)`. + `PopulationCount(operand)` Computes the number of bits set in each element of `operand`. -`Neg(operand)` Element-wise negation `x -> -x`. - `Real(operand)` Element-wise real part of a complex (or real) shape. `x -> real(x)`. If the operand is a floating point type, returns the same value. +`Round(operand)` Element-wise rounding, ties away from zero. + +`RoundNearestEven(operand)` Element-wise rounding, ties to nearest even. + `Rsqrt(operand)` Element-wise reciprocal of square root operation `x -> 1.0 / sqrt(x)`. @@ -1345,16 +1361,14 @@ $$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & using the comparison operator of the element type of `operand`. +`Sin(operand)` Element-wise sine `x -> sin(x)`. + `Sqrt(operand)` Element-wise square root operation `x -> sqrt(x)`. -`Cbrt(operand)` Element-wise cubic root operation `x -> cbrt(x)`. +`Tan(operand)` Element-wise tangent `x -> tan(x)`. `Tanh(operand)` Element-wise hyperbolic tangent `x -> tanh(x)`. -`Round(operand)` Element-wise rounding, ties away from zero. - -`RoundNearestEven(operand)` Element-wise rounding, ties to nearest even. - Arguments | Type | Semantics --------- | ------- | --------------------------- `operand` | `XlaOp` | The operand to the function diff --git a/third_party/xla/opensource_only.files b/third_party/xla/opensource_only.files index baafd35265caaf..5759a24c5d6d54 100644 --- a/third_party/xla/opensource_only.files +++ b/third_party/xla/opensource_only.files @@ -34,6 +34,7 @@ third_party/py/python_init_toolchains.bzl: third_party/py/python_repo.bzl: third_party/python_runtime/BUILD: third_party/repo.bzl: +third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: third_party/stablehlo/BUILD: tools/toolchains/BUILD: tools/toolchains/clang6/BUILD: diff --git a/third_party/xla/third_party/gloo/gloo.BUILD b/third_party/xla/third_party/gloo/gloo.BUILD index 99a8e32c69c8f6..2de0c852ebf007 100644 --- a/third_party/xla/third_party/gloo/gloo.BUILD +++ b/third_party/xla/third_party/gloo/gloo.BUILD @@ -22,7 +22,7 @@ substitions = { "#cmakedefine01 GLOO_USE_REDIS": "#define GLOO_USE_REDIS 0", "#cmakedefine01 GLOO_USE_IBVERBS": "#define GLOO_USE_IBVERBS 0", "#cmakedefine01 GLOO_USE_MPI": "#define GLOO_USE_MPI 0", - "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV 0", + "#cmakedefine01 GLOO_USE_LIBUV": "#define GLOO_USE_LIBUV (__APPLE__ ? 1 : 0)", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "#define GLOO_HAVE_TRANSPORT_TCP 1", "#cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "#define GLOO_HAVE_TRANSPORT_TCP_TLS 0", "#cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "#define GLOO_HAVE_TRANSPORT_IBVERBS 0", @@ -95,3 +95,14 @@ cc_library( copts = ["-fexceptions"], deps = [":gloo"], ) + +cc_library( + name = "transport_uv", + srcs = glob(["gloo/transport/uv/*.cc"]), + hdrs = glob(["gloo/transport/uv/*.h"]), + copts = ["-fexceptions"], + deps = [ + ":gloo", + "@uv", + ], +) diff --git a/third_party/xla/third_party/nanobind/nanobind.BUILD b/third_party/xla/third_party/nanobind/nanobind.BUILD index c9f307b75ef0ca..72b47585b5e5d0 100644 --- a/third_party/xla/third_party/nanobind/nanobind.BUILD +++ b/third_party/xla/third_party/nanobind/nanobind.BUILD @@ -4,9 +4,12 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "nanobind", - srcs = glob([ - "src/*.cpp", - ]), + srcs = glob( + [ + "src/*.cpp", + ], + exclude = ["src/nb_combined.cpp"], + ), copts = ["-fexceptions"], defines = [ "NB_BUILD=1", diff --git a/third_party/xla/third_party/nanobind/pr438.patch b/third_party/xla/third_party/nanobind/pr438.patch deleted file mode 100644 index edb7d61700e03b..00000000000000 --- a/third_party/xla/third_party/nanobind/pr438.patch +++ /dev/null @@ -1,51 +0,0 @@ -diff --git a/src/nb_enum.cpp b/src/nb_enum.cpp -index 86f64d1..91f3932 100644 ---- a/src/nb_enum.cpp -+++ b/src/nb_enum.cpp -@@ -73,6 +73,13 @@ static PyObject *nb_enum_get_doc(PyObject *self, void *) { - return result; - } - -+static PyObject *nb_enum_get_value(PyObject *self, void *) { -+ enum_supplement &supp = nb_enum_supplement(Py_TYPE(self)); -+ return supp.is_signed ? nb_enum_int_signed(self) -+ : nb_enum_int_unsigned(self); -+} -+ -+ - NB_NOINLINE static PyObject *nb_enum_int_signed(PyObject *o) { - type_data *t = nb_type_data(Py_TYPE(o)); - const void *p = inst_ptr((nb_inst *) o); -@@ -141,6 +148,8 @@ error: - static PyGetSetDef nb_enum_getset[] = { - { "__doc__", nb_enum_get_doc, nullptr, nullptr, nullptr }, - { "__name__", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "name", nb_enum_get_name, nullptr, nullptr, nullptr }, -+ { "value", nb_enum_get_value, nullptr, nullptr, nullptr }, - { nullptr, nullptr, nullptr, nullptr, nullptr } - }; - -diff --git a/tests/test_enum.py b/tests/test_enum.py -index 2a6e9ff..1063eef 100644 ---- a/tests/test_enum.py -+++ b/tests/test_enum.py -@@ -14,6 +14,9 @@ def test01_unsigned_enum(): - assert int(t.Enum.A) == 0 - assert int(t.Enum.B) == 1 - assert int(t.Enum.C) == 0xffffffff -+ assert t.Enum.A.value == 0 -+ assert t.Enum.B.value == 1 -+ assert t.Enum.C.value == 0xffffffff - assert t.Enum(0) is t.Enum.A - assert t.Enum(1) is t.Enum.B - assert t.Enum(0xffffffff) is t.Enum.C -@@ -48,6 +51,9 @@ def test02_signed_enum(): - assert int(t.SEnum.A) == 0 - assert int(t.SEnum.B) == 1 - assert int(t.SEnum.C) == -1 -+ assert t.SEnum.A.value == 0 -+ assert t.SEnum.B.value == 1 -+ assert t.SEnum.C.value == -1 - assert t.SEnum(0) is t.SEnum.A - assert t.SEnum(1) is t.SEnum.B - assert t.SEnum(-1) is t.SEnum.C \ No newline at end of file diff --git a/third_party/xla/third_party/nanobind/pr461.patch b/third_party/xla/third_party/nanobind/pr461.patch deleted file mode 100644 index aa0a51b68175a3..00000000000000 --- a/third_party/xla/third_party/nanobind/pr461.patch +++ /dev/null @@ -1,39 +0,0 @@ -diff --git a/src/nb_type.cpp b/src/nb_type.cpp ---- a/src/nb_type.cpp -+++ b/src/nb_type.cpp -@@ -36,6 +36,11 @@ static PyObject **nb_weaklist_ptr(PyObje - return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr; - } - -+static PyGetSetDef inst_getset[] = { -+ { "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr }, -+ { nullptr, nullptr, nullptr, nullptr, nullptr } -+}; -+ - static int inst_clear(PyObject *self) { - PyObject **dict = nb_dict_ptr(self); - if (dict) -@@ -923,8 +928,11 @@ PyObject *nb_type_new(const type_init_da - } - - bool has_traverse = false; -- for (PyType_Slot *ts = slots; ts != s; ++ts) -+ bool has_getset = false; -+ for (PyType_Slot *ts = slots; ts != s; ++ts) { - has_traverse |= ts->slot == Py_tp_traverse; -+ has_getset |= ts->slot == Py_tp_getset; -+ } - - Py_ssize_t dictoffset = 0, weaklistoffset = 0; - int num_members = 0; -@@ -948,6 +956,10 @@ PyObject *nb_type_new(const type_init_da - has_traverse = true; - } - spec.basicsize = (int) basicsize; -+ -+ if (!has_getset) { -+ *s++ = { Py_tp_getset, (void *) inst_getset }; -+ } - } - - if (is_weak_referenceable) { diff --git a/third_party/xla/third_party/nanobind/workspace.bzl b/third_party/xla/third_party/nanobind/workspace.bzl index 9f9022dbaa8d12..1c692d396e9b98 100644 --- a/third_party/xla/third_party/nanobind/workspace.bzl +++ b/third_party/xla/third_party/nanobind/workspace.bzl @@ -5,12 +5,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): tf_http_archive( name = "nanobind", - strip_prefix = "nanobind-1.9.2", - sha256 = "149a3da40b0a988513d8cf5e71db3037373823505a3c92f87b988c92d7e0ab34", - urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v1.9.2.tar.gz"), + strip_prefix = "nanobind-2.1.0", + sha256 = "c37c53c60ada5fe1c956e24bd4b83af669a2309bf952bd251f36a7d2fa3bacf0", + urls = tf_mirror_urls("https://github.com/wjakob/nanobind/archive/refs/tags/v2.1.0.tar.gz"), build_file = "//third_party/nanobind:nanobind.BUILD", - patch_file = [ - "//third_party/nanobind:pr438.patch", # Remove when updating to nanobind 2.0.0. - "//third_party/nanobind:pr461.patch", # Remove when updating to nanobind 2.0.0. - ], ) diff --git a/third_party/xla/third_party/py/python_repo.bzl b/third_party/xla/third_party/py/python_repo.bzl index 85dbda9c62f11e..0c58e3077712c6 100644 --- a/third_party/xla/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/py/python_repo.bzl @@ -259,8 +259,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -276,13 +280,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -365,6 +368,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/xla/third_party/shardy/BUILD b/third_party/xla/third_party/shardy/BUILD index ea1ecdb548c1f4..bf3ae84c142f65 100644 --- a/third_party/xla/third_party/shardy/BUILD +++ b/third_party/xla/third_party/shardy/BUILD @@ -2,4 +2,7 @@ # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) -exports_files(srcs = ["workspace.bzl"]) +exports_files(srcs = [ + "temporary.patch", + "workspace.bzl", +]) diff --git a/third_party/xla/third_party/shardy/workspace.bzl b/third_party/xla/third_party/shardy/workspace.bzl index c82f3275766f90..6d91def025b34a 100644 --- a/third_party/xla/third_party/shardy/workspace.bzl +++ b/third_party/xla/third_party/shardy/workspace.bzl @@ -3,8 +3,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): - SHARDY_COMMIT = "8f92b38a2400ce5dc72f97067b02c635ed4f3d00" - SHARDY_SHA256 = "3d91370627e81ce5285e5a6ec0d6dbefc786ae32f6d1ebcb4aa61fd247378b91" + SHARDY_COMMIT = "7e3ddfb532b3b53cb0b108014c24a86ac147e9f6" + SHARDY_SHA256 = "1d304e1e6f1132fe3ccb969d28798bc6ee90db84d10c85113ef8573eae350325" tf_http_archive( name = "shardy", diff --git a/third_party/xla/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/xla/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 8b137891791fe9..77fefee2b13b6d 100755 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -1 +1,28 @@ +diff --ruN a/stablehlo/BUILD.bazel b/stablehlo/BUILD.bazel +--- stablehlo/BUILD.bazel ++++ stablehlo/BUILD.bazel +@@ -1283,6 +1283,7 @@ + "@llvm-project//mlir:AllExtensions", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", ++ "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:TosaDialect", + ], + ) +diff --ruN a/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py b/stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +--- stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py ++++ stablehlo/stablehlo/conversions/tosa/tests/lit.cfg.py +@@ -32,9 +32,9 @@ + + # Make LLVM and StableHLO tools available in RUN directives + tools = [ +- 'stablehlo-opt', +- 'FileCheck', +- 'stablehlo-translate', ++ 'stablehlo-opt', ++ 'FileCheck', ++ 'stablehlo-translate', + ] + tool_dirs = [ + config.llvm_tools_dir, diff --git a/third_party/xla/third_party/stablehlo/workspace.bzl b/third_party/xla/third_party/stablehlo/workspace.bzl index f9c14a65d4dbb3..6c0cea3e8f16f5 100644 --- a/third_party/xla/third_party/stablehlo/workspace.bzl +++ b/third_party/xla/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # LINT.IfChange - STABLEHLO_COMMIT = "8555db77763fadbd6be83df0a5532828bc419cba" - STABLEHLO_SHA256 = "666a88d94e0f1b36e9e5b25411521b878320c61983214859b4e419f36acbf332" + STABLEHLO_COMMIT = "23d3e1414b0be1c1b5256f0949520dc4f0a0705c" + STABLEHLO_SHA256 = "ad694a3da43a2a432c8c5f1c60be39fc211e28834cca07cf663ce8dc85d920fe" # LINT.ThenChange(Google-internal path) tf_http_archive( diff --git a/third_party/xla/third_party/triton/llvm_integration/cl656020169.patch b/third_party/xla/third_party/triton/llvm_integration/cl656020169.patch deleted file mode 100644 index 7586a90b14ccf6..00000000000000 --- a/third_party/xla/third_party/triton/llvm_integration/cl656020169.patch +++ /dev/null @@ -1,12 +0,0 @@ -diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp ---- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -+++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp -@@ -117,7 +117,7 @@ private: - auto operands = callOp.getOperands(); - auto result = callOp.getResult(); - -- LLVM::LLVMFunctionType calleeType = callOp.getCalleeType().value(); -+ LLVM::LLVMFunctionType calleeType = callOp.getVarCalleeType().value(); - Type returnType = calleeType.getReturnType(); - - auto loc = callOp.getLoc(); diff --git a/third_party/xla/third_party/triton/workspace.bzl b/third_party/xla/third_party/triton/workspace.bzl index d1a4940f567dd9..29287cc59f3210 100644 --- a/third_party/xla/third_party/triton/workspace.bzl +++ b/third_party/xla/third_party/triton/workspace.bzl @@ -8,8 +8,8 @@ load("//third_party/triton/xla_extensions:series.bzl", "extensions_files_patch_l def repo(): """Imports Triton.""" - TRITON_COMMIT = "cl655158651" - TRITON_SHA256 = "ac136693d2aeae327896d33e1a4de4852f25c1c2cdca49f85a2b9ac8b6d03b44" + TRITON_COMMIT = "cl664783844" + TRITON_SHA256 = "d5779d331008dd3a4941dd59e61385ec964987da74454248446ac3e36b874007" tf_http_archive( name = "triton", sha256 = TRITON_SHA256, diff --git a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch index 21ed97b5afb822..dadc7732a4f280 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparse_dot.patch @@ -57,7 +57,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index 012786dae..6043b764a 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -498,6 +498,119 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, +@@ -498,6 +498,123 @@ getDefaultBlockedEncoding(MLIRContext *context, ArrayRef shape, return encoding; } @@ -173,6 +173,10 @@ index 012786dae..6043b764a 100644 + ArrayRef tensorShape) const { + return ::getShapePerCTATile(getParent(), tensorShape); +} ++std::optional SparseDotMetaEncodingAttr::toLinearLayout( ++ ArrayRef shape) const { ++ return ::toLinearLayout(shape, getParent()); ++} + } // namespace gpu } // namespace triton @@ -273,9 +277,9 @@ index d74e0a224..4e45f7c4c 100644 + return op->hasTrait() || isa(op); +} + - // Replace the ForOp's yield with a new one with the given operands appended. - static void appendToYield(scf::ForOp forOp, ArrayRef newOperands) { - // Fix up the yield op. + static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, + Value insertIdx, Value extractIdx, + tt::CoarseSchedule &schedule, @@ -248,19 +252,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { } else { if (!isa(user)) @@ -344,52 +348,6 @@ index d74e0a224..4e45f7c4c 100644 if (auto dotEnc = dyn_cast( dot.getResult().getType().getEncoding())) { auto loadTy = cast(op->getResultTypes()[0]); -diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -index 8c1f18e45..c39110d12 100644 ---- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -+++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp -@@ -38,6 +38,10 @@ public: - auto srcEncoding = srcType.getEncoding(); - if (isa(srcEncoding)) - return; -+ if (isa(dstType.getEncoding())) { -+ replaceSparseMetaEncoding(cvtOp); -+ return; -+ } - auto dstDotOp = - dyn_cast(dstType.getEncoding()); - if (!dstDotOp) -@@ -86,6 +90,30 @@ public: - cvtOp.erase(); - }); - } -+ -+ private: -+ void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) { -+ auto srcType = cast(cvtOp.getOperand().getType()); -+ auto srcEncoding = srcType.getEncoding(); -+ auto sharedLayout = triton::gpu::SharedEncodingAttr::get( -+ cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding), -+ triton::gpu::getCTALayout(srcEncoding)); -+ -+ auto dstType = cast(cvtOp.getType()); -+ auto sharedMemorySpace = -+ triton::gpu::SharedMemorySpaceAttr::get(srcType.getContext()); -+ auto tmpType = triton::MemDescType::get( -+ dstType.getShape(), dstType.getElementType(), sharedLayout, -+ sharedMemorySpace); -+ -+ OpBuilder builder(cvtOp); -+ auto tmp = builder.create( -+ cvtOp.getLoc(), tmpType, cvtOp.getSrc()); -+ auto newConvert = builder.create( -+ cvtOp.getLoc(), dstType, tmp); -+ cvtOp.replaceAllUsesWith(newConvert.getResult()); -+ cvtOp.erase(); -+ } - }; - - } // namespace gpu diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index fb0e7f6fd..37795c20c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp diff --git a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch index b64ddbdbdab683..4daf4f2856069c 100644 --- a/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch +++ b/third_party/xla/third_party/triton/xla_extensions/sparsity_layout.patch @@ -2,19 +2,20 @@ diff --git a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp b/lib/Conv index 34fb89954..a0172e107 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonGPUConversion.cpp -@@ -57,8 +57,9 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, +@@ -57,6 +57,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addArgumentMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> std::optional { -- llvm_unreachable("Argument rematerialization should not happen in Triton " -- "-> TritonGPU conversion"); -+ // TODO(b/354860562): reenable or remove. -+ // llvm_unreachable("Argument rematerialization should not happen in Triton " -+ // "-> TritonGPU conversion"); ++ // Allows partial TTIR to TTGIR conversion by materializing a conversion for ++ // remaining arguments that have been converted to a new type. ++ // We use this to rewrite triton_gpu.sparse_dot in a separate pass after ++ // 'convert-triton-to-tritongpu'. ++ return builder.create(loc, tensorType, ++ inputs); + llvm_unreachable("Argument rematerialization should not happen in Triton " + "-> TritonGPU conversion"); return std::nullopt; - }); - -@@ -67,6 +68,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, +@@ -67,6 +73,12 @@ TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context, addSourceMaterialization([&](OpBuilder &builder, RankedTensorType tensorType, ValueRange inputs, Location loc) -> std::optional { @@ -31,7 +32,7 @@ diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dia index df3d3b042..e38c184f6 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp -@@ -2867,13 +2867,13 @@ struct CanonicalizeConvertFromConvert +@@ -2867,13 +2879,13 @@ struct CanonicalizeConvertFromConvert // heuristic to accommodate fused attention. auto srcType = op.getSrc().getType(); auto dstType = op.getType(); diff --git a/third_party/xla/third_party/tsl/.bazelrc b/third_party/xla/third_party/tsl/.bazelrc index b94693e05efab8..9e565e91a1b903 100644 --- a/third_party/xla/third_party/tsl/.bazelrc +++ b/third_party/xla/third_party/tsl/.bazelrc @@ -219,13 +219,16 @@ build:mkl_aarch64_threadpool -c opt build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda +# Default CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" +# This flag is needed to include hermetic CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_hermetic_cuda_libs=true # CUDA: This config refers to building CUDA op kernels with clang. build:cuda_clang --config=cuda -# Enable TensorRT optimizations https://developer.nvidia.com/tensorrt -build:cuda_clang --config=tensorrt -build:cuda_clang --action_env=TF_CUDA_CLANG="1" build:cuda_clang --@local_config_cuda//:cuda_compiler=clang +build:cuda_clang --copt=-Qunused-arguments # Select supported compute capabilities (supported graphics cards). # This is the same as the official TensorFlow builds. # See https://developer.nvidia.com/cuda-gpus#compute @@ -234,22 +237,22 @@ build:cuda_clang --@local_config_cuda//:cuda_compiler=clang # release while SASS is only forward compatible inside the current # major release. Example: sm_80 kernels can run on sm_89 GPUs but # not on sm_90 GPUs. compute_80 kernels though can also run on sm_90 GPUs. -build:cuda_clang --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +build:cuda_clang --repo_env=HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,sm_89,compute_90" +# Set lld as the linker. +build:cuda_clang --host_linkopt="-fuse-ld=lld" +build:cuda_clang --host_linkopt="-lm" +build:cuda_clang --linkopt="-fuse-ld=lld" +build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --action_env=TF_CUDA_VERSION="12" -build:cuda_clang_official --action_env=TF_CUDNN_VERSION="8" -build:cuda_clang_official --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12.3" -build:cuda_clang_official --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" +build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="8.9.7.29" build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang_official --action_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:cuda_clang_official --crosstool_top="@sigbuild-r2.17-clang_config_cuda//crosstool:toolchain" # Build with nvcc for CUDA and clang for host build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" build:nvcc_clang --action_env=TF_NVCC_CLANG="1" build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc @@ -351,6 +354,13 @@ build:windows --features=archive_param_file build:windows --copt=/d2ReducedOptimizeHugeFunctions build:windows --host_copt=/d2ReducedOptimizeHugeFunctions +# Before VS 2017 15.8, the member "type" would non-conformingly have an +# alignment of only alignof(max_align_t). VS 2017 15.8 was fixed to handle this +# correctly, but the fix inherently changes layout and breaks binary +# compatibility (*only* for uses of aligned_storage with extended alignments). +build:windows --copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE +build:windows --host_copt=-D_ENABLE_EXTENDED_ALIGNED_STORAGE + # Enable the runfiles symlink tree on Windows. This makes it possible to build # the pip package on Windows without an intermediate data-file archive, as the # build_pip_package script in its current form (as of Aug 2023) uses the @@ -538,10 +548,6 @@ build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 -build:rbe_linux_cuda --repo_env=TF_CUDA_CONFIG_REPO="@sigbuild-r2.17-clang_config_cuda" -build:rbe_linux_cuda --repo_env=TF_TENSORRT_CONFIG_REPO="@sigbuild-r2.17-clang_config_tensorrt" -build:rbe_linux_cuda --repo_env=TF_NCCL_CONFIG_REPO="@sigbuild-r2.17-clang_config_nccl" -test:rbe_linux_cuda --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda build:rbe_linux_cuda_nvcc --config=nvcc_clang @@ -566,6 +572,9 @@ build:rbe_win_clang --compiler=clang-cl build:rbe_win_clang --linkopt=/FORCE:MULTIPLE build:rbe_win_clang --host_linkopt=/FORCE:MULTIPLE +# TODO(belitskiy): Rename `rbe_win_clang` to this, once done switching presubmits. +build:rbe_windows_x86_cpu --config=rbe_win_clang + # END TF REMOTE BUILD EXECUTION OPTIONS # TFLite build configs for generic embedded Linux @@ -623,7 +632,6 @@ build:release_cpu_linux_base --repo_env=BAZEL_COMPILER="/usr/lib/llvm-18/bin/cla # Test-related settings below this point. test:release_linux_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true test:release_linux_base --local_test_jobs=HOST_CPUS -test:release_linux_base --test_env=LD_LIBRARY_PATH # Give only the list of failed tests at the end of the log test:release_linux_base --test_summary=short @@ -637,7 +645,6 @@ build:release_gpu_linux --config=release_cpu_linux # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. # Note that linux cpu and cuda builds share the same toolchain now. build:release_gpu_linux --config=cuda_clang_official -test:release_gpu_linux --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" # Local test jobs has to be 4 because parallel_gpu_execute is fragile, I think test:release_gpu_linux --test_timeout=300,450,1200,3600 --local_test_jobs=4 --run_under=//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute @@ -668,9 +675,8 @@ build:unsupported_gpu_linux --config=unsupported_cpu_linux build:unsupported_gpu_linux --action_env=TF_CUDA_VERSION="11" build:unsupported_gpu_linux --action_env=TF_CUDNN_VERSION="8" build:unsupported_gpu_linux --repo_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_50,sm_60,sm_70,sm_75,compute_80" -build:unsupported_gpu_linux --config=tensorrt build:unsupported_gpu_linux --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-11.2" -build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64:/usr/local/tensorrt/lib" +build:unsupported_gpu_linux --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64" build:unsupported_gpu_linux --action_env=GCC_HOST_COMPILER_PATH="/dt9/usr/bin/gcc" build:unsupported_gpu_linux --crosstool_top=@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain @@ -780,17 +786,19 @@ test:macos_x86_wheel_test --config=macos_x86_wheel_test_filters -- //tensorflow/ # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. -# CPU PYCPP: +# LINUX CPU PYCPP: test:linux_cpu_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-gpu,-tpu,-benchmark-test,-v1only test:linux_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cpu_pycpp_test --config=linux_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# CUDA PYCPP: + +# LINUX CUDA PYCPP: test:linux_cuda_pycpp_test_filters --test_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --build_tag_filters=-no_oss,-oss_excluded,-oss_serial,-benchmark-test,-v1only,gpu,-no_gpu,-no_gpu_presubmit,-no_cuda11 test:linux_cuda_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium test:linux_cuda_pycpp_test --config=linux_cuda_pycpp_test_filters -- //tensorflow/... -//tensorflow/compiler/tf2tensorrt/... -//tensorflow/core/tpu/... -//tensorflow/lite/... -//tensorflow/tools/toolchains/... -# ARM64 PYCPP + +# LINUX ARM64 PYCPP # In Linux Arm64 presubmit/continuous build, we cross-compile the binaries on # Linux x86 so that we can use RBE. Since tests still need to run on the single # host Arm64 machine, the build becomes too slow (~30 min) to be a presubmit. @@ -823,6 +831,13 @@ build:macos_x86_pycpp_test --config=macos_x86_pycpp_test_filters -- //tensorflow # CROSS-COMPILE MACOS X86 PYCPP build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test +# WINDOWS X86-64 CPU PYCPP +test:windows_x86_cpu_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test +test:windows_x86_cpu_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" +test:windows_x86_cpu_pycpp_test_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off --build_tests_only +test:windows_x86_cpu_pycpp_test --config=windows_x86_cpu_pycpp_test_opts --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... + # END TF TEST SUITE OPTIONS # START CROSS-COMPILE CONFIGS diff --git a/third_party/xla/third_party/tsl/WORKSPACE b/third_party/xla/third_party/tsl/WORKSPACE index 19350e3dbba762..a83a9e63f4143a 100644 --- a/third_party/xla/third_party/tsl/WORKSPACE +++ b/third_party/xla/third_party/tsl/WORKSPACE @@ -50,3 +50,50 @@ tsl_workspace1() load(":workspace0.bzl", "tsl_workspace0") tsl_workspace0() + +load( + "//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/third_party/xla/third_party/tsl/opensource_only.files b/third_party/xla/third_party/tsl/opensource_only.files index 300ae95c10aec2..f93d02d633d3c7 100644 --- a/third_party/xla/third_party/tsl/opensource_only.files +++ b/third_party/xla/third_party/tsl/opensource_only.files @@ -21,6 +21,7 @@ third_party/git/BUILD.tpl: third_party/git/BUILD: third_party/git/git_configure.bzl: third_party/gpus/BUILD: +third_party/gpus/compiler_common_tools.bzl: third_party/gpus/crosstool/BUILD.rocm.tpl: third_party/gpus/crosstool/BUILD.sycl.tpl: third_party/gpus/crosstool/BUILD.tpl: @@ -38,6 +39,27 @@ third_party/gpus/cuda/LICENSE: third_party/gpus/cuda/build_defs.bzl.tpl: third_party/gpus/cuda/cuda_config.h.tpl: third_party/gpus/cuda/cuda_config.py.tpl: +third_party/gpus/cuda/hermetic/BUILD.tpl: +third_party/gpus/cuda/hermetic/BUILD: +third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_configure.bzl: +third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl: +third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl: +third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl: third_party/gpus/cuda_configure.bzl: third_party/gpus/find_cuda_config:.py third_party/gpus/rocm/BUILD.tpl: @@ -67,6 +89,9 @@ third_party/nccl/archive.BUILD: third_party/nccl/archive.patch: third_party/nccl/build_defs.bzl.tpl: third_party/nccl/generated_names.bzl.tpl: +third_party/nccl/hermetic/BUILD: +third_party/nccl/hermetic/cuda_nccl.BUILD.tpl: +third_party/nccl/hermetic/nccl_configure.bzl: third_party/nccl/nccl_configure.bzl: third_party/nccl/system.BUILD.tpl: third_party/nvtx/BUILD: @@ -93,6 +118,7 @@ third_party/remote_config/remote_platform_configure.bzl: third_party/repo.bzl: third_party/six.BUILD: third_party/snappy.BUILD: +third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD: third_party/systemlibs/BUILD.tpl: third_party/systemlibs/BUILD: third_party/systemlibs/absl_py.BUILD: diff --git a/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch new file mode 100644 index 00000000000000..5328c3a0d605c7 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/absl/nvidia_jetson.patch @@ -0,0 +1,35 @@ +From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001 +From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= +Date: Thu, 1 Aug 2024 12:38:52 -0700 +Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665 + +Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732 + +Fix build on NVIDIA Jetson board. Fix #1665 + +This patch is already used by the spark project. +I'm fixing this as this break the build of Tensorflow and JAX on Jetson board. +Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9 + +Merging this change closes #1732 + +COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff +PiperOrigin-RevId: 658501520 +Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181 +--- + absl/base/config.h | 2 +- + 1 file changed, 1 insertion(+), 1 deletion(-) + +diff --git a/absl/base/config.h b/absl/base/config.h +index 97c9a22a109..ab1e9860a91 100644 +--- a/absl/base/config.h ++++ b/absl/base/config.h +@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' || + // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code + #ifdef ABSL_INTERNAL_HAVE_ARM_NEON + #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set +-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__) ++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__)) + #define ABSL_INTERNAL_HAVE_ARM_NEON 1 + #endif + diff --git a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl index 06f75166ce4bb6..9565a82c331946 100644 --- a/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/absl/workspace.bzl @@ -44,4 +44,5 @@ def repo(): system_link_files = SYS_LINKS, strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT), urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)), + patch_file = ["//third_party/absl:nvidia_jetson.patch"], ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py index afd6380b0ac203..b1a10a86b9aac6 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/check_cuda_libs.py @@ -14,6 +14,9 @@ # ============================================================================== """Verifies that a list of libraries is installed on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + Takes a list of arguments with every two subsequent arguments being a logical tuple of (path, check_soname). The path to the library and either True or False to indicate whether to check the soname field on the shared library. diff --git a/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl b/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl new file mode 100644 index 00000000000000..bd07f49ec457bb --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/compiler_common_tools.bzl @@ -0,0 +1,174 @@ +"""Common compiler functions. """ + +load( + "//third_party/remote_config:common.bzl", + "err_out", + "raw_exec", + "realpath", +) + +def to_list_of_strings(elements): + """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. + + This is to be used to put a list of strings into the bzl file templates + so it gets interpreted as list of strings in Starlark. + + Args: + elements: list of string elements + + Returns: + single string of elements wrapped in quotes separated by a comma.""" + quoted_strings = ["\"" + element + "\"" for element in elements] + return ", ".join(quoted_strings) + +_INC_DIR_MARKER_BEGIN = "#include <...>" + +# OSX add " (framework directory)" at the end of line, strip it. +_OSX_FRAMEWORK_SUFFIX = " (framework directory)" +_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) + +# TODO(dzc): Once these functions have been factored out of Bazel's +# cc_configure.bzl, load them from @bazel_tools instead. +def _cxx_inc_convert(path): + """Convert path returned by cc -E xc++ in a complete path.""" + path = path.strip() + if path.endswith(_OSX_FRAMEWORK_SUFFIX): + path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() + return path + +def _normalize_include_path(repository_ctx, path): + """Normalizes include paths before writing them to the crosstool. + + If path points inside the 'crosstool' folder of the repository, a relative + path is returned. + If path points outside the 'crosstool' folder, an absolute path is returned. + """ + path = str(repository_ctx.path(path)) + crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) + + if path.startswith(crosstool_folder): + # We drop the path to "$REPO/crosstool" and a trailing path separator. + return path[len(crosstool_folder) + 1:] + return path + +def _is_compiler_option_supported(repository_ctx, cc, option): + """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" + result = repository_ctx.execute([ + cc, + option, + "-o", + "/dev/null", + "-c", + str(repository_ctx.path("tools/cpp/empty.cc")), + ]) + return result.stderr.find(option) == -1 + +def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sys_root): + """Compute the list of default C or C++ include directories.""" + if lang_is_cpp: + lang = "c++" + else: + lang = "c" + sysroot = [] + if tf_sys_root: + sysroot += ["--sysroot", tf_sys_root] + result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + + sysroot) + stderr = err_out(result) + index1 = stderr.find(_INC_DIR_MARKER_BEGIN) + if index1 == -1: + return [] + index1 = stderr.find("\n", index1) + if index1 == -1: + return [] + index2 = stderr.rfind("\n ") + if index2 == -1 or index2 < index1: + return [] + index2 = stderr.find("\n", index2 + 1) + if index2 == -1: + inc_dirs = stderr[index1 + 1:] + else: + inc_dirs = stderr[index1 + 1:index2].strip() + + print_resource_dir_supported = _is_compiler_option_supported( + repository_ctx, + cc, + "-print-resource-dir", + ) + + if print_resource_dir_supported: + resource_dir = repository_ctx.execute( + [cc, "-print-resource-dir"], + ).stdout.strip() + "/share" + inc_dirs += "\n" + resource_dir + + compiler_includes = [ + _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) + for p in inc_dirs.split("\n") + ] + + # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc + # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) + # but Bazel might encounter either (usually reported by the compiler) + # especially when a compiler wrapper (e.g. ccache) is used. + # So we need to also include paths where symlinks are not resolved. + + # Try to find real path to CC installation to "see through" compiler wrappers + # GCC has the path to g++ + index1 = result.stderr.find("COLLECT_GCC=") + if index1 != -1: + index1 = result.stderr.find("=", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname + else: + # Clang has the directory + index1 = result.stderr.find("InstalledDir: ") + if index1 != -1: + index1 = result.stderr.find(" ", index1) + index2 = result.stderr.find("\n", index1) + cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname + else: + # Fallback to the CC path + cc_topdir = repository_ctx.path(cc).dirname.dirname + + # We now have the compiler installation prefix, e.g. /symlink/gcc + # And the resolved installation prefix, e.g. /opt/gcc + cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() + cc_topdir = str(cc_topdir).strip() + + # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. + # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] + # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] + if cc_topdir_resolved != cc_topdir: + unresolved_compiler_includes = [ + cc_topdir + inc[len(cc_topdir_resolved):] + for inc in compiler_includes + if inc.startswith(cc_topdir_resolved) + ] + compiler_includes = compiler_includes + unresolved_compiler_includes + return compiler_includes + +def get_cxx_inc_directories(repository_ctx, cc, tf_sys_root): + """Compute the list of default C and C++ include directories.""" + + # For some reason `clang -xc` sometimes returns include paths that are + # different from the ones from `clang -xc++`. (Symlink and a dir) + # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists + includes_cpp = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + True, + tf_sys_root, + ) + includes_c = _get_cxx_inc_directories_impl( + repository_ctx, + cc, + False, + tf_sys_root, + ) + + return includes_cpp + [ + inc + for inc in includes_c + if inc not in includes_cpp + ] diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl index 8eda7a1cf6ac2b..b9553d9b99ecfe 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/BUILD.tpl @@ -2,6 +2,7 @@ # Update cuda_configure.bzl#verify_build_defines when adding new variables. load(":cc_toolchain_config.bzl", "cc_toolchain_config") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") licenses(["restricted"]) @@ -133,9 +134,17 @@ filegroup( srcs = [], ) +filegroup( + name = "cuda_nvcc_files", + srcs = %{cuda_nvcc_files}, +) + filegroup( name = "crosstool_wrapper_driver_is_not_gcc", - srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], + srcs = [ + ":cuda_nvcc_files", + ":clang/bin/crosstool_wrapper_driver_is_not_gcc" + ], ) filegroup( diff --git a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl index c46e09484fdfad..eb3a1d8c8ddf02 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl @@ -181,6 +181,9 @@ def InvokeNvcc(argv, log=False): nvccopts += ['--keep', '--keep-dir', tempdir] # Force C++17 dialect (note, everything in just one string!) nvccopts += ['--std c++17'] + # This is so that nvcc does not complain about MSVC or CLANG. + nvccopts += ['-allow-unsupported-compiler'] + nvccopts += ['--expt-extended-lambda', '--expt-relaxed-constexpr'] if log: Log([NVCC_PATH] + nvccopts) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl index 0b85e59231a374..094431dcedfc12 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.tpl @@ -1,6 +1,10 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Please use `hermetic/cuda_configure` instead. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag", "bool_setting") licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like @@ -144,7 +148,6 @@ cc_library( name = "cusolver", srcs = ["cuda/lib/%{cusolver_lib}"], data = ["cuda/lib/%{cusolver_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -220,7 +223,6 @@ cc_library( name = "cusparse", srcs = ["cuda/lib/%{cusparse_lib}"], data = ["cuda/lib/%{cusparse_lib}"], - linkopts = ["-lgomp"], linkstatic = 1, ) @@ -242,6 +244,41 @@ py_library( srcs = ["cuda/cuda_config.py"], ) +# Build setting that is always true (i.e. it can not be changed on the +# command line). It is used to create the config settings below that are +# always or never satisfied. +bool_setting( + name = "true_setting", + visibility = ["//visibility:private"], + build_setting_default = True, +) + +# Config settings whether TensorFlow is built with hermetic CUDA. +# These configs are never satisfied. +config_setting( + name = "hermetic_cuda_tools", + flag_values = {":true_setting": "False"}, +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":true_setting": "False"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + %{copy_rules} cc_library( @@ -249,3 +286,9 @@ cc_library( # to make bazel query happy. name = "nvptxcompiler", ) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvjitlink", +) \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl index dee0e898d9ae7a..6b25c8398a7144 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/BUILD.windows.tpl @@ -1,3 +1,7 @@ +# NB: DEPRECATED! This file is a part of the deprecated `cuda_configure` rule. +# Hermetic CUDA repository rule doesn't support Windows. +# Please use `hermetic/cuda_configure`. + load(":build_defs.bzl", "cuda_header_library") load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@bazel_skylib//lib:selects.bzl", "selects") diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl index bc865cecb3240a..d1c50ea6377b9e 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/build_defs.bzl.tpl @@ -104,9 +104,16 @@ def if_cuda_newer_than(wanted_ver, if_true, if_false = []): wanted_major = int(wanted_ver.split('_')[0]) wanted_minor = int(wanted_ver.split('_')[1]) - configured_version = "%{cuda_version}" - configured_major = int(configured_version.split('.')[0]) - configured_minor = int(configured_version.split('.')[1]) + # Strip "64_" which appears in the CUDA version on Windows. + configured_version = "%{cuda_version}".rsplit("_", 1)[-1] + configured_version_parts = configured_version.split('.') + + # On Windows, the major and minor versions are concatenated without a period and the minor only contains one digit. + if len(configured_version_parts) == 1: + configured_version_parts = [configured_version[0:-1], configured_version[-1:]] + + configured_major = int(configured_version_parts[0]) + configured_minor = int(configured_version_parts[1]) if %{cuda_is_configured} and (wanted_major, wanted_minor) <= (configured_major, configured_minor): return select({"//conditions:default": if_true}) @@ -142,9 +149,13 @@ def cuda_header_library( **kwargs ) -def cuda_library(copts = [], **kwargs): +def cuda_library(copts = [], tags = [],**kwargs): """Wrapper over cc_library which adds default CUDA options.""" - native.cc_library(copts = cuda_default_copts() + copts, **kwargs) + native.cc_library( + copts = cuda_default_copts() + copts, + tags = tags + ["gpu"], + **kwargs + ) def cuda_cc_test(copts = [], **kwargs): """Wrapper over cc_test which adds default CUDA options.""" diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl new file mode 100644 index 00000000000000..ccf1b9a030d5ad --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/BUILD.tpl @@ -0,0 +1,266 @@ +load("@bazel_skylib//:bzl_library.bzl", "bzl_library") +load("@bazel_skylib//lib:selects.bzl", "selects") +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +# Config setting whether TensorFlow is built with CUDA support using clang. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_clang. +selects.config_setting_group( + name = "using_clang", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_clang", + ], +) + +# Config setting whether TensorFlow is built with CUDA support using nvcc. +# +# TODO(b/174244321), DEPRECATED: this target will be removed when all users +# have been converted to :is_cuda_enabled (most) or :is_cuda_compiler_nvcc. +selects.config_setting_group( + name = "using_nvcc", + match_all = [ + "@local_config_cuda//:is_cuda_enabled", + "@local_config_cuda//:is_cuda_compiler_nvcc", + ], +) + +# Equivalent to using_clang && -c opt. +selects.config_setting_group( + name = "using_clang_opt", + match_all = [ + ":using_clang", + ":_opt", + ], +) + +config_setting( + name = "_opt", + values = {"compilation_mode": "opt"}, +) + +# Provides CUDA headers for '#include "third_party/gpus/cuda/include/cuda.h"' +# All clients including TensorFlow should use these directives. +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ], + include_prefix = "third_party/gpus", + includes = [ + ".", # required to include cuda/cuda/cuda_config.h as cuda/config.h + ], + deps = [":cudart_headers", + ":cublas_headers", + ":cccl_headers", + ":nvtx_headers", + ":nvcc_headers", + ":cusolver_headers", + ":cufft_headers", + ":cusparse_headers", + ":curand_headers", + ":cupti_headers", + ":nvml_headers"], +) + +cc_library( + name = "cudart_static", + srcs = ["@cuda_cudart//:static"], + linkopts = [ + "-ldl", + "-lpthread", + %{cudart_static_linkopt} + ], +) + +alias( + name = "cuda_driver", + actual = "@cuda_cudart//:cuda_driver", +) + +alias( + name = "cudart_headers", + actual = "@cuda_cudart//:headers", +) + +alias( + name = "cudart", + actual = "@cuda_cudart//:cudart", +) + +alias( + name = "nvtx_headers", + actual = "@cuda_nvtx//:headers", +) + +alias( + name = "nvml_headers", + actual = "@cuda_nvml//:headers", +) + +alias( + name = "nvcc_headers", + actual = "@cuda_nvcc//:headers", +) + +alias( + name = "cccl_headers", + actual = "@cuda_cccl//:headers", +) + +alias( + name = "cublas_headers", + actual = "@cuda_cublas//:headers", +) + +alias( + name = "cusolver_headers", + actual = "@cuda_cusolver//:headers", +) + +alias( + name = "cufft_headers", + actual = "@cuda_cufft//:headers", +) + +alias( + name = "cusparse_headers", + actual = "@cuda_cusparse//:headers", +) + +alias( + name = "curand_headers", + actual = "@cuda_curand//:headers", +) + +alias( + name = "cublas", + actual = "@cuda_cublas//:cublas", +) + +alias( + name = "cublasLt", + actual = "@cuda_cublas//:cublasLt", +) + +alias( + name = "cusolver", + actual = "@cuda_cusolver//:cusolver", +) + +alias( + name = "cudnn", + actual = "@cuda_cudnn//:cudnn", +) + +alias( + name = "cudnn_header", + actual = "@cuda_cudnn//:headers", +) + +alias( + name = "cufft", + actual = "@cuda_cufft//:cufft", +) + +alias( + name = "curand", + actual = "@cuda_curand//:curand", +) + +cc_library( + name = "cuda", + deps = [ + ":cublas", + ":cublasLt", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +alias( + name = "cub_headers", + actual = ":cuda_headers", +) + +alias( + name = "cupti_headers", + actual = "@cuda_cupti//:headers", +) + +alias( + name = "cupti_dsos", + actual = "@cuda_cupti//:cupti", +) + +alias( + name = "cusparse", + actual = "@cuda_cusparse//:cusparse", +) + +alias( + name = "cuda-nvvm", + actual = "@cuda_nvcc//:nvvm", +) + +alias( + name = "nvjitlink", + actual = "@cuda_nvjitlink//:nvjitlink" +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], +) + +bzl_library( + name = "build_defs_bzl", + srcs = ["build_defs.bzl"], + deps = [ + "@bazel_skylib//lib:selects", + ], +) + +py_library( + name = "cuda_config_py", + srcs = ["cuda/cuda_config.py"], +) + +# Config setting whether TensorFlow is built with hermetic CUDA. +alias( + name = "hermetic_cuda_tools", + actual = "@local_config_cuda//:is_cuda_enabled", +) + +# Flag indicating if we should include hermetic CUDA libs. +bool_flag( + name = "include_hermetic_cuda_libs", + build_setting_default = False, +) + +config_setting( + name = "hermetic_cuda_libs", + flag_values = {":include_hermetic_cuda_libs": "True"}, +) + +selects.config_setting_group( + name = "hermetic_cuda_tools_and_libs", + match_all = [ + ":hermetic_cuda_libs", + ":hermetic_cuda_tools" + ], +) + +cc_library( + # This is not yet fully supported, but we need the rule + # to make bazel query happy. + name = "nvptxcompiler", +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl new file mode 100644 index 00000000000000..85c0cbbb196fef --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cccl.BUILD.tpl @@ -0,0 +1,15 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + hdrs = glob([ + %{comment}"include/cub/**", + %{comment}"include/cuda/**", + %{comment}"include/nv/**", + %{comment}"include/thrust/**", + ]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl new file mode 100644 index 00000000000000..270b73c3884855 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_configure.bzl @@ -0,0 +1,521 @@ +"""Repository rule for hermetic CUDA autoconfiguration. + +`cuda_configure` depends on the following environment variables: + + * `TF_NEED_CUDA`: Whether to enable building with CUDA. + * `TF_NVCC_CLANG`: Whether to use clang for C++ and NVCC for Cuda compilation. + * `CLANG_CUDA_COMPILER_PATH`: The clang compiler path that will be used for + both host and device code compilation. + * `TF_SYSROOT`: The sysroot to use when compiling. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + * `HERMETIC_CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default + is `3.5,5.2`. If not specified, the value will be determined by the + `TF_CUDA_COMPUTE_CAPABILITIES`. + * `PYTHON_BIN_PATH`: The python binary path +""" + +load( + "//third_party/gpus:compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", + "which", +) + +def _find_cc(repository_ctx): + """Find the C++ compiler.""" + cc_path_envvar = _CLANG_CUDA_COMPILER_PATH + cc_name = "clang" + + cc_name_from_env = get_host_environ(repository_ctx, cc_path_envvar) + if cc_name_from_env: + cc_name = cc_name_from_env + if cc_name.startswith("/"): + # Return the absolute path. + return cc_name + cc = which(repository_ctx, cc_name) + if cc == None: + fail(("Cannot find {}, either correct your path or set the {}" + + " environment variable").format(cc_name, cc_path_envvar)) + return cc + +def _auto_configure_fail(msg): + """Output failure message when cuda configuration fails.""" + red = "\033[0;31m" + no_color = "\033[0m" + fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg)) + +def _verify_build_defines(params): + """Verify all variables that crosstool/BUILD.tpl expects are substituted. + + Args: + params: dict of variables that will be passed to the BUILD.tpl template. + """ + missing = [] + for param in [ + "cxx_builtin_include_directories", + "extra_no_canonical_prefixes_flags", + "host_compiler_path", + "host_compiler_prefix", + "host_compiler_warnings", + "linker_bin_path", + "compiler_deps", + "msvc_cl_path", + "msvc_env_include", + "msvc_env_lib", + "msvc_env_path", + "msvc_env_tmp", + "msvc_lib_path", + "msvc_link_path", + "msvc_ml_path", + "unfiltered_compile_flags", + "win_compiler_deps", + ]: + if ("%{" + param + "}") not in params: + missing.append(param) + + if missing: + _auto_configure_fail( + "BUILD.tpl template is missing these variables: " + + str(missing) + + ".\nWe only got: " + + str(params) + + ".", + ) + +def get_cuda_version(repository_ctx): + return (get_host_environ(repository_ctx, HERMETIC_CUDA_VERSION) or + get_host_environ(repository_ctx, TF_CUDA_VERSION)) + +def enable_cuda(repository_ctx): + """Returns whether to build with CUDA support.""" + return int(get_host_environ(repository_ctx, TF_NEED_CUDA, False)) + +def _flag_enabled(repository_ctx, flag_name): + return get_host_environ(repository_ctx, flag_name) == "1" + +def _use_nvcc_and_clang(repository_ctx): + # Returns the flag if we need to use clang for C++ and NVCC for Cuda. + return _flag_enabled(repository_ctx, _TF_NVCC_CLANG) + +def _tf_sysroot(repository_ctx): + return get_host_environ(repository_ctx, _TF_SYSROOT, "") + +def _py_tmpl_dict(d): + return {"%{cuda_config}": str(d)} + +def _cudart_static_linkopt(cpu_value): + """Returns additional platform-specific linkopts for cudart.""" + return "\"\"," if cpu_value == "Darwin" else "\"-lrt\"," + +def _compute_capabilities(repository_ctx): + """Returns a list of strings representing cuda compute capabilities. + + Args: + repository_ctx: the repo rule's context. + + Returns: + list of cuda architectures to compile for. 'compute_xy' refers to + both PTX and SASS, 'sm_xy' refers to SASS only. + """ + capabilities = (get_host_environ( + repository_ctx, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + ) or + get_host_environ( + repository_ctx, + _TF_CUDA_COMPUTE_CAPABILITIES, + )) + capabilities = (capabilities or "compute_35,compute_52").split(",") + + # Map old 'x.y' capabilities to 'compute_xy'. + if len(capabilities) > 0 and all([len(x.split(".")) == 2 for x in capabilities]): + # If all capabilities are in 'x.y' format, only include PTX for the + # highest capability. + cc_list = sorted([x.replace(".", "") for x in capabilities]) + capabilities = [ + "sm_%s" % x + for x in cc_list[:-1] + ] + ["compute_%s" % cc_list[-1]] + for i, capability in enumerate(capabilities): + parts = capability.split(".") + if len(parts) != 2: + continue + capabilities[i] = "compute_%s%s" % (parts[0], parts[1]) + + # Make list unique + capabilities = dict(zip(capabilities, capabilities)).keys() + + # Validate capabilities. + for capability in capabilities: + if not capability.startswith(("compute_", "sm_")): + _auto_configure_fail("Invalid compute capability: %s" % capability) + for prefix in ["compute_", "sm_"]: + if not capability.startswith(prefix): + continue + if len(capability) == len(prefix) + 2 and capability[-2:].isdigit(): + continue + if len(capability) == len(prefix) + 3 and capability.endswith("90a"): + continue + _auto_configure_fail("Invalid compute capability: %s" % capability) + + return capabilities + +def _compute_cuda_extra_copts(compute_capabilities): + copts = ["--no-cuda-include-ptx=all"] + for capability in compute_capabilities: + if capability.startswith("compute_"): + capability = capability.replace("compute_", "sm_") + copts.append("--cuda-include-ptx=%s" % capability) + copts.append("--cuda-gpu-arch=%s" % capability) + + return str(copts) + +def _get_cuda_config(repository_ctx): + """Detects and returns information about the CUDA installation on the system. + + Args: + repository_ctx: The repository context. + + Returns: + A struct containing the following fields: + cuda_version: The version of CUDA on the system. + cudart_version: The CUDA runtime version on the system. + cudnn_version: The version of cuDNN on the system. + compute_capabilities: A list of the system's CUDA compute capabilities. + cpu_value: The name of the host operating system. + """ + + return struct( + cuda_version = get_cuda_version(repository_ctx), + cupti_version = repository_ctx.read(repository_ctx.attr.cupti_version), + cudart_version = repository_ctx.read(repository_ctx.attr.cudart_version), + cublas_version = repository_ctx.read(repository_ctx.attr.cublas_version), + cusolver_version = repository_ctx.read(repository_ctx.attr.cusolver_version), + curand_version = repository_ctx.read(repository_ctx.attr.curand_version), + cufft_version = repository_ctx.read(repository_ctx.attr.cufft_version), + cusparse_version = repository_ctx.read(repository_ctx.attr.cusparse_version), + cudnn_version = repository_ctx.read(repository_ctx.attr.cudnn_version), + compute_capabilities = _compute_capabilities(repository_ctx), + cpu_value = get_cpu_value(repository_ctx), + ) + +_DUMMY_CROSSTOOL_BZL_FILE = """ +def error_gpu_disabled(): + fail("ERROR: Building with --config=cuda but TensorFlow is not configured " + + "to build with GPU support. Please re-run ./configure and enter 'Y' " + + "at the prompt to build with GPU support.") + + native.genrule( + name = "error_gen_crosstool", + outs = ["CROSSTOOL"], + cmd = "echo 'Should not be run.' && exit 1", + ) + + native.filegroup( + name = "crosstool", + srcs = [":CROSSTOOL"], + output_licenses = ["unencumbered"], + ) +""" + +_DUMMY_CROSSTOOL_BUILD_FILE = """ +load("//crosstool:error_gpu_disabled.bzl", "error_gpu_disabled") + +error_gpu_disabled() +""" + +def _create_dummy_repository(repository_ctx): + cpu_value = get_cpu_value(repository_ctx) + + # Set up BUILD file for cuda/. + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "False", + "%{cuda_extra_copts}": "[]", + "%{cuda_gpu_architectures}": "[]", + "%{cuda_version}": "0.0", + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value), + }, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": "", + "%{cudart_version}": "", + "%{cupti_version}": "", + "%{cublas_version}": "", + "%{cusolver_version}": "", + "%{curand_version}": "", + "%{cufft_version}": "", + "%{cusparse_version}": "", + "%{cudnn_version}": "", + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": "", + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({}), + ) + + # If cuda_configure is not configured to build with GPU support, and the user + # attempts to build with --config=cuda, add a dummy build rule to intercept + # this and fail with an actionable error message. + repository_ctx.file( + "crosstool/error_gpu_disabled.bzl", + _DUMMY_CROSSTOOL_BZL_FILE, + ) + repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE) + +def _create_local_cuda_repository(repository_ctx): + """Creates the repository containing files set up to build with CUDA.""" + cuda_config = _get_cuda_config(repository_ctx) + + # Set up BUILD file for cuda/ + repository_ctx.template( + "cuda/build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + cuda_config.compute_capabilities, + ), + "%{cuda_gpu_architectures}": str(cuda_config.compute_capabilities), + "%{cuda_version}": cuda_config.cuda_version, + }, + ) + + repository_ctx.template( + "cuda/BUILD", + repository_ctx.attr.cuda_build_tpl, + { + "%{cudart_static_linkopt}": _cudart_static_linkopt( + cuda_config.cpu_value, + ), + }, + ) + + is_nvcc_and_clang = _use_nvcc_and_clang(repository_ctx) + tf_sysroot = _tf_sysroot(repository_ctx) + + # Set up crosstool/ + cc = _find_cc(repository_ctx) + host_compiler_includes = get_cxx_inc_directories( + repository_ctx, + cc, + tf_sysroot, + ) + + cuda_defines = {} + + # We do not support hermetic CUDA on Windows. + # This ensures the CROSSTOOL file parser is happy. + cuda_defines.update({ + "%{msvc_env_tmp}": "msvc_not_used", + "%{msvc_env_path}": "msvc_not_used", + "%{msvc_env_include}": "msvc_not_used", + "%{msvc_env_lib}": "msvc_not_used", + "%{msvc_cl_path}": "msvc_not_used", + "%{msvc_ml_path}": "msvc_not_used", + "%{msvc_link_path}": "msvc_not_used", + "%{msvc_lib_path}": "msvc_not_used", + "%{win_compiler_deps}": ":empty", + }) + + cuda_defines["%{builtin_sysroot}"] = tf_sysroot + cuda_defines["%{cuda_toolkit_path}"] = repository_ctx.attr.nvcc_binary.workspace_root + cuda_defines["%{compiler}"] = "clang" + cuda_defines["%{host_compiler_prefix}"] = "/usr/bin" + cuda_defines["%{linker_bin_path}"] = "" + cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" + cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cxx_builtin_include_directories}"] = to_list_of_strings( + host_compiler_includes, + ) + cuda_defines["%{cuda_nvcc_files}"] = "if_cuda([\"@{nvcc_archive}//:bin\", \"@{nvcc_archive}//:nvvm\"])".format( + nvcc_archive = repository_ctx.attr.nvcc_binary.repo_name, + ) + + if not is_nvcc_and_clang: + cuda_defines["%{host_compiler_path}"] = str(cc) + cuda_defines["%{host_compiler_warnings}"] = """ + # Some parts of the codebase set -Werror and hit this warning, so + # switch it off for now. + "-Wno-invalid-partial-specialization" + """ + cuda_defines["%{compiler_deps}"] = ":cuda_nvcc_files" + repository_ctx.file( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + "", + ) + else: + cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc" + cuda_defines["%{host_compiler_warnings}"] = "" + + nvcc_relative_path = "%s/%s" % ( + repository_ctx.attr.nvcc_binary.workspace_root, + repository_ctx.attr.nvcc_binary.name, + ) + cuda_defines["%{compiler_deps}"] = ":crosstool_wrapper_driver_is_not_gcc" + + wrapper_defines = { + "%{cpu_compiler}": str(cc), + "%{cuda_version}": cuda_config.cuda_version, + "%{nvcc_path}": nvcc_relative_path, + "%{host_compiler_path}": str(cc), + "%{use_clang_compiler}": "True", + } + repository_ctx.template( + "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", + repository_ctx.attr.crosstool_wrapper_driver_is_not_gcc_tpl, + wrapper_defines, + ) + + _verify_build_defines(cuda_defines) + + # Only expand template variables in the BUILD file + repository_ctx.template( + "crosstool/BUILD", + repository_ctx.attr.crosstool_build_tpl, + cuda_defines, + ) + + # No templating of cc_toolchain_config - use attributes and templatize the + # BUILD file. + repository_ctx.template( + "crosstool/cc_toolchain_config.bzl", + repository_ctx.attr.cc_toolchain_config_tpl, + {}, + ) + + # Set up cuda_config.h, which is used by + # tensorflow/compiler/xla/stream_executor/dso_loader.cc. + repository_ctx.template( + "cuda/cuda/cuda_config.h", + repository_ctx.attr.cuda_config_tpl, + { + "%{cuda_version}": cuda_config.cuda_version, + "%{cudart_version}": cuda_config.cudart_version, + "%{cupti_version}": cuda_config.cupti_version, + "%{cublas_version}": cuda_config.cublas_version, + "%{cusolver_version}": cuda_config.cusolver_version, + "%{curand_version}": cuda_config.curand_version, + "%{cufft_version}": cuda_config.cufft_version, + "%{cusparse_version}": cuda_config.cusparse_version, + "%{cudnn_version}": cuda_config.cudnn_version, + "%{cuda_toolkit_path}": "", + "%{cuda_compute_capabilities}": ", ".join([ + cc.split("_")[1] + for cc in cuda_config.compute_capabilities + ]), + }, + ) + + # Set up cuda_config.py, which is used by gen_build_info to provide + # static build environment info to the API + repository_ctx.template( + "cuda/cuda/cuda_config.py", + repository_ctx.attr.cuda_config_py_tpl, + _py_tmpl_dict({ + "cuda_version": cuda_config.cuda_version, + "cudnn_version": cuda_config.cudnn_version, + "cuda_compute_capabilities": cuda_config.compute_capabilities, + "cpu_compiler": str(cc), + }), + ) + +def _cuda_autoconf_impl(repository_ctx): + """Implementation of the cuda_autoconf repository rule.""" + build_file = repository_ctx.attr.local_config_cuda_build_file + + if not enable_cuda(repository_ctx): + _create_dummy_repository(repository_ctx) + else: + _create_local_cuda_repository(repository_ctx) + + repository_ctx.symlink(build_file, "BUILD") + +_CLANG_CUDA_COMPILER_PATH = "CLANG_CUDA_COMPILER_PATH" +_PYTHON_BIN_PATH = "PYTHON_BIN_PATH" +_HERMETIC_CUDA_COMPUTE_CAPABILITIES = "HERMETIC_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +HERMETIC_CUDA_VERSION = "HERMETIC_CUDA_VERSION" +TF_CUDA_VERSION = "TF_CUDA_VERSION" +TF_NEED_CUDA = "TF_NEED_CUDA" +_TF_NVCC_CLANG = "TF_NVCC_CLANG" +_TF_SYSROOT = "TF_SYSROOT" + +_ENVIRONS = [ + _CLANG_CUDA_COMPILER_PATH, + TF_NEED_CUDA, + _TF_NVCC_CLANG, + TF_CUDA_VERSION, + HERMETIC_CUDA_VERSION, + _TF_CUDA_COMPUTE_CAPABILITIES, + _HERMETIC_CUDA_COMPUTE_CAPABILITIES, + _TF_SYSROOT, + _PYTHON_BIN_PATH, + "TMP", + "TMPDIR", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", +] + +cuda_configure = repository_rule( + implementation = _cuda_autoconf_impl, + environ = _ENVIRONS, + attrs = { + "environ": attr.string_dict(), + "cublas_version": attr.label(default = Label("@cuda_cublas//:version.txt")), + "cudart_version": attr.label(default = Label("@cuda_cudart//:version.txt")), + "cudnn_version": attr.label(default = Label("@cuda_cudnn//:version.txt")), + "cufft_version": attr.label(default = Label("@cuda_cufft//:version.txt")), + "cupti_version": attr.label(default = Label("@cuda_cupti//:version.txt")), + "curand_version": attr.label(default = Label("@cuda_curand//:version.txt")), + "cusolver_version": attr.label(default = Label("@cuda_cusolver//:version.txt")), + "cusparse_version": attr.label(default = Label("@cuda_cusparse//:version.txt")), + "nvcc_binary": attr.label(default = Label("@cuda_nvcc//:bin/nvcc")), + "local_config_cuda_build_file": attr.label(default = Label("//third_party/gpus:local_config_cuda.BUILD")), + "build_defs_tpl": attr.label(default = Label("//third_party/gpus/cuda:build_defs.bzl.tpl")), + "cuda_build_tpl": attr.label(default = Label("//third_party/gpus/cuda/hermetic:BUILD.tpl")), + "cuda_config_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.h.tpl")), + "cuda_config_py_tpl": attr.label(default = Label("//third_party/gpus/cuda:cuda_config.py.tpl")), + "crosstool_wrapper_driver_is_not_gcc_tpl": attr.label(default = Label("//third_party/gpus/crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl")), + "crosstool_build_tpl": attr.label(default = Label("//third_party/gpus/crosstool:BUILD.tpl")), + "cc_toolchain_config_tpl": attr.label(default = Label("//third_party/gpus/crosstool:cc_toolchain_config.bzl.tpl")), + }, +) +"""Detects and configures the hermetic CUDA toolchain. + +Add the following to your WORKSPACE file: + +```python +cuda_configure(name = "local_config_cuda") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl new file mode 100644 index 00000000000000..510235d801de4e --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cublas.BUILD.tpl @@ -0,0 +1,44 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cublas_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublas.so.%{libcublas_version}", + deps = [":cublasLt"], +) + +cc_import( + name = "cublasLt_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcublasLt.so.%{libcublaslt_version}", +) +%{multiline_comment} +cc_library( + name = "cublas", + visibility = ["//visibility:public"], + %{comment}deps = [":cublas_shared_library"], +) + +cc_library( + name = "cublasLt", + visibility = ["//visibility:public"], + %{comment}deps = [":cublasLt_shared_library"], +) + +cc_library( + name = "headers", + %{comment}hdrs = [ + %{comment}"include/cublas.h", + %{comment}"include/cublasLt.h", + %{comment}"include/cublas_api.h", + %{comment}"include/cublas_v2.h", + %{comment}], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl new file mode 100644 index 00000000000000..f7ba469b42b76a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudart.BUILD.tpl @@ -0,0 +1,126 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) + +filegroup( + name = "static", + srcs = ["lib/libcudart_static.a"], + visibility = ["@local_config_cuda//cuda:__pkg__"], +) +%{multiline_comment} +# TODO: Replace system provided library with hermetic NVIDIA driver library. +cc_import( + name = "cuda_driver_shared_library", + interface_library = "lib/stubs/libcuda.so", + system_provided = 1, +) + +cc_import( + name = "cudart_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcudart.so.%{libcudart_version}", +) +%{multiline_comment} +cc_library( + name = "cuda_driver", + %{comment}deps = [":cuda_driver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + %{comment}deps = [ + %{comment}":cuda_driver", + %{comment}":cudart_shared_library", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/builtin_types.h", + %{comment}"include/channel_descriptor.h", + %{comment}"include/common_functions.h", + %{comment}"include/cooperative_groups/**", + %{comment}"include/cooperative_groups.h", + %{comment}"include/cuComplex.h", + %{comment}"include/cuda.h", + %{comment}"include/cudaEGL.h", + %{comment}"include/cudaEGLTypedefs.h", + %{comment}"include/cudaGL.h", + %{comment}"include/cudaGLTypedefs.h", + %{comment}"include/cudaProfilerTypedefs.h", + %{comment}"include/cudaTypedefs.h", + %{comment}"include/cudaVDPAU.h", + %{comment}"include/cudaVDPAUTypedefs.h", + %{comment}"include/cuda_awbarrier.h", + %{comment}"include/cuda_awbarrier_helpers.h", + %{comment}"include/cuda_awbarrier_primitives.h", + %{comment}"include/cuda_bf16.h", + %{comment}"include/cuda_bf16.hpp", + %{comment}"include/cuda_device_runtime_api.h", + %{comment}"include/cuda_egl_interop.h", + %{comment}"include/cuda_fp16.h", + %{comment}"include/cuda_fp16.hpp", + %{comment}"include/cuda_fp8.h", + %{comment}"include/cuda_fp8.hpp", + %{comment}"include/cuda_gl_interop.h", + %{comment}"include/cuda_occupancy.h", + %{comment}"include/cuda_pipeline.h", + %{comment}"include/cuda_pipeline_helpers.h", + %{comment}"include/cuda_pipeline_primitives.h", + %{comment}"include/cuda_runtime.h", + %{comment}"include/cuda_runtime_api.h", + %{comment}"include/cuda_surface_types.h", + %{comment}"include/cuda_texture_types.h", + %{comment}"include/cuda_vdpau_interop.h", + %{comment}"include/cudart_platform.h", + %{comment}"include/device_atomic_functions.h", + %{comment}"include/device_atomic_functions.hpp", + %{comment}"include/device_double_functions.h", + %{comment}"include/device_functions.h", + %{comment}"include/device_launch_parameters.h", + %{comment}"include/device_types.h", + %{comment}"include/driver_functions.h", + %{comment}"include/driver_types.h", + %{comment}"include/host_config.h", + %{comment}"include/host_defines.h", + %{comment}"include/library_types.h", + %{comment}"include/math_constants.h", + %{comment}"include/math_functions.h", + %{comment}"include/mma.h", + %{comment}"include/nvfunctional", + %{comment}"include/sm_20_atomic_functions.h", + %{comment}"include/sm_20_atomic_functions.hpp", + %{comment}"include/sm_20_intrinsics.h", + %{comment}"include/sm_20_intrinsics.hpp", + %{comment}"include/sm_30_intrinsics.h", + %{comment}"include/sm_30_intrinsics.hpp", + %{comment}"include/sm_32_atomic_functions.h", + %{comment}"include/sm_32_atomic_functions.hpp", + %{comment}"include/sm_32_intrinsics.h", + %{comment}"include/sm_32_intrinsics.hpp", + %{comment}"include/sm_35_atomic_functions.h", + %{comment}"include/sm_35_intrinsics.h", + %{comment}"include/sm_60_atomic_functions.h", + %{comment}"include/sm_60_atomic_functions.hpp", + %{comment}"include/sm_61_intrinsics.h", + %{comment}"include/sm_61_intrinsics.hpp", + %{comment}"include/surface_functions.h", + %{comment}"include/surface_indirect_functions.h", + %{comment}"include/surface_types.h", + %{comment}"include/texture_fetch_functions.h", + %{comment}"include/texture_indirect_functions.h", + %{comment}"include/texture_types.h", + %{comment}"include/vector_functions.h", + %{comment}"include/vector_functions.hpp", + %{comment}"include/vector_types.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl new file mode 100644 index 00000000000000..165c5b1579e73f --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn.BUILD.tpl @@ -0,0 +1,73 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_infer.so.%{libcudnn_ops_infer_version}", +) + +cc_import( + name = "cudnn_cnn_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_infer.so.%{libcudnn_cnn_infer_version}", +) + +cc_import( + name = "cudnn_ops_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops_train.so.%{libcudnn_ops_train_version}", +) + +cc_import( + name = "cudnn_cnn_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn_train.so.%{libcudnn_cnn_train_version}", +) + +cc_import( + name = "cudnn_adv_infer", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_infer.so.%{libcudnn_adv_infer_version}", +) + +cc_import( + name = "cudnn_adv_train", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv_train.so.%{libcudnn_adv_train_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_ops_infer", + %{comment}":cudnn_ops_train", + %{comment}":cudnn_cnn_infer", + %{comment}":cudnn_cnn_train", + %{comment}":cudnn_adv_infer", + %{comment}":cudnn_adv_train", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl new file mode 100644 index 00000000000000..7f36054a51bb5b --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cudnn9.BUILD.tpl @@ -0,0 +1,80 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cudnn_ops", + hdrs = [":headers"], + shared_library = "lib/libcudnn_ops.so.%{libcudnn_ops_version}", +) + +cc_import( + name = "cudnn_cnn", + hdrs = [":headers"], + shared_library = "lib/libcudnn_cnn.so.%{libcudnn_cnn_version}", +) + +cc_import( + name = "cudnn_adv", + hdrs = [":headers"], + shared_library = "lib/libcudnn_adv.so.%{libcudnn_adv_version}", +) + +cc_import( + name = "cudnn_graph", + hdrs = [":headers"], + shared_library = "lib/libcudnn_graph.so.%{libcudnn_graph_version}", +) + +cc_import( + name = "cudnn_engines_precompiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_precompiled.so.%{libcudnn_engines_precompiled_version}", +) + +cc_import( + name = "cudnn_engines_runtime_compiled", + hdrs = [":headers"], + shared_library = "lib/libcudnn_engines_runtime_compiled.so.%{libcudnn_engines_runtime_compiled_version}", +) + +cc_import( + name = "cudnn_heuristic", + hdrs = [":headers"], + shared_library = "lib/libcudnn_heuristic.so.%{libcudnn_heuristic_version}", +) + +cc_import( + name = "cudnn_main", + hdrs = [":headers"], + shared_library = "lib/libcudnn.so.%{libcudnn_version}", +) +%{multiline_comment} +cc_library( + name = "cudnn", + %{comment}deps = [ + %{comment}":cudnn_engines_precompiled", + %{comment}":cudnn_ops", + %{comment}":cudnn_graph", + %{comment}":cudnn_cnn", + %{comment}":cudnn_adv", + %{comment}":cudnn_engines_runtime_compiled", + %{comment}":cudnn_heuristic", + %{comment}"@cuda_nvrtc//:nvrtc", + %{comment}":cudnn_main", + %{comment}], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudnn*.h", + %{comment}]), + include_prefix = "third_party/gpus/cudnn", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl new file mode 100644 index 00000000000000..48ccb0ea3cd197 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cufft.BUILD.tpl @@ -0,0 +1,29 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cufft_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcufft.so.%{libcufft_version}", +) +%{multiline_comment} +cc_library( + name = "cufft", + %{comment}deps = [":cufft_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cudalibxt.h", + %{comment}"include/cufft*.h" + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl new file mode 100644 index 00000000000000..3efe76f470953f --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cupti.BUILD.tpl @@ -0,0 +1,59 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cupti_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcupti.so.%{libcupti_version}", +) +%{multiline_comment} +cc_library( + name = "cupti", + %{comment}deps = [":cupti_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/Openacc/**", + %{comment}"include/Openmp/**", + %{comment}"include/cuda_stdint.h", + %{comment}"include/cupti.h", + %{comment}"include/cupti_activity.h", + %{comment}"include/cupti_activity_deprecated.h", + %{comment}"include/cupti_callbacks.h", + %{comment}"include/cupti_checkpoint.h", + %{comment}"include/cupti_driver_cbid.h", + %{comment}"include/cupti_events.h", + %{comment}"include/cupti_metrics.h", + %{comment}"include/cupti_nvtx_cbid.h", + %{comment}"include/cupti_pcsampling.h", + %{comment}"include/cupti_pcsampling_util.h", + %{comment}"include/cupti_profiler_target.h", + %{comment}"include/cupti_result.h", + %{comment}"include/cupti_runtime_cbid.h", + %{comment}"include/cupti_sass_metrics.h", + %{comment}"include/cupti_target.h", + %{comment}"include/cupti_version.h", + %{comment}"include/generated_cudaGL_meta.h", + %{comment}"include/generated_cudaVDPAU_meta.h", + %{comment}"include/generated_cuda_gl_interop_meta.h", + %{comment}"include/generated_cuda_meta.h", + %{comment}"include/generated_cuda_runtime_api_meta.h", + %{comment}"include/generated_cuda_vdpau_interop_meta.h", + %{comment}"include/generated_cudart_removed_meta.h", + %{comment}"include/generated_nvtx_meta.h", + %{comment}"include/nvperf_common.h", + %{comment}"include/nvperf_cuda_host.h", + %{comment}"include/nvperf_host.h", + %{comment}"include/nvperf_target.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/extras/CUPTI/include", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl new file mode 100644 index 00000000000000..50e5a8f18a96fd --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_curand.BUILD.tpl @@ -0,0 +1,26 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "curand_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcurand.so.%{libcurand_version}", +) +%{multiline_comment} +cc_library( + name = "curand", + %{comment}deps = [":curand_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob(["include/curand*.h"]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl new file mode 100644 index 00000000000000..943a08ebeb96e1 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusolver.BUILD.tpl @@ -0,0 +1,34 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusolver_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusolver.so.%{libcusolver_version}", + deps = [ + "@cuda_nvjitlink//:nvjitlink", + "@cuda_cusparse//:cusparse", + "@cuda_cublas//:cublas", + "@cuda_cublas//:cublasLt", + ], +) +%{multiline_comment} +cc_library( + name = "cusolver", + %{comment}deps = [":cusolver_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/cusolver*.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl new file mode 100644 index 00000000000000..46b24366ce1c04 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_cusparse.BUILD.tpl @@ -0,0 +1,27 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "cusparse_shared_library", + hdrs = [":headers"], + shared_library = "lib/libcusparse.so.%{libcusparse_version}", + deps = ["@cuda_nvjitlink//:nvjitlink"], +) +%{multiline_comment} +cc_library( + name = "cusparse", + %{comment}deps = [":cusparse_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = ["include/cusparse.h"], + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl new file mode 100644 index 00000000000000..fdda3aaf92cea5 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_json_init_repository.bzl @@ -0,0 +1,125 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistributions JSON repository initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_JSON_DICT", + "CUDNN_REDIST_JSON_DICT", +) + +def _get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_json_file_content(repository_ctx, url_to_sha256, json_file_name): + if len(url_to_sha256) > 1: + (url, sha256) = url_to_sha256 + else: + url = url_to_sha256[0] + sha256 = "" + repository_ctx.download( + url = tf_mirror_urls(url), + sha256 = sha256, + output = json_file_name, + ) + return repository_ctx.read(repository_ctx.path(json_file_name)) + +def _cuda_redist_json_impl(repository_ctx): + cuda_version = (_get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + _get_env_var(repository_ctx, "TF_CUDA_VERSION")) + local_cuda_path = _get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + cudnn_version = (_get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + _get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + local_cudnn_path = _get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + supported_cuda_versions = repository_ctx.attr.cuda_json_dict.keys() + if (cuda_version and not local_cuda_path and + (cuda_version not in supported_cuda_versions)): + fail( + ("The supported CUDA versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add JSON URL for" + + " CUDA version={version}.") + .format( + supported_versions = supported_cuda_versions, + version = cuda_version, + ), + ) + supported_cudnn_versions = repository_ctx.attr.cudnn_json_dict.keys() + if cudnn_version and not local_cudnn_path and (cudnn_version not in supported_cudnn_versions): + fail( + ("The supported CUDNN versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDNN_VERSION" + + " environment variable or add JSON URL for" + + " CUDNN version={version}.") + .format( + supported_versions = supported_cudnn_versions, + version = cudnn_version, + ), + ) + cuda_redistributions = "{}" + cudnn_redistributions = "{}" + if cuda_version and not local_cuda_path: + cuda_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cuda_json_dict[cuda_version], + "redistrib_cuda_%s.json" % cuda_version, + ) + if cudnn_version and not local_cudnn_path: + cudnn_redistributions = _get_json_file_content( + repository_ctx, + repository_ctx.attr.cudnn_json_dict[cudnn_version], + "redistrib_cudnn_%s.json" % cudnn_version, + ) + + repository_ctx.file( + "distributions.bzl", + """CUDA_REDISTRIBUTIONS = {cuda_redistributions} + +CUDNN_REDISTRIBUTIONS = {cudnn_redistributions} +""".format( + cuda_redistributions = cuda_redistributions, + cudnn_redistributions = cudnn_redistributions, + ), + ) + repository_ctx.file( + "BUILD", + "", + ) + +cuda_redist_json = repository_rule( + implementation = _cuda_redist_json_impl, + attrs = { + "cuda_json_dict": attr.string_list_dict(mandatory = True), + "cudnn_json_dict": attr.string_list_dict(mandatory = True), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "HERMETIC_CUDNN_VERSION", + "TF_CUDA_VERSION", + "TF_CUDNN_VERSION", + "LOCAL_CUDA_PATH", + "LOCAL_CUDNN_PATH", + ], +) + +def cuda_json_init_repository( + cuda_json_dict = CUDA_REDIST_JSON_DICT, + cudnn_json_dict = CUDNN_REDIST_JSON_DICT): + cuda_redist_json( + name = "cuda_redist_json", + cuda_json_dict = cuda_json_dict, + cudnn_json_dict = cudnn_json_dict, + ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl new file mode 100644 index 00000000000000..7757a92a90b795 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvcc.BUILD.tpl @@ -0,0 +1,75 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "bin/nvcc", +]) + +filegroup( + name = "nvvm", + srcs = [ + "nvvm/libdevice/libdevice.10.bc", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "nvlink", + srcs = [ + "bin/nvlink", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "fatbinary", + srcs = [ + "bin/fatbinary", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin2c", + srcs = [ + "bin/bin2c", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "ptxas", + srcs = [ + "bin/ptxas", + ], + visibility = ["//visibility:public"], +) + +filegroup( + name = "bin", + srcs = glob([ + "bin/**", + "nvvm/bin/**", + ]), + visibility = ["//visibility:public"], +) + +filegroup( + name = "link_stub", + srcs = [ + "bin/crt/link.stub", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/crt/**", + %{comment}"include/fatbinary_section.h", + %{comment}"include/nvPTXCompiler.h", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl new file mode 100644 index 00000000000000..9784a84471f1a7 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvjitlink.BUILD.tpl @@ -0,0 +1,17 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nvjitlink_shared_library", + shared_library = "lib/libnvJitLink.so.%{libnvjitlink_version}", +) +%{multiline_comment} +cc_library( + name = "nvjitlink", + %{comment}deps = [":nvjitlink_shared_library"], + visibility = ["//visibility:public"], +) + diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl new file mode 100644 index 00000000000000..23ee30f09f8ff3 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvml.BUILD.tpl @@ -0,0 +1,10 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = ["include/nvml.h"], + include_prefix = "third_party/gpus/cuda/nvml/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl new file mode 100644 index 00000000000000..986ef0c8f76166 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvprune.BUILD.tpl @@ -0,0 +1,9 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +filegroup( + name = "nvprune", + srcs = [ + "bin/nvprune", + ], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl new file mode 100644 index 00000000000000..de18489b455b79 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvrtc.BUILD.tpl @@ -0,0 +1,20 @@ +licenses(["restricted"]) # NVIDIA proprietary license +%{multiline_comment} +cc_import( + name = "nvrtc_main", + shared_library = "lib/libnvrtc.so.%{libnvrtc_version}", +) + +cc_import( + name = "nvrtc_builtins", + shared_library = "lib/libnvrtc-builtins.so.%{libnvrtc-builtins_version}", +) +%{multiline_comment} +cc_library( + name = "nvrtc", + %{comment}deps = [ + %{comment}":nvrtc_main", + %{comment}":nvrtc_builtins", + %{comment}], + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl new file mode 100644 index 00000000000000..3457f41a502dee --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_nvtx.BUILD.tpl @@ -0,0 +1,13 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nvToolsExt*.h", + %{comment}"include/nvtx3/**", + %{comment}]), + include_prefix = "third_party/gpus/cuda/include", + includes = ["include"], + strip_include_prefix = "include", + visibility = ["@local_config_cuda//cuda:__pkg__"], +) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl new file mode 100644 index 00000000000000..d2015e737540c3 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_init_repositories.bzl @@ -0,0 +1,491 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_REDIST_PATH_PREFIX", + "CUDNN_REDIST_PATH_PREFIX", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +OS_ARCH_DICT = { + "amd64": "x86_64-unknown-linux-gnu", + "aarch64": "aarch64-unknown-linux-gnu", +} +_REDIST_ARCH_DICT = { + "linux-x86_64": "x86_64-unknown-linux-gnu", + "linux-sbsa": "aarch64-unknown-linux-gnu", +} + +SUPPORTED_ARCHIVE_EXTENSIONS = [ + ".zip", + ".jar", + ".war", + ".aar", + ".tar", + ".tar.gz", + ".tgz", + ".tar.xz", + ".txz", + ".tar.zst", + ".tzst", + ".tar.bz2", + ".tbz", + ".ar", + ".deb", + ".whl", +] + +def get_env_var(ctx, name): + return ctx.os.environ.get(name) + +def _get_file_name(url): + last_slash_index = url.rfind("/") + return url[last_slash_index + 1:] + +def get_archive_name(url): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the archive name without extension.""" + filename = _get_file_name(url) + for extension in SUPPORTED_ARCHIVE_EXTENSIONS: + if filename.endswith(extension): + return filename[:-len(extension)] + return filename + +LIB_EXTENSION = ".so." + +def _get_lib_name_and_version(path): + extension_index = path.rfind(LIB_EXTENSION) + last_slash_index = path.rfind("/") + lib_name = path[last_slash_index + 1:extension_index] + lib_version = path[extension_index + len(LIB_EXTENSION):] + return (lib_name, lib_version) + +def _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_dir_path = repository_ctx.path("lib") + if not lib_dir_path.exists: + return [] + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]).lower() + lib_dir_content = lib_dir_path.readdir() + return [ + str(f) + for f in lib_dir_content + if (LIB_EXTENSION in str(f) and + main_lib_name in str(f).lower()) + ] + +def get_lib_name_to_version_dict(repository_ctx): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns a dict of library names and major versions.""" + lib_name_to_version_dict = {} + for path in _get_libraries_by_redist_name_in_dir(repository_ctx): + lib_name, lib_version = _get_lib_name_and_version(path) + key = "%%{%s_version}" % lib_name.lower() + + # We need to find either major or major.minor version if there is no + # file with major version. E.g. if we have the following files: + # libcudart.so + # libcudart.so.12 + # libcudart.so.12.3.2, + # we will save save {"%{libcudart_version}": "12"}. + if len(lib_version.split(".")) == 1: + lib_name_to_version_dict[key] = lib_version + if (len(lib_version.split(".")) == 2 and + key not in lib_name_to_version_dict): + lib_name_to_version_dict[key] = lib_version + return lib_name_to_version_dict + +def create_dummy_build_file(repository_ctx, use_comment_symbols = True): + repository_ctx.template( + "BUILD", + repository_ctx.attr.build_templates[0], + { + "%{multiline_comment}": "'''" if use_comment_symbols else "", + "%{comment}": "#" if use_comment_symbols else "", + }, + ) + +def _get_build_template(repository_ctx, major_lib_version): + template = None + for i in range(0, len(repository_ctx.attr.versions)): + for dist_version in repository_ctx.attr.versions[i].split(","): + if dist_version == major_lib_version: + template = repository_ctx.attr.build_templates[i] + break + if not template: + fail("No build template found for {} version {}".format( + repository_ctx.name, + major_lib_version, + )) + return template + +def get_major_library_version(repository_ctx, lib_name_to_version_dict): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns the major library version provided the versions dict.""" + major_version = "" + if len(lib_name_to_version_dict) == 0: + return major_version + main_lib_name = "lib{}".format(repository_ctx.name.split("_")[1]) + key = "%%{%s_version}" % main_lib_name + major_version = lib_name_to_version_dict[key] + return major_version + +def create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_lib_version): + # buildifier: disable=function-docstring-args + """Creates a BUILD file for the repository.""" + if len(major_lib_version) == 0: + build_template_content = repository_ctx.read( + repository_ctx.attr.build_templates[0], + ) + if "_version}" not in build_template_content: + create_dummy_build_file(repository_ctx, use_comment_symbols = False) + else: + create_dummy_build_file(repository_ctx) + return + build_template = _get_build_template( + repository_ctx, + major_lib_version.split(".")[0], + ) + repository_ctx.template( + "BUILD", + build_template, + lib_name_to_version_dict | { + "%{multiline_comment}": "", + "%{comment}": "", + }, + ) + +def _create_symlinks(repository_ctx, local_path, dirs): + for dir in dirs: + repository_ctx.symlink( + "{path}/{dir}".format( + path = local_path, + dir = dir, + ), + dir, + ) + +def use_local_path(repository_ctx, local_path, dirs): + # buildifier: disable=function-docstring-args + """Creates repository using local redistribution paths.""" + _create_symlinks( + repository_ctx, + local_path, + dirs, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _use_local_cuda_path(repository_ctx, local_cuda_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDA repository.""" + use_local_path( + repository_ctx, + local_cuda_path, + ["include", "lib", "bin", "nvvm"], + ) + +def _use_local_cudnn_path(repository_ctx, local_cudnn_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic CUDNN repository.""" + use_local_path(repository_ctx, local_cudnn_path, ["include", "lib"]) + +def _download_redistribution(repository_ctx, arch_key, path_prefix): + (url, sha256) = repository_ctx.attr.url_dict[arch_key] + + # If url is not relative, then appending prefix is not needed. + if not (url.startswith("http") or url.startswith("file:///")): + url = path_prefix + url + archive_name = get_archive_name(url) + file_name = _get_file_name(url) + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + if repository_ctx.attr.override_strip_prefix: + strip_prefix = repository_ctx.attr.override_strip_prefix + else: + strip_prefix = archive_name + repository_ctx.extract( + archive = file_name, + stripPrefix = strip_prefix, + ) + repository_ctx.delete(file_name) + +def _use_downloaded_cuda_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDA redistribution and initializes hermetic CUDA repository.""" + major_version = "" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cuda_version: + # If no CUDA version is found, comment out all cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cuda_redist_path_prefix, + ) + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version(repository_ctx, lib_name_to_version_dict) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + repository_ctx.file("version.txt", major_version) + +def _cuda_repo_impl(repository_ctx): + local_cuda_path = get_env_var(repository_ctx, "LOCAL_CUDA_PATH") + if local_cuda_path: + _use_local_cuda_path(repository_ctx, local_cuda_path) + else: + _use_downloaded_cuda_redistribution(repository_ctx) + +cuda_repo = repository_rule( + implementation = _cuda_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cuda_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDA_PATH", + ], +) + +def _use_downloaded_cudnn_redistribution(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads CUDNN redistribution and initializes hermetic CUDNN repository.""" + cudnn_version = None + major_version = "" + cudnn_version = (get_env_var(repository_ctx, "HERMETIC_CUDNN_VERSION") or + get_env_var(repository_ctx, "TF_CUDNN_VERSION")) + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + if not cudnn_version: + # If no CUDNN version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + if len(repository_ctx.attr.url_dict) == 0: + print("{} is not found in redistributions list.".format( + repository_ctx.name, + )) # buildifier: disable=print + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch_key = OS_ARCH_DICT[repository_ctx.os.arch] + if arch_key not in repository_ctx.attr.url_dict.keys(): + arch_key = "cuda{version}_{arch}".format( + version = cuda_version.split(".")[0], + arch = arch_key, + ) + if arch_key not in repository_ctx.attr.url_dict.keys(): + fail( + ("The supported platforms are {supported_platforms}." + + " Platform {platform} is not supported for {dist_name}.") + .format( + supported_platforms = repository_ctx.attr.url_dict.keys(), + platform = arch_key, + dist_name = repository_ctx.name, + ), + ) + + _download_redistribution( + repository_ctx, + arch_key, + repository_ctx.attr.cudnn_redist_path_prefix, + ) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _cudnn_repo_impl(repository_ctx): + local_cudnn_path = get_env_var(repository_ctx, "LOCAL_CUDNN_PATH") + if local_cudnn_path: + _use_local_cudnn_path(repository_ctx, local_cudnn_path) + else: + _use_downloaded_cudnn_redistribution(repository_ctx) + +cudnn_repo = repository_rule( + implementation = _cudnn_repo_impl, + attrs = { + "url_dict": attr.string_list_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "override_strip_prefix": attr.string(), + "cudnn_redist_path_prefix": attr.string(), + }, + environ = [ + "HERMETIC_CUDNN_VERSION", + "TF_CUDNN_VERSION", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "LOCAL_CUDNN_PATH", + ], +) + +def _get_redistribution_urls(dist_info): + url_dict = {} + for arch in _REDIST_ARCH_DICT.keys(): + if "relative_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["relative_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + if "full_path" in dist_info[arch]: + url_dict[_REDIST_ARCH_DICT[arch]] = [ + dist_info[arch]["full_path"], + dist_info[arch].get("sha256", ""), + ] + continue + + for cuda_version, data in dist_info[arch].items(): + # CUDNN JSON might contain paths for each CUDA version. + path_key = "relative_path" + if path_key not in data.keys(): + path_key = "full_path" + url_dict["{cuda_version}_{arch}".format( + cuda_version = cuda_version, + arch = _REDIST_ARCH_DICT[arch], + )] = [data[path_key], data.get("sha256", "")] + return url_dict + +def get_version_and_template_lists(version_to_template): + # buildifier: disable=function-docstring-return + # buildifier: disable=function-docstring-args + """Returns lists of versions and templates provided in the dict.""" + template_to_version_map = {} + for version, template in version_to_template.items(): + if template not in template_to_version_map.keys(): + template_to_version_map[template] = [version] + else: + template_to_version_map[template].append(version) + version_list = [] + template_list = [] + for template, versions in template_to_version_map.items(): + version_list.append(",".join(versions)) + template_list.append(Label(template)) + return (version_list, template_list) + +def cudnn_redist_init_repository( + cudnn_redistributions, + cudnn_redist_path_prefix = CUDNN_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDNN repository.""" + if "cudnn" in cudnn_redistributions.keys(): + url_dict = _get_redistribution_urls(cudnn_redistributions["cudnn"]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates["cudnn"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cudnn_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cudnn_redist_path_prefix = cudnn_redist_path_prefix, + ) + +def cuda_redist_init_repositories( + cuda_redistributions, + cuda_redist_path_prefix = CUDA_REDIST_PATH_PREFIX, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes CUDA repositories.""" + for redist_name, _ in redist_versions_to_build_templates.items(): + if redist_name in ["cudnn", "cuda_nccl"]: + continue + if redist_name in cuda_redistributions.keys(): + url_dict = _get_redistribution_urls(cuda_redistributions[redist_name]) + else: + url_dict = {} + repo_data = redist_versions_to_build_templates[redist_name] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_repo( + name = repo_data["repo_name"], + versions = versions, + build_templates = templates, + url_dict = url_dict, + cuda_redist_path_prefix = cuda_redist_path_prefix, + ) diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl new file mode 100644 index 00000000000000..d7ccff736a4801 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda/hermetic/cuda_redist_versions.bzl @@ -0,0 +1,243 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic CUDA redistribution versions.""" + +CUDA_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cuda/redist/" +CUDNN_REDIST_PATH_PREFIX = "https://developer.download.nvidia.com/compute/cudnn/redist/" + +CUDA_REDIST_JSON_DICT = { + "11.8": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_11.8.0.json", + "941a950a4ab3b95311c50df7b3c8bca973e0cdda76fc2f4b456d2d5e4dac0281", + ], + "12.1.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.1.1.json", + "bafea3cb83a4cf5c764eeedcaac0040d0d3c5db3f9a74550da0e7b6ac24d378c", + ], + "12.2.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.2.0.json", + "d883762c6339c8ebb3ffb072facc8f7265cd257d2db16a475fff9a9306ecea89", + ], + "12.3.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.1.json", + "b3cc4181d711cf9b6e3718f323b23813c24f9478119911d7b4bceec9b437dbc3", + ], + "12.3.2": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.3.2.json", + "1b6eacf335dd49803633fed53ef261d62c193e5a56eee5019e7d2f634e39e7ef", + ], + "12.4.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.0.json", + "a4f496b8d5299939b34c9ef88dc4274821f8c9451b2d7c9bcee53166932da067", + ], + "12.4.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.4.1.json", + "9cd815f3b71c2e3686ef2219b7794b81044f9dcefaa8e21dacfcb5bc4d931892", + ], + "12.5.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.0.json", + "166664b520bfe51f27abcc8c7a934f4cb6ea287f8c399b5f8255f6f4d214569a", + ], + "12.5.1": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.5.1.json", + "7ab9c76014ae4907fa1b51738af599607a5fd8ca3a5c4bb4c3b31338cc642a93", + ], + "12.6.0": [ + "https://developer.download.nvidia.com/compute/cuda/redist/redistrib_12.6.0.json", + "87740b01676b3d18982982ab96ec7fa1a626d03a96df070a6b0f258d01ff5fab", + ], +} + +CUDNN_REDIST_JSON_DICT = { + "8.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.6.0.json", + "7f6f50bed4fd8216dc10d6ef505771dc0ecc99cce813993ab405cb507a21d51d", + ], + "8.9.4.25": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.4.25.json", + "02258dba8384860c9230fe3c78522e7bd8e350e461ccd37a8d932cb64127ba57", + ], + "8.9.6": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.6.json", + "6069ef92a2b9bb18cebfbc944964bd2b024b76f2c2c35a43812982e0bc45cf0c", + ], + "8.9.7.29": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_8.9.7.29.json", + "a0734f26f068522464fa09b2f2c186dfbe6ad7407a88ea0c50dd331f0c3389ec", + ], + "9.1.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.1.1.json", + "d22d569405e5683ff8e563d00d6e8c27e5e6a902c564c23d752b22a8b8b3fe20", + ], + "9.2.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.0.json", + "6852eb279b95d2b5775f7a7737ec133bed059107f863cdd8588f3ae6f13eadd7", + ], + "9.2.1": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.2.1.json", + "9a4198c59b2e66b2b115a736ebe4dc8f3dc6d78161bb494702f824da8fc77b99", + ], + "9.3.0": [ + "https://developer.download.nvidia.com/compute/cudnn/redist/redistrib_9.3.0.json", + "d17d9a7878365736758550294f03e633a0b023bec879bf173349bfb34781972e", + ], +} + +# The versions are different for x86 and aarch64 architectures because only +# NCCL release versions 2.20.3 and 2.20.5 have the wheels for aarch64. +CUDA_12_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", + }, + "aarch64-unknown-linux-gnu": { + "version": "2.20.5", + "url": "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", + "sha256": "1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", + }, +} + +CUDA_11_NCCL_WHEEL_DICT = { + "x86_64-unknown-linux-gnu": { + "version": "2.21.5", + "url": "https://files.pythonhosted.org/packages/ac/9a/8b6a28b3b87d5fddab0e92cd835339eb8fbddaa71ae67518c8c1b3d05bae/nvidia_nccl_cu11-2.21.5-py3-none-manylinux2014_x86_64.whl", + "sha256": "49d8350629c7888701d1fd200934942671cb5c728f49acc5a0b3a768820bed29", + }, +} + +CUDA_NCCL_WHEELS = { + "11.8": CUDA_11_NCCL_WHEEL_DICT, + "12.1.1": CUDA_12_NCCL_WHEEL_DICT, + "12.2.0": CUDA_12_NCCL_WHEEL_DICT, + "12.3.1": CUDA_12_NCCL_WHEEL_DICT, + "12.3.2": CUDA_12_NCCL_WHEEL_DICT, + "12.4.0": CUDA_12_NCCL_WHEEL_DICT, + "12.1.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.0": CUDA_12_NCCL_WHEEL_DICT, + "12.5.1": CUDA_12_NCCL_WHEEL_DICT, + "12.6.0": CUDA_12_NCCL_WHEEL_DICT, +} + +REDIST_VERSIONS_TO_BUILD_TEMPLATES = { + "cuda_nccl": { + "repo_name": "cuda_nccl", + "version_to_template": { + "2": "//third_party/nccl/hermetic:cuda_nccl.BUILD.tpl", + }, + }, + "cudnn": { + "repo_name": "cuda_cudnn", + "version_to_template": { + "9": "//third_party/gpus/cuda/hermetic:cuda_cudnn9.BUILD.tpl", + "8": "//third_party/gpus/cuda/hermetic:cuda_cudnn.BUILD.tpl", + }, + }, + "libcublas": { + "repo_name": "cuda_cublas", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cublas.BUILD.tpl", + }, + }, + "cuda_cudart": { + "repo_name": "cuda_cudart", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cudart.BUILD.tpl", + }, + }, + "libcufft": { + "repo_name": "cuda_cufft", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + "10": "//third_party/gpus/cuda/hermetic:cuda_cufft.BUILD.tpl", + }, + }, + "cuda_cupti": { + "repo_name": "cuda_cupti", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cupti.BUILD.tpl", + }, + }, + "libcurand": { + "repo_name": "cuda_curand", + "version_to_template": { + "10": "//third_party/gpus/cuda/hermetic:cuda_curand.BUILD.tpl", + }, + }, + "libcusolver": { + "repo_name": "cuda_cusolver", + "version_to_template": { + "11": "//third_party/gpus/cuda/hermetic:cuda_cusolver.BUILD.tpl", + }, + }, + "libcusparse": { + "repo_name": "cuda_cusparse", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cusparse.BUILD.tpl", + }, + }, + "libnvjitlink": { + "repo_name": "cuda_nvjitlink", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvjitlink.BUILD.tpl", + }, + }, + "cuda_nvrtc": { + "repo_name": "cuda_nvrtc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvrtc.BUILD.tpl", + }, + }, + "cuda_cccl": { + "repo_name": "cuda_cccl", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_cccl.BUILD.tpl", + }, + }, + "cuda_nvcc": { + "repo_name": "cuda_nvcc", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvcc.BUILD.tpl", + }, + }, + "cuda_nvml_dev": { + "repo_name": "cuda_nvml", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvml.BUILD.tpl", + }, + }, + "cuda_nvprune": { + "repo_name": "cuda_nvprune", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvprune.BUILD.tpl", + }, + }, + "cuda_nvtx": { + "repo_name": "cuda_nvtx", + "version_to_template": { + "12": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + "11": "//third_party/gpus/cuda/hermetic:cuda_nvtx.BUILD.tpl", + }, + }, +} diff --git a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl index fefbf081c87e1c..8bf1db2b0f8f9f 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/cuda_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for CUDA autoconfiguration. +NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + `cuda_configure` depends on the following environment variables: * `TF_NEED_CUDA`: Whether to enable building with CUDA. @@ -53,6 +55,11 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "get_cxx_inc_directories", + "to_list_of_strings", +) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" _GCC_HOST_COMPILER_PREFIX = "GCC_HOST_COMPILER_PREFIX" @@ -67,20 +74,6 @@ _TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _TF_DOWNLOAD_CLANG = "TF_DOWNLOAD_CLANG" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" -def to_list_of_strings(elements): - """Convert the list of ["a", "b", "c"] into '"a", "b", "c"'. - - This is to be used to put a list of strings into the bzl file templates - so it gets interpreted as list of strings in Starlark. - - Args: - elements: list of string elements - - Returns: - single string of elements wrapped in quotes separated by a comma.""" - quoted_strings = ["\"" + element + "\"" for element in elements] - return ", ".join(quoted_strings) - def verify_build_defines(params): """Verify all variables that crosstool/BUILD.tpl expects are substituted. @@ -238,156 +231,6 @@ def find_cc(repository_ctx, use_cuda_clang): " environment variable").format(target_cc_name, cc_path_envvar)) return cc -_INC_DIR_MARKER_BEGIN = "#include <...>" - -# OSX add " (framework directory)" at the end of line, strip it. -_OSX_FRAMEWORK_SUFFIX = " (framework directory)" -_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX) - -def _cxx_inc_convert(path): - """Convert path returned by cc -E xc++ in a complete path.""" - path = path.strip() - if path.endswith(_OSX_FRAMEWORK_SUFFIX): - path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip() - return path - -def _normalize_include_path(repository_ctx, path): - """Normalizes include paths before writing them to the crosstool. - - If path points inside the 'crosstool' folder of the repository, a relative - path is returned. - If path points outside the 'crosstool' folder, an absolute path is returned. - """ - path = str(repository_ctx.path(path)) - crosstool_folder = str(repository_ctx.path(".").get_child("crosstool")) - - if path.startswith(crosstool_folder): - # We drop the path to "$REPO/crosstool" and a trailing path separator. - return path[len(crosstool_folder) + 1:] - return path - -def _is_compiler_option_supported(repository_ctx, cc, option): - """Checks that `option` is supported by the C compiler. Doesn't %-escape the option.""" - result = repository_ctx.execute([ - cc, - option, - "-o", - "/dev/null", - "-c", - str(repository_ctx.path("tools/cpp/empty.cc")), - ]) - return result.stderr.find(option) == -1 - -def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp, tf_sysroot): - """Compute the list of default C or C++ include directories.""" - if lang_is_cpp: - lang = "c++" - else: - lang = "c" - sysroot = [] - if tf_sysroot: - sysroot += ["--sysroot", tf_sysroot] - result = raw_exec(repository_ctx, [cc, "-E", "-x" + lang, "-", "-v"] + - sysroot) - stderr = err_out(result) - index1 = stderr.find(_INC_DIR_MARKER_BEGIN) - if index1 == -1: - return [] - index1 = stderr.find("\n", index1) - if index1 == -1: - return [] - index2 = stderr.rfind("\n ") - if index2 == -1 or index2 < index1: - return [] - index2 = stderr.find("\n", index2 + 1) - if index2 == -1: - inc_dirs = stderr[index1 + 1:] - else: - inc_dirs = stderr[index1 + 1:index2].strip() - - print_resource_dir_supported = _is_compiler_option_supported( - repository_ctx, - cc, - "-print-resource-dir", - ) - - if print_resource_dir_supported: - resource_dir = repository_ctx.execute( - [cc, "-print-resource-dir"], - ).stdout.strip() + "/share" - inc_dirs += "\n" + resource_dir - - compiler_includes = [ - _normalize_include_path(repository_ctx, _cxx_inc_convert(p)) - for p in inc_dirs.split("\n") - ] - - # The compiler might be on a symlink, e.g. /symlink -> /opt/gcc - # The above keeps only the resolved paths to the default includes (e.g. /opt/gcc/include/c++/11) - # but Bazel might encounter either (usually reported by the compiler) - # especially when a compiler wrapper (e.g. ccache) is used. - # So we need to also include paths where symlinks are not resolved. - - # Try to find real path to CC installation to "see through" compiler wrappers - # GCC has the path to g++ - index1 = result.stderr.find("COLLECT_GCC=") - if index1 != -1: - index1 = result.stderr.find("=", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname.dirname - else: - # Clang has the directory - index1 = result.stderr.find("InstalledDir: ") - if index1 != -1: - index1 = result.stderr.find(" ", index1) - index2 = result.stderr.find("\n", index1) - cc_topdir = repository_ctx.path(result.stderr[index1 + 1:index2]).dirname - else: - # Fallback to the CC path - cc_topdir = repository_ctx.path(cc).dirname.dirname - - # We now have the compiler installation prefix, e.g. /symlink/gcc - # And the resolved installation prefix, e.g. /opt/gcc - cc_topdir_resolved = str(realpath(repository_ctx, cc_topdir)).strip() - cc_topdir = str(cc_topdir).strip() - - # If there is (any!) symlink involved we add paths where the unresolved installation prefix is kept. - # e.g. [/opt/gcc/include/c++/11, /opt/gcc/lib/gcc/x86_64-linux-gnu/11/include, /other/path] - # adds [/symlink/include/c++/11, /symlink/lib/gcc/x86_64-linux-gnu/11/include] - if cc_topdir_resolved != cc_topdir: - unresolved_compiler_includes = [ - cc_topdir + inc[len(cc_topdir_resolved):] - for inc in compiler_includes - if inc.startswith(cc_topdir_resolved) - ] - compiler_includes = compiler_includes + unresolved_compiler_includes - return compiler_includes - -def get_cxx_inc_directories(repository_ctx, cc, tf_sysroot): - """Compute the list of default C and C++ include directories.""" - - # For some reason `clang -xc` sometimes returns include paths that are - # different from the ones from `clang -xc++`. (Symlink and a dir) - # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists - includes_cpp = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - True, - tf_sysroot, - ) - includes_c = _get_cxx_inc_directories_impl( - repository_ctx, - cc, - False, - tf_sysroot, - ) - - return includes_cpp + [ - inc - for inc in includes_c - if inc not in includes_cpp - ] - def auto_configure_fail(msg): """Output failure message when cuda configuration fails.""" red = "\033[0;31m" @@ -1293,6 +1136,7 @@ def _create_local_cuda_repository(repository_ctx): cuda_defines["%{extra_no_canonical_prefixes_flags}"] = "" cuda_defines["%{unfiltered_compile_flags}"] = "" + cuda_defines["%{cuda_nvcc_files}"] = "[]" if is_cuda_clang and not is_nvcc_and_clang: cuda_defines["%{host_compiler_path}"] = str(cc) cuda_defines["%{host_compiler_warnings}"] = """ diff --git a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py index b88694af5c014d..68623bf671da71 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py +++ b/third_party/xla/third_party/tsl/third_party/gpus/find_cuda_config.py @@ -14,6 +14,9 @@ # ============================================================================== """Prints CUDA library and header directories and versions found on the system. +NB: DEPRECATED! This script is a part of the deprecated `cuda_configure` rule. +Please use `hermetic/cuda_configure` instead. + The script searches for CUDA library and header files on the system, inspects them to determine their version and prints the configuration to stdout. The paths to inspect and the required versions are specified through environment diff --git a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl index c185ca7c48cc02..fb63d4db886c1c 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/rocm_configure.bzl @@ -22,12 +22,15 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "enable_cuda", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) load( ":sycl_configure.bzl", @@ -205,6 +208,8 @@ def _rocm_include_path(repository_ctx, rocm_config, bash_bin): inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17.0.0/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/17/include") inc_dirs.append(rocm_toolkit_path + "/llvm/lib/clang/18/include") + if int(rocm_config.rocm_version_number) >= 60200: + inc_dirs.append(rocm_toolkit_path + "/lib/llvm/lib/clang/18/include") # Support hcc based off clang 10.0.0 (for ROCm 3.3) inc_dirs.append(rocm_toolkit_path + "/hcc/compiler/lib/clang/10.0.0/include/") diff --git a/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl b/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl index 05330b2fe53195..dd80694e7274f5 100644 --- a/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/gpus/sycl_configure.bzl @@ -16,11 +16,14 @@ load( "realpath", "which", ) +load( + ":compiler_common_tools.bzl", + "to_list_of_strings", +) load( ":cuda_configure.bzl", "make_copy_dir_rule", "make_copy_files_rule", - "to_list_of_strings", ) _GCC_HOST_COMPILER_PATH = "GCC_HOST_COMPILER_PATH" diff --git a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD index 4b3ad84d836933..8c730960bc3ed3 100644 --- a/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD +++ b/third_party/xla/third_party/tsl/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -12,7 +12,7 @@ _CMAKE_COMMON_LIST = { "#cmakedefine DNNL_SYCL_CUDA": "#undef DNNL_SYCL_CUDA", "#cmakedefine DNNL_SYCL_HIP": "#undef DNNL_SYCL_HIP", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", + "#cmakedefine ONEDNN_BUILD_GRAPH": "#define ONEDNN_BUILD_GRAPH", "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#define DNNL_EXPERIMENTAL_SPARSE", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", @@ -109,6 +109,7 @@ _COPTS_LIST = select({ "-UUSE_CBLAS", "-DDNNL_ENABLE_MAX_CPU_ISA", "-DDNNL_ENABLE_ITT_TASKS", + "-DDNNL_ENABLE_GRAPH_DUMP", ] + tf_openmp_copts() _INCLUDES_LIST = [ @@ -119,6 +120,7 @@ _INCLUDES_LIST = [ "src/cpu", "src/cpu/gemm", "src/cpu/x64/xbyak", + "src/graph", ] _TEXTUAL_HDRS_LIST = glob([ @@ -129,6 +131,15 @@ _TEXTUAL_HDRS_LIST = glob([ "src/cpu/**/*.hpp", "src/cpu/jit_utils/**/*.hpp", "src/cpu/x64/xbyak/*.h", + "src/graph/interface/*.hpp", + "src/graph/backend/*.hpp", + "src/graph/backend/dnnl/*.hpp", + "src/graph/backend/fake/*.hpp", + "src/graph/backend/dnnl/passes/*.hpp", + "src/graph/backend/dnnl/patterns/*.hpp", + "src/graph/backend/dnnl/kernels/*.hpp", + "src/graph/utils/*.hpp", + "src/graph/utils/pm/*.hpp", ]) + [ ":dnnl_config_h", ":dnnl_version_h", @@ -160,6 +171,16 @@ cc_library( "src/cpu/**/*.cpp", "src/common/ittnotify/*.c", "src/cpu/jit_utils/**/*.cpp", + "src/cpu/x64/**/*.cpp", + "src/graph/interface/*.cpp", + "src/graph/backend/*.cpp", + "src/graph/backend/dnnl/*.cpp", + "src/graph/backend/fake/*.cpp", + "src/graph/backend/dnnl/passes/*.cpp", + "src/graph/backend/dnnl/patterns/*.cpp", + "src/graph/backend/dnnl/kernels/*.cpp", + "src/graph/utils/*.cpp", + "src/graph/utils/pm/*.cpp", ], exclude = [ "src/cpu/aarch64/**", diff --git a/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl b/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl index 53a6d4e1e41890..a0930df34ecec8 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl +++ b/third_party/xla/third_party/tsl/third_party/nccl/build_defs.bzl.tpl @@ -5,7 +5,6 @@ load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") # CUDA toolkit version as tuple (e.g. '(11, 1)'). _cuda_version = %{cuda_version} -_cuda_clang = %{cuda_clang} def _rdc_copts(): """Returns copts for compiling relocatable device code.""" @@ -121,25 +120,25 @@ _device_link = rule( "gpu_archs": attr.string_list(), "nvlink_args": attr.string_list(), "_nvlink": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvlink"), + default = Label("%{nvlink_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_fatbinary": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/fatbinary"), + default = Label("%{fatbinary_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_bin2c": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/bin2c"), + default = Label("%{bin2c_label}"), allow_single_file = True, executable = True, cfg = "host", ), "_link_stub": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/crt/link.stub"), + default = Label("%{link_stub_label}"), allow_single_file = True, ), }, @@ -189,7 +188,7 @@ _prune_relocatable_code = rule( "input": attr.label(mandatory = True, allow_files = True), "gpu_archs": attr.string_list(), "_nvprune": attr.label( - default = Label("@local_config_cuda//cuda:cuda/bin/nvprune"), + default = Label("%{nvprune_label}"), allow_single_file = True, executable = True, cfg = "host", diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/BUILD b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/BUILD new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl new file mode 100644 index 00000000000000..61d7809bcdaad1 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/cuda_nccl.BUILD.tpl @@ -0,0 +1,30 @@ +licenses(["restricted"]) # NVIDIA proprietary license + +exports_files([ + "version.txt", +]) +%{multiline_comment} +cc_import( + name = "nccl_shared_library", + shared_library = "lib/libnccl.so.%{libnccl_version}", + hdrs = [":headers"], + deps = ["@local_config_cuda//cuda:cuda_headers", ":headers"], +) +%{multiline_comment} +cc_library( + name = "nccl", + %{comment}deps = [":nccl_shared_library"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "headers", + %{comment}hdrs = glob([ + %{comment}"include/nccl*.h", + %{comment}]), + include_prefix = "third_party/nccl", + includes = ["include/"], + strip_include_prefix = "include", + visibility = ["//visibility:public"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl new file mode 100644 index 00000000000000..75f5a10b6fe24e --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_configure.bzl @@ -0,0 +1,183 @@ +"""Repository rule for hermetic NCCL configuration. + +`nccl_configure` depends on the following environment variables: + + * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should + be used, "0" if NCCL should be linked in statically. + * `HERMETIC_CUDA_VERSION`: The version of the CUDA toolkit. If not specified, + the version will be determined by the `TF_CUDA_VERSION`. + +""" + +load( + "//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "HERMETIC_CUDA_VERSION", + "TF_CUDA_VERSION", + "TF_NEED_CUDA", + "enable_cuda", + "get_cuda_version", +) +load( + "//third_party/remote_config:common.bzl", + "get_cpu_value", + "get_host_environ", +) + +_TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" + +_NCCL_DUMMY_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl", + visibility = ["//visibility:public"], +) + +cc_library( + name = "nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +_NCCL_ARCHIVE_STUB_BUILD_CONTENT = """ +filegroup( + name = "LICENSE", + data = ["@nccl_archive//:LICENSE.txt"], + visibility = ["//visibility:public"], +) + +alias( + name = "nccl", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:nccl", + "//conditions:default": "@nccl_archive//:nccl_via_stub", + }), + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_headers", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": "@cuda_nccl//:headers", + "//conditions:default": "@nccl_archive//:nccl_headers", + }), + visibility = ["//visibility:public"], +) + +cc_library( + name = "hermetic_nccl_config", + hdrs = ["nccl_config.h"], + include_prefix = "third_party/nccl", + visibility = ["//visibility:public"], +) + +alias( + name = "nccl_config", + actual = select({ + "@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": ":hermetic_nccl_config", + "//conditions:default": "@nccl_archive//:nccl_config", + }), + visibility = ["//visibility:public"], +) +""" + +def _create_local_nccl_repository(repository_ctx): + cuda_version = get_cuda_version(repository_ctx).split(".")[:2] + nccl_version = repository_ctx.read(repository_ctx.attr.nccl_version) + + if get_host_environ(repository_ctx, _TF_NCCL_USE_STUB, "0") == "0": + repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT) + else: + repository_ctx.file("BUILD", _NCCL_ARCHIVE_STUB_BUILD_CONTENT) + + repository_ctx.template("generated_names.bzl", repository_ctx.attr.generated_names_tpl, {}) + repository_ctx.template( + "build_defs.bzl", + repository_ctx.attr.build_defs_tpl, + { + "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), + "%{nvlink_label}": "@cuda_nvcc//:nvlink", + "%{fatbinary_label}": "@cuda_nvcc//:fatbinary", + "%{bin2c_label}": "@cuda_nvcc//:bin2c", + "%{link_stub_label}": "@cuda_nvcc//:link_stub", + "%{nvprune_label}": "@cuda_nvprune//:nvprune", + }, + ) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"%s\"" % nccl_version) + +def _nccl_autoconf_impl(repository_ctx): + if (not enable_cuda(repository_ctx) or + get_cpu_value(repository_ctx) != "Linux"): + # Add a dummy build file to make bazel query happy. + repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT) + repository_ctx.file("nccl_config.h", "#define TF_NCCL_VERSION \"\"") + else: + _create_local_nccl_repository(repository_ctx) + +_ENVIRONS = [ + TF_NEED_CUDA, + TF_CUDA_VERSION, + _TF_NCCL_USE_STUB, + HERMETIC_CUDA_VERSION, + "LOCAL_NCCL_PATH", +] + +nccl_configure = repository_rule( + environ = _ENVIRONS, + implementation = _nccl_autoconf_impl, + attrs = { + "environ": attr.string_dict(), + "nccl_version": attr.label(default = Label("@cuda_nccl//:version.txt")), + "generated_names_tpl": attr.label(default = Label("//third_party/nccl:generated_names.bzl.tpl")), + "build_defs_tpl": attr.label(default = Label("//third_party/nccl:build_defs.bzl.tpl")), + }, +) +"""Downloads and configures the hermetic NCCL configuration. + +Add the following to your WORKSPACE file: + +```python +nccl_configure(name = "local_config_nccl") +``` + +Args: + name: A unique name for this workspace rule. +""" # buildifier: disable=no-effect diff --git a/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl new file mode 100644 index 00000000000000..244cb851ddf591 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/nccl/hermetic/nccl_redist_init_repository.bzl @@ -0,0 +1,145 @@ +# Copyright 2024 The TensorFlow Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Hermetic NCCL repositories initialization. Consult the WORKSPACE on how to use it.""" + +load("//third_party:repo.bzl", "tf_mirror_urls") +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "OS_ARCH_DICT", + "create_build_file", + "create_dummy_build_file", + "get_archive_name", + "get_env_var", + "get_lib_name_to_version_dict", + "get_major_library_version", + "get_version_and_template_lists", + "use_local_path", +) +load( + "//third_party/gpus/cuda/hermetic:cuda_redist_versions.bzl", + "CUDA_NCCL_WHEELS", + "REDIST_VERSIONS_TO_BUILD_TEMPLATES", +) + +def _use_downloaded_nccl_wheel(repository_ctx): + # buildifier: disable=function-docstring-args + """ Downloads NCCL wheel and inits hermetic NCCL repository.""" + cuda_version = (get_env_var(repository_ctx, "HERMETIC_CUDA_VERSION") or + get_env_var(repository_ctx, "TF_CUDA_VERSION")) + major_version = "" + if not cuda_version: + # If no CUDA version is found, comment out cc_import targets. + create_dummy_build_file(repository_ctx) + repository_ctx.file("version.txt", major_version) + return + + # Download archive only when GPU config is used. + arch = OS_ARCH_DICT[repository_ctx.os.arch] + dict_key = "{cuda_version}-{arch}".format( + cuda_version = cuda_version, + arch = arch, + ) + supported_versions = repository_ctx.attr.url_dict.keys() + if dict_key not in supported_versions: + fail( + ("The supported NCCL versions are {supported_versions}." + + " Please provide a supported version in HERMETIC_CUDA_VERSION" + + " environment variable or add NCCL distribution for" + + " CUDA version={version}, OS={arch}.") + .format( + supported_versions = supported_versions, + version = cuda_version, + arch = arch, + ), + ) + sha256 = repository_ctx.attr.sha256_dict[dict_key] + url = repository_ctx.attr.url_dict[dict_key] + + archive_name = get_archive_name(url) + file_name = archive_name + ".zip" + + print("Downloading and extracting {}".format(url)) # buildifier: disable=print + repository_ctx.download( + url = tf_mirror_urls(url), + output = file_name, + sha256 = sha256, + ) + repository_ctx.extract( + archive = file_name, + stripPrefix = repository_ctx.attr.strip_prefix, + ) + repository_ctx.delete(file_name) + + lib_name_to_version_dict = get_lib_name_to_version_dict(repository_ctx) + major_version = get_major_library_version( + repository_ctx, + lib_name_to_version_dict, + ) + create_build_file( + repository_ctx, + lib_name_to_version_dict, + major_version, + ) + + repository_ctx.file("version.txt", major_version) + +def _use_local_nccl_path(repository_ctx, local_nccl_path): + # buildifier: disable=function-docstring-args + """ Creates symlinks and initializes hermetic NCCL repository.""" + use_local_path(repository_ctx, local_nccl_path, ["include", "lib"]) + +def _cuda_nccl_repo_impl(repository_ctx): + local_nccl_path = get_env_var(repository_ctx, "LOCAL_NCCL_PATH") + if local_nccl_path: + _use_local_nccl_path(repository_ctx, local_nccl_path) + else: + _use_downloaded_nccl_wheel(repository_ctx) + +cuda_nccl_repo = repository_rule( + implementation = _cuda_nccl_repo_impl, + attrs = { + "sha256_dict": attr.string_dict(mandatory = True), + "url_dict": attr.string_dict(mandatory = True), + "versions": attr.string_list(mandatory = True), + "build_templates": attr.label_list(mandatory = True), + "strip_prefix": attr.string(), + }, + environ = ["HERMETIC_CUDA_VERSION", "TF_CUDA_VERSION", "LOCAL_NCCL_PATH"], +) + +def nccl_redist_init_repository( + cuda_nccl_wheels = CUDA_NCCL_WHEELS, + redist_versions_to_build_templates = REDIST_VERSIONS_TO_BUILD_TEMPLATES): + # buildifier: disable=function-docstring-args + """Initializes NCCL repository.""" + nccl_artifacts_dict = {"sha256_dict": {}, "url_dict": {}} + for cuda_version, nccl_wheel_info in cuda_nccl_wheels.items(): + for arch in OS_ARCH_DICT.values(): + if arch in nccl_wheel_info.keys(): + cuda_version_to_arch_key = "%s-%s" % (cuda_version, arch) + nccl_artifacts_dict["sha256_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch].get("sha256", "") + nccl_artifacts_dict["url_dict"][cuda_version_to_arch_key] = nccl_wheel_info[arch]["url"] + repo_data = redist_versions_to_build_templates["cuda_nccl"] + versions, templates = get_version_and_template_lists( + repo_data["version_to_template"], + ) + cuda_nccl_repo( + name = repo_data["repo_name"], + sha256_dict = nccl_artifacts_dict["sha256_dict"], + url_dict = nccl_artifacts_dict["url_dict"], + versions = versions, + build_templates = templates, + strip_prefix = "nvidia/nccl", + ) diff --git a/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl b/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl index 22cf64d4771062..59f8b5c08ef0ee 100644 --- a/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl +++ b/third_party/xla/third_party/tsl/third_party/nccl/nccl_configure.bzl @@ -1,5 +1,7 @@ """Repository rule for NCCL configuration. +NB: DEPRECATED! Use `hermetic/nccl_configure` rule instead. + `nccl_configure` depends on the following environment variables: * `TF_NCCL_VERSION`: Installed NCCL version or empty to build from source. @@ -8,7 +10,6 @@ files. * `TF_CUDA_PATHS`: The base paths to look for CUDA and cuDNN. Default is `/usr/local/cuda,usr/`. - * `TF_CUDA_CLANG`: "1" if using Clang, "0" if using NVCC. * `TF_NCCL_USE_STUB`: "1" if a NCCL stub that loads NCCL dynamically should be used, "0" if NCCL should be linked in statically. @@ -33,7 +34,6 @@ _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" _TF_NCCL_VERSION = "TF_NCCL_VERSION" _TF_NEED_CUDA = "TF_NEED_CUDA" _TF_CUDA_PATHS = "TF_CUDA_PATHS" -_TF_CUDA_CLANG = "TF_CUDA_CLANG" _TF_NCCL_USE_STUB = "TF_NCCL_USE_STUB" _DEFINE_NCCL_MAJOR = "#define NCCL_MAJOR" @@ -129,7 +129,11 @@ def _create_local_nccl_repository(repository_ctx): _label("build_defs.bzl.tpl"), { "%{cuda_version}": "(%s, %s)" % tuple(cuda_version), - "%{cuda_clang}": repr(get_host_environ(repository_ctx, _TF_CUDA_CLANG)), + "%{nvlink_label}": "@local_config_cuda//cuda:cuda/bin/nvlink", + "%{fatbinary_label}": "@local_config_cuda//cuda:cuda/bin/fatbinary", + "%{bin2c_label}": "@local_config_cuda//cuda:cuda/bin/bin2c", + "%{link_stub_label}": "@local_config_cuda//cuda:cuda/bin/crt/link.stub", + "%{nvprune_label}": "@local_config_cuda//cuda:cuda/bin/nvprune", }, ) else: @@ -181,7 +185,6 @@ _ENVIRONS = [ _TF_CUDA_COMPUTE_CAPABILITIES, _TF_NEED_CUDA, _TF_CUDA_PATHS, - _TF_CUDA_CLANG, ] remote_nccl_configure = repository_rule( diff --git a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl index f8fdd1033b5e2f..13aed2b687129f 100644 --- a/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/tsl/third_party/py/python_repo.bzl @@ -255,8 +255,12 @@ def _basic_wildcard_match(name, patterns, expected_match_result, match_all): def _custom_python_interpreter_impl(ctx): version = ctx.attr.version - strip_prefix = ctx.attr.strip_prefix.format(version = version) - urls = [url.format(version = version) for url in ctx.attr.urls] + version_variant = ctx.attr.version_variant + strip_prefix = ctx.attr.strip_prefix.format( + version = version, + version_variant = version_variant, + ) + urls = [url.format(version = version, version_variant = version_variant) for url in ctx.attr.urls] binary_name = ctx.attr.binary_name if not binary_name: ver_chunks = version.split(".") @@ -272,13 +276,12 @@ def _custom_python_interpreter_impl(ctx): output = srcs_dir, ) - configure_params = [] + configure_params = list(ctx.attr.configure_params) if "CC" in ctx.os.environ: configure_params.append("CC={}".format(ctx.os.environ["CC"])) if "CXX" in ctx.os.environ: configure_params.append("CXX={}".format(ctx.os.environ["CXX"])) - configure_params.append("--enable-optimizations") configure_params.append("--prefix=%s" % install_path.realpath) _exec_and_check( ctx, @@ -361,6 +364,11 @@ custom_python_interpreter = repository_rule( "strip_prefix": attr.string(), "binary_name": attr.string(mandatory = False), "version": attr.string(), + "version_variant": attr.string(), + "configure_params": attr.string_list( + mandatory = False, + default = ["--enable-optimizations"], + ), }, ) diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD new file mode 100644 index 00000000000000..8d626dc7635d1a --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/BUILD @@ -0,0 +1,7 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//visibility:public"], + licenses = ["notice"], +) + +# spirv_llvm_translator license placeholder diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD new file mode 100644 index 00000000000000..557e2e8f50edd2 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.BUILD @@ -0,0 +1,34 @@ +cc_library( + name = "spirv_llvm_translator", + srcs = glob([ + "lib/SPIRV/libSPIRV/*.cpp", + "lib/SPIRV/libSPIRV/*.hpp", + "lib/SPIRV/libSPIRV/*.h", + "lib/SPIRV/Mangler/*.cpp", + "lib/SPIRV/Mangler/*.h", + "lib/SPIRV/*.cpp", + "lib/SPIRV/*.hpp", + "lib/SPIRV/*.h", + ]), + hdrs = glob(["include/*"]), + includes = [ + "include/", + "lib/SPIRV/", + "lib/SPIRV/Mangler/", + "lib/SPIRV/libSPIRV/", + ], + visibility = ["//visibility:public"], + deps = [ + "@llvm-project//llvm:Analysis", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:CodeGen", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Demangle", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Linker", + "@llvm-project//llvm:Passes", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@spirv_headers//:spirv_cpp_headers", + ], +) diff --git a/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch new file mode 100644 index 00000000000000..fc843b1b039b09 --- /dev/null +++ b/third_party/xla/third_party/tsl/third_party/spirv_llvm_translator/spirv_llvm_translator.patch @@ -0,0 +1,25 @@ +diff --git a/lib/SPIRV/SPIRVInternal.h b/lib/SPIRV/SPIRVInternal.h +index a828add8..924e13b4 100644 + +Spir backend uses different addrspace representations link with nvptx backend link. +We reorder the enum value here so that we can make XLA LLVM codegen simple(avoiding +changing addrspace based on device backend everywhere) + +--- a/lib/SPIRV/SPIRVInternal.h ++++ b/lib/SPIRV/SPIRVInternal.h +@@ -179,11 +179,12 @@ typedef SPIRVMap IntBoolOpMap; + "-v512:512:512-v1024:1024:1024" + + enum SPIRAddressSpace { +- SPIRAS_Private, ++ SPIRAS_Generic, + SPIRAS_Global, +- SPIRAS_Constant, ++ SPIRAS_Internal, + SPIRAS_Local, +- SPIRAS_Generic, ++ SPIRAS_Constant, ++ SPIRAS_Private, + SPIRAS_GlobalDevice, + SPIRAS_GlobalHost, + SPIRAS_Input, \ No newline at end of file diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl index 0c28198f980b95..9a4dfa2aafdc51 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl index 18a84d96c39f82..dbd7bad8d855c6 100644 --- a/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/third_party/tsl/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/gpus/cuda/hermetic:cuda_configure.bzl", "cuda_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD index 0fee1ed113152a..a5b0791d5d28ce 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/core/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/core/BUILD @@ -36,7 +36,7 @@ filegroup( srcs = [ "bitmap.h", "bits.h", - "status_test_util.h", + "@local_xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -67,7 +67,7 @@ filegroup( filegroup( name = "legacy_lib_core_status_test_util_header", srcs = [ - "status_test_util.h", + "@local_xla//xla/tsl/lib/core:legacy_lib_core_status_test_util_header", ], compatible_with = get_compatible_with_portable(), visibility = internal_visibility([ @@ -88,16 +88,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "status_test_util", - testonly = 1, - hdrs = ["status_test_util.h"], - deps = [ - "//tsl/platform:status_matchers", - "//tsl/platform:test", - ], -) - cc_library( name = "bits", hdrs = ["bits.h"], diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h index f04fd0c87e8986..8d5cf7912e9d78 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h +++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatmap.h @@ -80,7 +80,7 @@ class FlatMap { // Move constructor leaves src in a valid but unspecified state (same as // std::unordered_map). - FlatMap(FlatMap&& src) : rep_(std::move(src.rep_)) {} + FlatMap(FlatMap&& src) noexcept : rep_(std::move(src.rep_)) {} template FlatMap(InputIter first, InputIter last, size_t N = 1, @@ -100,14 +100,14 @@ class FlatMap { // Move-assignment operator leaves src in a valid but unspecified state (same // as std::unordered_map). - FlatMap& operator=(FlatMap&& src) { + FlatMap& operator=(FlatMap&& src) noexcept { rep_.MoveFrom(std::move(src.rep_)); return *this; } ~FlatMap() {} - void swap(FlatMap& x) { rep_.swap(x.rep_); } + void swap(FlatMap& x) noexcept { rep_.swap(x.rep_); } void clear_no_resize() { rep_.clear_no_resize(); } void clear() { rep_.clear(); } void reserve(size_t N) { rep_.Resize(std::max(N, size())); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h index dfc65844e68ed3..d6c77e7de363ea 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h +++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatrep.h @@ -58,10 +58,11 @@ class FlatRep { CopyEntries(src.array_, src.end_, CopyEntry()); } - FlatRep(FlatRep&& src) - // Copy rather than move src.hash_ and src.equal_. This is necessary to - // leave src in a valid state -- otherwise e.g. if hash_ is an - // std::function, moving it would null it out. + FlatRep( + FlatRep&& src) noexcept // Copy rather than move src.hash_ and + // src.equal_. This is necessary to leave src in + // a valid state -- otherwise e.g. if hash_ is an + // std::function, moving it would null it out. : hash_(src.hash_), equal_(src.equal_) { // TODO(jlebar): Init(1) still allocates some memory, so this isn't as cheap // as it could be. The fundamental problem is that we need to leave src in @@ -118,7 +119,7 @@ class FlatRep { MaybeResize(); } - void swap(FlatRep& x) { + void swap(FlatRep& x) noexcept { using std::swap; swap(array_, x.array_); swap(end_, x.end_); diff --git a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h index ec8e9ad4be3ee2..b3178225647fe1 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h +++ b/third_party/xla/third_party/tsl/tsl/lib/gtl/flatset.h @@ -63,7 +63,7 @@ class FlatSet { // Move constructor leaves src in a valid but unspecified state (same as // std::unordered_set). - FlatSet(FlatSet&& src) : rep_(std::move(src.rep_)) {} + FlatSet(FlatSet&& src) noexcept : rep_(std::move(src.rep_)) {} template FlatSet(InputIter first, InputIter last, size_t N = 1, @@ -83,14 +83,14 @@ class FlatSet { // Move-assignment operator leaves src in a valid but unspecified state (same // as std::unordered_set). - FlatSet& operator=(FlatSet&& src) { + FlatSet& operator=(FlatSet&& src) noexcept { rep_.MoveFrom(std::move(src.rep_)); return *this; } ~FlatSet() {} - void swap(FlatSet& x) { rep_.swap(x.rep_); } + void swap(FlatSet& x) noexcept { rep_.swap(x.rep_); } void clear_no_resize() { rep_.clear_no_resize(); } void clear() { rep_.clear(); } void reserve(size_t N) { rep_.Resize(std::max(N, size())); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD index c103dcfdc5a417..055931faddf1fd 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/io/BUILD @@ -263,12 +263,12 @@ tsl_cc_test( srcs = ["buffered_file_test.cc"], deps = [ ":buffered_file", - "//tsl/lib/core:status_test_util", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:test", "//tsl/platform:test_benchmark", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -443,12 +443,12 @@ tsl_cc_test( deps = [ ":buffered_inputstream", ":random_inputstream", - "//tsl/lib/core:status_test_util", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:test", "//tsl/platform:test_benchmark", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -471,7 +471,6 @@ tsl_cc_test( srcs = ["inputbuffer_test.cc"], deps = [ ":inputbuffer", - "//tsl/lib/core:status_test_util", "//tsl/platform:coding", "//tsl/platform:env", "//tsl/platform:env_impl", @@ -482,6 +481,7 @@ tsl_cc_test( "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -491,10 +491,10 @@ tsl_cc_test( srcs = ["inputstream_interface_test.cc"], deps = [ ":inputstream_interface", - "//tsl/lib/core:status_test_util", "//tsl/platform:errors", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -504,11 +504,11 @@ tsl_cc_test( srcs = ["random_inputstream_test.cc"], deps = [ ":random_inputstream", - "//tsl/lib/core:status_test_util", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -519,7 +519,6 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/lib/core:status_test_util", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:errors", @@ -528,6 +527,7 @@ tsl_cc_test( "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", "@zlib", ], ) @@ -539,7 +539,6 @@ tsl_cc_test( deps = [ ":record_reader", ":record_writer", - "//tsl/lib/core:status_test_util", "//tsl/lib/hash:crc32c", "//tsl/lib/random:philox", "//tsl/platform:coding", @@ -549,6 +548,7 @@ tsl_cc_test( "//tsl/platform:str_util", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -580,12 +580,12 @@ tsl_cc_test( ":zlib_compression_options", ":zlib_inputstream", ":zlib_outputbuffer", - "//tsl/lib/core:status_test_util", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:errors", "//tsl/platform:strcat", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc index 6fae0b66b8665d..f9fa67dd1572f5 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_file_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc index ab1f58e0b14a83..83e5776d6602d2 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/buffered_inputstream_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/lib/io/buffered_inputstream.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc index e38460405befb2..d23f06a260fb04 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputbuffer_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/coding.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc index 23d4fb0ddf50bc..c9c34dba55364e 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/inputstream_interface_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/lib/io/inputstream_interface.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc index 0b47ef2e9075e7..dfa4ec80e20a17 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/random_inputstream_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/lib/io/random_inputstream.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc index 67df783112f9ee..45934c9f355576 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/record_reader_writer_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc index 42adf76f7ef0d3..51c2be62301455 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/recordio_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/lib/hash/crc32c.h" #include "tsl/lib/io/record_reader.h" #include "tsl/lib/io/record_writer.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD index 3f42c5fa8b03ae..0adc5e2fa467aa 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/BUILD @@ -90,12 +90,12 @@ tsl_cc_test( ":snappy_inputbuffer", ":snappy_inputstream", ":snappy_outputbuffer", - "//tsl/lib/core:status_test_util", "//tsl/lib/io:inputbuffer", "//tsl/lib/io:random_inputstream", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc index d04d8d184549b3..7844b8993fd98d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_inputbuffer.cc @@ -170,7 +170,7 @@ absl::Status SnappyInputBuffer::ReadFromFile() { bytes_to_read -= avail_in_; read_location += avail_in_; } - StringPiece data; + absl::string_view data; // Try to read enough data to fill up input_buffer_. absl::Status s = file_->Read(file_pos_, bytes_to_read, &data, read_location); if (data.data() != read_location) { diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc index 6d19c60839995e..e851f58f1b9cda 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.cc @@ -40,7 +40,7 @@ SnappyOutputBuffer::~SnappyOutputBuffer() { } } -absl::Status SnappyOutputBuffer::Append(StringPiece data) { +absl::Status SnappyOutputBuffer::Append(absl::string_view data) { return Write(data); } @@ -58,7 +58,7 @@ absl::Status SnappyOutputBuffer::Close() { return Flush(); } -absl::Status SnappyOutputBuffer::Name(StringPiece* result) const { +absl::Status SnappyOutputBuffer::Name(absl::string_view* result) const { return file_->Name(result); } @@ -71,7 +71,7 @@ absl::Status SnappyOutputBuffer::Tell(int64_t* position) { return file_->Tell(position); } -absl::Status SnappyOutputBuffer::Write(StringPiece data) { +absl::Status SnappyOutputBuffer::Write(absl::string_view data) { // // The deflated output is accumulated in output_buffer_ and gets written to // file as and when needed. @@ -121,7 +121,7 @@ int32 SnappyOutputBuffer::AvailableInputSpace() const { return input_buffer_capacity_ - avail_in_; } -void SnappyOutputBuffer::AddToInputBuffer(StringPiece data) { +void SnappyOutputBuffer::AddToInputBuffer(absl::string_view data) { size_t bytes_to_write = data.size(); DCHECK_LE(bytes_to_write, AvailableInputSpace()); @@ -182,7 +182,7 @@ absl::Status SnappyOutputBuffer::DeflateBuffered() { absl::Status SnappyOutputBuffer::FlushOutputBufferToFile() { size_t bytes_to_write = output_buffer_capacity_ - avail_out_; if (bytes_to_write > 0) { - absl::Status s = file_->Append(StringPiece( + absl::Status s = file_->Append(absl::string_view( reinterpret_cast(output_buffer_.get()), bytes_to_write)); if (s.ok()) { next_out_ = output_buffer_.get(); diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h index 4c4d664d014a07..a3bd44748c152f 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_outputbuffer.h @@ -64,7 +64,7 @@ class SnappyOutputBuffer : public WritableFile { // // The input data is buffered internally and will be written to disk at a // later time. To immediately write contents to file call `Flush()`. - absl::Status Append(StringPiece data) override; + absl::Status Append(absl::string_view data) override; #if defined(TF_CORD_SUPPORT) absl::Status Append(const absl::Cord& cord) override; @@ -81,7 +81,7 @@ class SnappyOutputBuffer : public WritableFile { absl::Status Close() override; // Returns the name of the underlying file. - absl::Status Name(StringPiece* result) const override; + absl::Status Name(absl::string_view* result) const override; // Deflates any cached input, writes all output to file and syncs it. absl::Status Sync() override; @@ -98,7 +98,7 @@ class SnappyOutputBuffer : public WritableFile { // to file when the buffer is full. // // To immediately write contents to file call `Flush()`. - absl::Status Write(StringPiece data); + absl::Status Write(absl::string_view data); // Compresses any cached input and writes all output to file. This must be // called before the destructor to avoid any data loss. @@ -107,7 +107,7 @@ class SnappyOutputBuffer : public WritableFile { private: // Appends `data` to `input_buffer_`. // Throws if `data.size()` > AvailableInputSpace(). - void AddToInputBuffer(StringPiece data); + void AddToInputBuffer(absl::string_view data); // Appends `data` to `output_buffer_`. Flushes buffer contents to file when // buffer gets full. diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc index 33f42bd1b90683..78eecf360d9489 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/snappy/snappy_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/lib/io/inputbuffer.h" #include "tsl/lib/io/random_inputstream.h" #include "tsl/lib/io/snappy/snappy_inputbuffer.h" @@ -77,7 +77,7 @@ absl::Status TestMultipleWritesWriteFile(size_t compress_input_buf_size, compress_output_buf_size); for (int i = 0; i < num_writes; i++) { - TF_RETURN_IF_ERROR(out.Write(StringPiece(data))); + TF_RETURN_IF_ERROR(out.Write(absl::string_view(data))); if (with_flush) { TF_RETURN_IF_ERROR(out.Flush()); } @@ -96,7 +96,7 @@ absl::Status TestMultipleWritesWriteFile(size_t compress_input_buf_size, std::unique_ptr file_reader; TF_RETURN_IF_ERROR(env->NewRandomAccessFile(fname, &file_reader)); - StringPiece data; + absl::string_view data; size_t file_pos = 0; size_t bytes_to_read = 256; char* scratch = new char[bytes_to_read]; @@ -106,14 +106,14 @@ absl::Status TestMultipleWritesWriteFile(size_t compress_input_buf_size, while ((file_reader->Read(file_pos, bytes_to_read, &data, scratch)).ok()) { file_pos += data.size(); TF_CHECK_OK( - corrupt_file_writer->Append(StringPiece(buffer, buffer_size))); + corrupt_file_writer->Append(absl::string_view(buffer, buffer_size))); memcpy(buffer, data.data(), data.size()); buffer_size = data.size(); } // Drop the last byte. File is now corrupt. - TF_CHECK_OK( - corrupt_file_writer->Append(StringPiece(buffer, buffer_size - 1))); + TF_CHECK_OK(corrupt_file_writer->Append( + absl::string_view(buffer, buffer_size - 1))); TF_CHECK_OK(corrupt_file_writer->Flush()); TF_CHECK_OK(corrupt_file_writer->Close()); delete[] scratch; @@ -216,7 +216,7 @@ void TestTellWriteFile(size_t compress_input_buf_size, TF_CHECK_OK(env->NewWritableFile(fname, &file_writer)); io::SnappyOutputBuffer out(file_writer.get(), compress_input_buf_size, compress_output_buf_size); - TF_CHECK_OK(out.Write(StringPiece(data))); + TF_CHECK_OK(out.Write(absl::string_view(data))); TF_CHECK_OK(out.Flush()); TF_CHECK_OK(file_writer->Flush()); TF_CHECK_OK(file_writer->Close()); @@ -296,7 +296,7 @@ void TestTellInputStream(size_t compress_input_buf_size, static bool SnappyCompressionSupported() { string out; - StringPiece in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; + absl::string_view in = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"; return port::Snappy_Compress(in.data(), in.size(), &out); } diff --git a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc index 0aa65e8e747a89..c2ff61d347df40 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc +++ b/third_party/xla/third_party/tsl/tsl/lib/io/zlib_buffers_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/lib/io/random_inputstream.h" #include "tsl/lib/io/zlib_compression_options.h" #include "tsl/lib/io/zlib_inputstream.h" diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD index c15c9293dbdbf6..302c0c412ef11b 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/BUILD @@ -71,7 +71,6 @@ cc_library( deps = [ ":collection_registry", ":metric_def", - "//tsl/lib/histogram", "//tsl/platform", "//tsl/platform:macros", "//tsl/platform:mutex", @@ -81,6 +80,7 @@ cc_library( "//tsl/protobuf:histogram_proto_cc", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/lib/histogram", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h b/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h index 63f583bd4f1b44..e17f3cdae00d91 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h +++ b/third_party/xla/third_party/tsl/tsl/lib/monitoring/sampler.h @@ -122,7 +122,7 @@ class Sampler { #include #include -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include "tsl/lib/monitoring/collection_registry.h" #include "tsl/lib/monitoring/metric_def.h" #include "tsl/platform/macros.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/BUILD b/third_party/xla/third_party/tsl/tsl/platform/BUILD index f2a3d51ba6d593..bc788812475163 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/BUILD @@ -1465,7 +1465,7 @@ tsl_cc_test( ":subprocess", ":test", ":test_main", - "//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -1768,7 +1768,7 @@ tsl_cc_test( ":str_util", ":test", ":test_main", - "//tsl/lib/core:status_test_util", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -1784,7 +1784,7 @@ tsl_cc_test( ":str_util", ":test", ":test_main", - "//tsl/lib/core:status_test_util", "@com_google_absl//absl/time", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD index bff2db41d43626..e92fd047cc0018 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/BUILD @@ -231,7 +231,6 @@ cc_library( copts = tsl_copts(), deps = [ ":curl_http_request", - "//tsl/lib/core:status_test_util", "//tsl/platform:errors", "//tsl/platform:macros", "//tsl/platform:protobuf", @@ -240,6 +239,7 @@ cc_library( "//tsl/platform:test", "//tsl/platform:types", "@curl", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -359,10 +359,10 @@ tsl_cc_test( deps = [ ":expiring_lru_cache", ":now_seconds_env", - "//tsl/lib/core:status_test_util", "//tsl/platform:env_impl", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -373,13 +373,13 @@ tsl_cc_test( deps = [ ":now_seconds_env", ":ram_file_block_cache", - "//tsl/lib/core:status_test_util", "//tsl/platform:blocking_counter", "//tsl/platform:env", "//tsl/platform:env_impl", "//tsl/platform:notification", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -390,7 +390,6 @@ tsl_cc_test( deps = [ ":gcs_file_system", ":http_request_fake", - "//tsl/lib/core:status_test_util", "//tsl/platform:env_impl", "//tsl/platform:errors", "//tsl/platform:str_util", @@ -399,6 +398,7 @@ tsl_cc_test( "//tsl/platform:test_main", "//tsl/profiler/backends/cpu:traceme_recorder_impl", "//tsl/profiler/utils:time_utils_impl", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -423,11 +423,11 @@ tsl_cc_test( linkopts = if_windows(["-DEFAULTLIB:ws2_32.lib"]), deps = [ ":gcs_throttle", - "//tsl/lib/core:status_test_util", "//tsl/platform:env_impl", "//tsl/platform:str_util", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -437,13 +437,13 @@ tsl_cc_test( srcs = ["curl_http_request_test.cc"], deps = [ ":curl_http_request", - "//tsl/lib/core:status_test_util", "//tsl/platform:env_impl", "//tsl/platform:path", "//tsl/platform:platform_port", "//tsl/platform:test", "//tsl/platform:test_main", "@com_google_absl//absl/status", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -459,7 +459,6 @@ tsl_cc_test( deps = [ ":http_request_fake", ":oauth_client", - "//tsl/lib/core:status_test_util", "//tsl/platform:base64", "//tsl/platform:env", "//tsl/platform:env_impl", @@ -468,6 +467,7 @@ tsl_cc_test( "//tsl/platform:test", "//tsl/platform:test_main", "@boringssl//:crypto", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -484,11 +484,11 @@ tsl_cc_test( ":google_auth_provider", ":http_request_fake", ":oauth_client", - "//tsl/lib/core:status_test_util", "//tsl/platform:env_impl", "//tsl/platform:path", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) @@ -525,8 +525,8 @@ tsl_cc_test( srcs = ["time_util_test.cc"], deps = [ ":time_util", - "//tsl/lib/core:status_test_util", "//tsl/platform:test", "//tsl/platform:test_main", + "@local_xla//xla/tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h index 969b8bc0cea23e..4b1b2927581b72 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/auth_provider.h @@ -31,9 +31,9 @@ class AuthProvider { /// \brief Returns the short-term authentication bearer token. /// /// Safe for concurrent use by multiple threads. - virtual Status GetToken(string* t) = 0; + virtual absl::Status GetToken(string* t) = 0; - static Status GetToken(AuthProvider* provider, string* token) { + static absl::Status GetToken(AuthProvider* provider, string* token) { if (!provider) { return errors::Internal("Auth provider is required."); } @@ -44,9 +44,9 @@ class AuthProvider { /// No-op auth provider, which will only work for public objects. class EmptyAuthProvider : public AuthProvider { public: - Status GetToken(string* token) override { + absl::Status GetToken(string* token) override { *token = ""; - return OkStatus(); + return absl::OkStatus(); } }; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc index 7be3af7019050f..7a41c8f37b7536 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.cc @@ -40,7 +40,7 @@ ComputeEngineMetadataClient::ComputeEngineMetadataClient( : http_request_factory_(std::move(http_request_factory)), retry_config_(config) {} -Status ComputeEngineMetadataClient::GetMetadata( +absl::Status ComputeEngineMetadataClient::GetMetadata( const string& path, std::vector* response_buffer) { const auto get_metadata_from_gce = [path, response_buffer, this]() { string metadata_url; @@ -56,7 +56,7 @@ Status ComputeEngineMetadataClient::GetMetadata( request->AddHeader("Metadata-Flavor", "Google"); request->SetResultBuffer(response_buffer); TF_RETURN_IF_ERROR(request->Send()); - return OkStatus(); + return absl::OkStatus(); }; return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h index fac94cdb1c9c8c..1337d33c8e1895 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_metadata_client.h @@ -51,8 +51,8 @@ class ComputeEngineMetadataClient { /// To get the zone of an instance: /// compute_engine_metadata_client.GetMetadata( /// "instance/zone", response_buffer); - virtual Status GetMetadata(const string& path, - std::vector* response_buffer); + virtual absl::Status GetMetadata(const string& path, + std::vector* response_buffer); private: std::shared_ptr http_request_factory_; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc index 7720784db75eac..19f27556d44991 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.cc @@ -28,15 +28,15 @@ ComputeEngineZoneProvider::ComputeEngineZoneProvider( std::shared_ptr google_metadata_client) : google_metadata_client_(std::move(google_metadata_client)) {} -Status ComputeEngineZoneProvider::GetZone(string* zone) { +absl::Status ComputeEngineZoneProvider::GetZone(string* zone) { if (!cached_zone.empty()) { *zone = cached_zone; - return OkStatus(); + return absl::OkStatus(); } std::vector response_buffer; TF_RETURN_IF_ERROR(google_metadata_client_->GetMetadata(kGceMetadataZonePath, &response_buffer)); - StringPiece location(&response_buffer[0], response_buffer.size()); + absl::string_view location(&response_buffer[0], response_buffer.size()); std::vector elems = str_util::Split(location, "/"); if (elems.size() == 4) { @@ -47,7 +47,7 @@ Status ComputeEngineZoneProvider::GetZone(string* zone) { << string(location); } - return OkStatus(); + return absl::OkStatus(); } ComputeEngineZoneProvider::~ComputeEngineZoneProvider() {} diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h index a37b43c22a484f..99ed41fca7d881 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/compute_engine_zone_provider.h @@ -27,7 +27,7 @@ class ComputeEngineZoneProvider : public ZoneProvider { std::shared_ptr google_metadata_client); virtual ~ComputeEngineZoneProvider(); - Status GetZone(string* zone) override; + absl::Status GetZone(string* zone) override; private: std::shared_ptr google_metadata_client_; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc index c41f967c04b055..44eeab7f511fb9 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.cc @@ -230,8 +230,8 @@ void CurlHttpRequest::SetDeleteRequest() { libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE")); } -Status CurlHttpRequest::SetPutFromFile(const string& body_filepath, - size_t offset) { +absl::Status CurlHttpRequest::SetPutFromFile(const string& body_filepath, + size_t offset) { CheckNotSent(); CheckMethodNotSet(); is_method_set_ = true; @@ -257,7 +257,7 @@ Status CurlHttpRequest::SetPutFromFile(const string& body_filepath, reinterpret_cast(put_body_))); // Using the default CURLOPT_READFUNCTION, which is doing an fread() on the // FILE * userdata set with CURLOPT_READDATA. - return OkStatus(); + return absl::OkStatus(); } void CurlHttpRequest::SetPutEmptyBody() { @@ -286,7 +286,7 @@ void CurlHttpRequest::SetPostFromBuffer(const char* buffer, size_t size) { reinterpret_cast(this))); CHECK_CURL_OK(libcurl_->curl_easy_setopt(curl_, CURLOPT_READFUNCTION, &CurlHttpRequest::ReadCallback)); - post_body_buffer_ = StringPiece(buffer, size); + post_body_buffer_ = absl::string_view(buffer, size); } void CurlHttpRequest::SetPostEmptyBody() { @@ -397,8 +397,8 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, size_t nmemb, void* this_object) { CHECK(ptr); auto that = reinterpret_cast(this_object); - StringPiece header(reinterpret_cast(ptr), size * nmemb); - StringPiece name, value; + absl::string_view header(reinterpret_cast(ptr), size * nmemb); + absl::string_view name, value; // The supplied header has the form ": ", parse it. if (strings::Scanner(header) .ScanEscapedUntil(':') @@ -412,7 +412,7 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, return size * nmemb; } -Status CurlHttpRequest::Send() { +absl::Status CurlHttpRequest::Send() { CheckNotSent(); CHECK(is_uri_set_) << "URI has not been set."; @@ -457,7 +457,7 @@ Status CurlHttpRequest::Send() { auto get_error_message = [this]() -> string { string error_message = strings::StrCat( "Error executing an HTTP request: HTTP response code ", response_code_); - StringPiece body = GetResponse(); + absl::string_view body = GetResponse(); if (!body.empty()) { return strings::StrCat( error_message, " with body '", @@ -466,7 +466,7 @@ Status CurlHttpRequest::Send() { return error_message; }; - Status result; + absl::Status result; switch (response_code_) { // The group of response codes indicating that the request achieved // the expected goal. @@ -474,7 +474,7 @@ Status CurlHttpRequest::Send() { case 201: // Created case 204: // No Content case 206: // Partial Content - result = OkStatus(); + result = absl::OkStatus(); break; case 416: // Requested Range Not Satisfiable @@ -485,7 +485,7 @@ Status CurlHttpRequest::Send() { if (IsDirectResponse()) { direct_response_.bytes_transferred_ = 0; } - result = OkStatus(); + result = absl::OkStatus(); break; // INVALID_ARGUMENT indicates a problem with how the request is constructed. @@ -556,13 +556,14 @@ void CurlHttpRequest::CheckNotSent() const { CHECK(!is_sent_) << "The request has already been sent."; } -StringPiece CurlHttpRequest::GetResponse() const { - StringPiece response; +absl::string_view CurlHttpRequest::GetResponse() const { + absl::string_view response; if (IsDirectResponse()) { - response = StringPiece(direct_response_.buffer_, - direct_response_.bytes_transferred_); + response = absl::string_view(direct_response_.buffer_, + direct_response_.bytes_transferred_); } else { - response = StringPiece(response_buffer_->data(), response_buffer_->size()); + response = + absl::string_view(response_buffer_->data(), response_buffer_->size()); } return response; } @@ -627,10 +628,10 @@ int CurlHttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal, return 0; } -Status CurlHttpRequest::CURLcodeToStatus(CURLcode code, - const char* error_buffer) { +absl::Status CurlHttpRequest::CURLcodeToStatus(CURLcode code, + const char* error_buffer) { if (code == CURLE_OK) { - return OkStatus(); + return absl::OkStatus(); } string error_message = strings::StrCat( "Error executing an HTTP request: libcurl code ", code, " meaning '", @@ -648,7 +649,7 @@ Status CurlHttpRequest::CURLcodeToStatus(CURLcode code, // a response body (e.g. GCS sends one with an error message) but we // pretend as though they don't, so actually ignore this error. if (get_response_result == CURLE_OK && response_code == 416) { - return OkStatus(); + return absl::OkStatus(); } return errors::FailedPrecondition( strings::StrCat(error_message, overflow_message)); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h index b5c728520dc693..4c64758a0bda8b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request.h @@ -86,7 +86,8 @@ class CurlHttpRequest : public HttpRequest { /// /// The request body will be taken from the specified file starting from /// the given offset. - Status SetPutFromFile(const string& body_filepath, size_t offset) override; + absl::Status SetPutFromFile(const string& body_filepath, + size_t offset) override; /// Makes the request a PUT request with an empty body. void SetPutEmptyBody() override; @@ -140,7 +141,7 @@ class CurlHttpRequest : public HttpRequest { /// /// If the result buffer was defined, the response will be written there. /// The object is not designed to be re-used after Send() is executed. - Status Send() override; + absl::Status Send() override; // Url encodes str and returns a new string. string EscapeString(const string& str) override; @@ -167,18 +168,18 @@ class CurlHttpRequest : public HttpRequest { curl_off_t ulnow); void CheckMethodNotSet() const; void CheckNotSent() const; - StringPiece GetResponse() const; + absl::string_view GetResponse() const; /// Helper to convert the given CURLcode and error buffer, representing the /// result of performing a transfer, into a Status with an error message. - Status CURLcodeToStatus(CURLcode code, const char* error_buffer); + absl::Status CURLcodeToStatus(CURLcode code, const char* error_buffer); LibCurl* libcurl_; Env* env_; FILE* put_body_ = nullptr; - StringPiece post_body_buffer_; + absl::string_view post_body_buffer_; size_t post_body_read_ = 0; std::vector* response_buffer_ = nullptr; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc index 36d710804607d9..31cde679f4978b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/curl_http_request_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "absl/status/status.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/mem.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" @@ -151,8 +151,8 @@ class FakeLibCurl : public LibCurl { posted_content_ = ""; do { bytes_read = read_callback_(buffer, 1, sizeof(buffer), read_data_); - posted_content_ = - strings::StrCat(posted_content_, StringPiece(buffer, bytes_read)); + posted_content_ = strings::StrCat( + posted_content_, absl::string_view(buffer, bytes_read)); } while (bytes_read > 0); } if (write_data_ || write_callback_) { @@ -366,7 +366,7 @@ TEST(CurlHttpRequestTest, GetRequest_Direct_ResponseTooLarge) { http_request.SetUri("http://www.testuri.com"); http_request.SetResultBufferDirect(scratch.data(), scratch.size()); - const Status& status = http_request.Send(); + const absl::Status& status = http_request.Send(); EXPECT_EQ(error::FAILED_PRECONDITION, status.code()); EXPECT_EQ( "Error executing an HTTP request: libcurl code 23 meaning " @@ -770,7 +770,7 @@ class TestStats : public HttpRequest::RequestStats { void RecordResponse(const HttpRequest* request, const string& uri, HttpRequest::RequestMethod method, - const Status& result) override { + const absl::Status& result) override { has_recorded_response_ = true; record_response_request_ = request; record_response_uri_ = uri; @@ -787,7 +787,7 @@ class TestStats : public HttpRequest::RequestStats { string record_response_uri_ = "http://www.testuri.com"; HttpRequest::RequestMethod record_response_method_ = HttpRequest::RequestMethod::kGet; - Status record_response_result_; + absl::Status record_response_result_; bool has_recorded_request_ = false; bool has_recorded_response_ = false; @@ -864,7 +864,7 @@ TEST(CurlHttpRequestTest, StatsGetNotFound) { http_request.AddAuthBearerHeader("fake-bearer"); http_request.SetRange(100, 199); http_request.SetResultBuffer(&scratch); - Status s = http_request.Send(); + absl::Status s = http_request.Send(); // Check interaction with stats. ASSERT_TRUE(stats.has_recorded_request_); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h index 1def81b6d0d562..d3a15bc9fa3781 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache.h @@ -71,13 +71,13 @@ class ExpiringLRUCache { return LookupLocked(key, value); } - typedef std::function ComputeFunc; + typedef std::function ComputeFunc; /// Look up the entry with key `key` and copy it to `value` if found. If not /// found, call `compute_func`. If `compute_func` returns successfully, store /// a copy of the output parameter in the cache, and another copy in `value`. - Status LookupOrCompute(const string& key, T* value, - const ComputeFunc& compute_func) { + absl::Status LookupOrCompute(const string& key, T* value, + const ComputeFunc& compute_func) { if (max_age_ == 0) { return compute_func(key, value); } @@ -88,9 +88,9 @@ class ExpiringLRUCache { // key if this proves to be a significant performance bottleneck. mutex_lock lock(mu_); if (LookupLocked(key, value)) { - return OkStatus(); + return absl::OkStatus(); } - Status s = compute_func(key, value); + absl::Status s = compute_func(key, value); if (s.ok()) { InsertLocked(key, *value); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc index ce3e0fc29d1684..7225dca6e50ce1 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/expiring_lru_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cloud/now_seconds_env.h" #include "tsl/platform/test.h" @@ -97,7 +97,7 @@ TEST(ExpiringLRUCacheTest, LookupOrCompute) { [&num_compute_calls](const string& key, int* value) { *value = num_compute_calls; num_compute_calls++; - return OkStatus(); + return absl::OkStatus(); }; ExpiringLRUCache cache1(0, 4); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h index e336a42835c9f9..59927545f7eb76 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/file_block_cache.h @@ -70,9 +70,9 @@ class FileBlockCache { /// cache is constructed. The returned Status should be OK as long as the /// read from the remote filesystem succeeded (similar to the semantics of the /// read(2) system call). - typedef std::function + typedef std::function BlockFetcher; virtual ~FileBlockCache() {} @@ -91,8 +91,8 @@ class FileBlockCache { /// placed in `out`. /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed /// in `out`). - virtual Status Read(const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred) = 0; + virtual absl::Status Read(const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred) = 0; // Validate the given file signature with the existing file signature in the // cache. Returns true if the signature doesn't change or the file did not diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc index 4819b49fc31338..594703f31ffef2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache.cc @@ -41,7 +41,8 @@ namespace { const std::vector& kCachedDomainNames = *new std::vector{"www.googleapis.com", "storage.googleapis.com"}; -inline void print_getaddrinfo_error(const string& name, Status return_status) { +inline void print_getaddrinfo_error(const string& name, + absl::Status return_status) { // Status doesn't map well to EAI type errors. LOG(ERROR) << "Error resolving " << name << ": " << return_status; } @@ -104,13 +105,13 @@ void GcsDnsCache::AnnotateRequest(HttpRequest* request) { /* max_delay_time_us = */ 50 * 1000 * 5000, /* max_retries = */ 5); - const Status getaddrinfo_status = RetryingUtils::CallWithRetries( + const absl::Status getaddrinfo_status = RetryingUtils::CallWithRetries( [&name, &hints, &result]() { int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result); absl::Status return_status; switch (return_code) { case 0: - return_status = OkStatus(); + return_status = absl::OkStatus(); break; #ifndef _WIN32 case EAI_ADDRFAMILY: @@ -175,7 +176,7 @@ void GcsDnsCache::AnnotateRequest(HttpRequest* request) { #endif } - return Status(return_status); + return absl::Status(return_status); }, retryConfig); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc index 069dcb546f0ebb..a5ce0882ba5688 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_dns_cache_test.cc @@ -40,8 +40,9 @@ class TestHttpRequest : public HttpRequest { void SetRequestStats(HttpRequest::RequestStats* stats) override {} void SetDeleteRequest() override {} - Status SetPutFromFile(const string& body_filepath, size_t offset) override { - return OkStatus(); + absl::Status SetPutFromFile(const string& body_filepath, + size_t offset) override { + return absl::OkStatus(); } void SetPutEmptyBody() override {} void SetPostFromBuffer(const char* buffer, size_t size) override {} @@ -52,7 +53,7 @@ class TestHttpRequest : public HttpRequest { string GetResponseHeader(const string& name) const override { return ""; } uint64 GetResponseCode() const override { return 0; } - Status Send() override { return OkStatus(); } + absl::Status Send() override { return absl::OkStatus(); } string EscapeString(const string& str) override { return ""; } void SetTimeouts(uint32 connection, uint32 inactivity, diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc index ea65028a96cd22..c1cc244e0dc173 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.cc @@ -163,9 +163,9 @@ constexpr char kAppendMode[] = "GCS_APPEND_MODE"; // objects. constexpr char kComposeAppend[] = "compose"; -Status GetTmpFilename(string* filename) { +absl::Status GetTmpFilename(string* filename) { *filename = io::GetTempFilename(""); - return OkStatus(); + return absl::OkStatus(); } /// Appends a trailing slash if the name doesn't already have one. @@ -199,7 +199,7 @@ std::set AddAllSubpaths(const std::vector& paths) { std::set result; result.insert(paths.begin(), paths.end()); for (const string& path : paths) { - StringPiece subpath = io::Dirname(path); + absl::string_view subpath = io::Dirname(path); // If `path` starts with `/`, `subpath` will be `/` and then we get into an // infinite loop. Same behavior happens if there is a `//` pattern in // `path`, so we check for that and leave the loop quicker. @@ -211,32 +211,32 @@ std::set AddAllSubpaths(const std::vector& paths) { return result; } -Status ParseJson(StringPiece json, Json::Value* result) { +absl::Status ParseJson(absl::string_view json, Json::Value* result) { Json::Reader reader; if (!reader.parse(json.data(), json.data() + json.size(), *result)) { return errors::Internal("Couldn't parse JSON response from GCS."); } - return OkStatus(); + return absl::OkStatus(); } -Status ParseJson(const std::vector& json, Json::Value* result) { - return ParseJson(StringPiece{json.data(), json.size()}, result); +absl::Status ParseJson(const std::vector& json, Json::Value* result) { + return ParseJson(absl::string_view{json.data(), json.size()}, result); } /// Reads a JSON value with the given name from a parent JSON value. -Status GetValue(const Json::Value& parent, const char* name, - Json::Value* result) { +absl::Status GetValue(const Json::Value& parent, const char* name, + Json::Value* result) { *result = parent.get(name, Json::Value::null); if (result->isNull()) { return errors::Internal("The field '", name, "' was expected in the JSON response."); } - return OkStatus(); + return absl::OkStatus(); } /// Reads a string JSON value with the given name from a parent JSON value. -Status GetStringValue(const Json::Value& parent, const char* name, - string* result) { +absl::Status GetStringValue(const Json::Value& parent, const char* name, + string* result) { Json::Value result_value; TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (!result_value.isString()) { @@ -245,21 +245,21 @@ Status GetStringValue(const Json::Value& parent, const char* name, "' in the JSON response was expected to be a string."); } *result = result_value.asString(); - return OkStatus(); + return absl::OkStatus(); } /// Reads a long JSON value with the given name from a parent JSON value. -Status GetInt64Value(const Json::Value& parent, const char* name, - int64_t* result) { +absl::Status GetInt64Value(const Json::Value& parent, const char* name, + int64_t* result) { Json::Value result_value; TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (result_value.isNumeric()) { *result = result_value.asInt64(); - return OkStatus(); + return absl::OkStatus(); } if (result_value.isString() && strings::safe_strto64(result_value.asCString(), result)) { - return OkStatus(); + return absl::OkStatus(); } return errors::Internal( "The field '", name, @@ -267,7 +267,8 @@ Status GetInt64Value(const Json::Value& parent, const char* name, } /// Reads a boolean JSON value with the given name from a parent JSON value. -Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) { +absl::Status GetBoolValue(const Json::Value& parent, const char* name, + bool* result) { Json::Value result_value; TF_RETURN_IF_ERROR(GetValue(parent, name, &result_value)); if (!result_value.isBool()) { @@ -276,7 +277,7 @@ Status GetBoolValue(const Json::Value& parent, const char* name, bool* result) { "' in the JSON response was expected to be a boolean."); } *result = result_value.asBool(); - return OkStatus(); + return absl::OkStatus(); } /// Get GCS Retry Config by applying user overrides through env if any. @@ -314,21 +315,21 @@ RetryConfig GetGcsRetryConfig() { /// A GCS-based implementation of a random access file with an LRU block cache. class GcsRandomAccessFile : public RandomAccessFile { public: - using ReadFn = - std::function; + using ReadFn = std::function; GcsRandomAccessFile(const string& filename, ReadFn read_fn) : filename_(filename), read_fn_(std::move(read_fn)) {} - Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = filename_; - return OkStatus(); + return absl::OkStatus(); } /// The implementation of reads with an LRU block cache. Thread safe. - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, + char* scratch) const override { return read_fn_(filename_, offset, n, result, scratch); } @@ -342,9 +343,9 @@ class GcsRandomAccessFile : public RandomAccessFile { /// A GCS-based implementation of a random access file with a read buffer. class BufferedGcsRandomAccessFile : public RandomAccessFile { public: - using ReadFn = - std::function; + using ReadFn = std::function; // Initialize the reader. Provided read_fn should be thread safe. BufferedGcsRandomAccessFile(const string& filename, uint64 buffer_size, @@ -355,16 +356,16 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { buffer_start_(0), buffer_end_is_past_eof_(false) {} - Status Name(StringPiece* result) const override { + absl::Status Name(absl::string_view* result) const override { *result = filename_; - return OkStatus(); + return absl::OkStatus(); } /// The implementation of reads with an read buffer. Thread safe. /// Returns `OUT_OF_RANGE` if fewer than n bytes were stored in `*result` /// because of EOF. - Status Read(uint64 offset, size_t n, StringPiece* result, - char* scratch) const override { + absl::Status Read(uint64 offset, size_t n, absl::string_view* result, + char* scratch) const override { if (n > buffer_size_) { return read_fn_(filename_, offset, n, result, scratch); } @@ -375,12 +376,12 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { if (offset < buffer_end && offset >= buffer_start_) { copy_size = std::min(n, static_cast(buffer_end - offset)); memcpy(scratch, buffer_.data() + (offset - buffer_start_), copy_size); - *result = StringPiece(scratch, copy_size); + *result = absl::string_view(scratch, copy_size); } bool consumed_buffer_to_eof = offset + copy_size >= buffer_end && buffer_end_is_past_eof_; if (copy_size < n && !consumed_buffer_to_eof) { - Status status = FillBuffer(offset + copy_size); + absl::Status status = FillBuffer(offset + copy_size); if (!status.ok() && !absl::IsOutOfRange(status)) { // Empty the buffer to avoid caching bad reads. buffer_.resize(0); @@ -389,7 +390,7 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { size_t remaining_copy = std::min(n - copy_size, buffer_.size()); memcpy(scratch + copy_size, buffer_.data(), remaining_copy); copy_size += remaining_copy; - *result = StringPiece(scratch, copy_size); + *result = absl::string_view(scratch, copy_size); } if (copy_size < n) { // Forget the end-of-file flag to allow for clients that poll on the @@ -399,17 +400,17 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { " bytes from ", offset, "."); } } - return OkStatus(); + return absl::OkStatus(); } private: - Status FillBuffer(uint64 start) const + absl::Status FillBuffer(uint64 start) const TF_EXCLUSIVE_LOCKS_REQUIRED(buffer_mutex_) { buffer_start_ = start; buffer_.resize(buffer_size_); - StringPiece str_piece; - Status status = read_fn_(filename_, buffer_start_, buffer_size_, &str_piece, - &(buffer_[0])); + absl::string_view str_piece; + absl::Status status = read_fn_(filename_, buffer_start_, buffer_size_, + &str_piece, &(buffer_[0])); buffer_end_is_past_eof_ = absl::IsOutOfRange(status); buffer_.resize(str_piece.size()); return status; @@ -437,28 +438,28 @@ class BufferedGcsRandomAccessFile : public RandomAccessFile { }; // Function object declaration with params needed to create upload sessions. -typedef std::function SessionCreator; // Function object declaration with params needed to upload objects. -typedef std::function +typedef std::function ObjectUploader; // Function object declaration with params needed to poll upload status. -typedef std::function +typedef std::function StatusPoller; // Function object declaration with params needed to poll upload status. -typedef std::function +typedef std::function GenerationGetter; /// \brief GCS-based implementation of a writeable file. @@ -534,7 +535,7 @@ class GcsWritableFile : public WritableFile { std::remove(tmp_content_filename_.c_str()); } - Status Append(StringPiece data) override { + absl::Status Append(absl::string_view data) override { TF_RETURN_IF_ERROR(CheckWritable()); VLOG(3) << "Append: " << GetGcsPath() << " size " << data.length(); sync_needed_ = true; @@ -543,37 +544,38 @@ class GcsWritableFile : public WritableFile { return errors::Internal( "Could not append to the internal temporary file."); } - return OkStatus(); + return absl::OkStatus(); } - Status Close() override { + absl::Status Close() override { VLOG(3) << "Close:" << GetGcsPath(); if (outfile_.is_open()) { - Status sync_status = Sync(); + absl::Status sync_status = Sync(); if (sync_status.ok()) { outfile_.close(); } return sync_status; } - return OkStatus(); + return absl::OkStatus(); } - Status Flush() override { + absl::Status Flush() override { VLOG(3) << "Flush:" << GetGcsPath(); return Sync(); } - Status Name(StringPiece* result) const override { - return errors::Unimplemented("GCSWritableFile does not support Name()"); + absl::Status Name(absl::string_view* result) const override { + *result = object_; + return absl::OkStatus(); } - Status Sync() override { + absl::Status Sync() override { VLOG(3) << "Sync started:" << GetGcsPath(); TF_RETURN_IF_ERROR(CheckWritable()); if (!sync_needed_) { - return OkStatus(); + return absl::OkStatus(); } - Status status = SyncImpl(); + absl::Status status = SyncImpl(); VLOG(3) << "Sync finished " << GetGcsPath(); if (status.ok()) { sync_needed_ = false; @@ -581,12 +583,12 @@ class GcsWritableFile : public WritableFile { return status; } - Status Tell(int64_t* position) override { + absl::Status Tell(int64_t* position) override { *position = outfile_.tellp(); if (*position == -1) { return errors::Internal("tellp on the internal temporary file failed"); } - return OkStatus(); + return absl::OkStatus(); } private: @@ -596,7 +598,7 @@ class GcsWritableFile : public WritableFile { /// In case of a failure, it resumes failed uploads as recommended by the GCS /// resumable API documentation. When the whole upload needs to be /// restarted, Sync() returns UNAVAILABLE and relies on RetryingFileSystem. - Status SyncImpl() { + absl::Status SyncImpl() { outfile_.flush(); if (!outfile_.good()) { return errors::Internal( @@ -620,7 +622,7 @@ class GcsWritableFile : public WritableFile { &session_handle)); uint64 already_uploaded = 0; bool first_attempt = true; - const Status upload_status = RetryingUtils::CallWithRetries( + const absl::Status upload_status = RetryingUtils::CallWithRetries( [&first_attempt, &already_uploaded, &session_handle, &start_offset, this]() { if (session_handle.resumable && !first_attempt) { @@ -637,7 +639,7 @@ class GcsWritableFile : public WritableFile { // It's unclear why UploadToSession didn't return OK in the // previous attempt, but GCS reports that the file is fully // uploaded, so succeed. - return OkStatus(); + return absl::OkStatus(); } } first_attempt = false; @@ -661,28 +663,28 @@ class GcsWritableFile : public WritableFile { return upload_status; } - Status CheckWritable() const { + absl::Status CheckWritable() const { if (!outfile_.is_open()) { return errors::FailedPrecondition( "The internal temporary file is not writable."); } - return OkStatus(); + return absl::OkStatus(); } - Status GetCurrentFileSize(uint64* size) { + absl::Status GetCurrentFileSize(uint64* size) { const auto tellp = outfile_.tellp(); if (tellp == static_cast(-1)) { return errors::Internal( "Could not get the size of the internal temporary file."); } *size = tellp; - return OkStatus(); + return absl::OkStatus(); } /// Initiates a new resumable upload session. - Status CreateNewUploadSession(uint64 start_offset, - std::string object_to_upload, - UploadSessionHandle* session_handle) { + absl::Status CreateNewUploadSession(uint64 start_offset, + std::string object_to_upload, + UploadSessionHandle* session_handle) { uint64 file_size; TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); return session_creator_(start_offset, object_to_upload, bucket_, file_size, @@ -691,7 +693,7 @@ class GcsWritableFile : public WritableFile { /// Appends the data of append_object to the original object and deletes /// append_object. - Status AppendObject(string append_object) { + absl::Status AppendObject(string append_object) { const string append_object_path = GetGcsPathWithObject(append_object); VLOG(3) << "AppendObject: " << append_object_path << " to " << GetGcsPath(); @@ -718,7 +720,7 @@ class GcsWritableFile : public WritableFile { request->SetPostFromBuffer(request_body.c_str(), request_body.size()); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when composing to ", GetGcsPath()); - return OkStatus(); + return absl::OkStatus(); }, retry_config_)); @@ -734,8 +736,8 @@ class GcsWritableFile : public WritableFile { /// If the upload has already succeeded, sets 'completed' to true. /// Otherwise sets 'completed' to false and 'uploaded' to the currently /// uploaded size in bytes. - Status RequestUploadSessionStatus(const string& session_uri, bool* completed, - uint64* uploaded) { + absl::Status RequestUploadSessionStatus(const string& session_uri, + bool* completed, uint64* uploaded) { uint64 file_size; TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); return status_poller_(session_uri, file_size, GetGcsPath(), completed, @@ -743,11 +745,11 @@ class GcsWritableFile : public WritableFile { } /// Uploads data to object. - Status UploadToSession(const string& session_uri, uint64 start_offset, - uint64 already_uploaded) { + absl::Status UploadToSession(const string& session_uri, uint64 start_offset, + uint64 already_uploaded) { uint64 file_size; TF_RETURN_IF_ERROR(GetCurrentFileSize(&file_size)); - Status status = + absl::Status status = object_uploader_(session_uri, start_offset, already_uploaded, tmp_content_filename_, file_size, GetGcsPath()); if (status.ok()) { @@ -795,14 +797,14 @@ class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { uint64 length_; }; -bool StringPieceIdentity(StringPiece str, StringPiece* value) { +bool StringPieceIdentity(absl::string_view str, absl::string_view* value) { *value = str; return true; } /// \brief Utility function to split a comma delimited list of strings to an /// unordered set, lowercasing all values. -bool SplitByCommaToLowercaseSet(StringPiece list, +bool SplitByCommaToLowercaseSet(absl::string_view list, std::unordered_set* set) { std::vector vector = absl::StrSplit(absl::AsciiStrToLower(list), ','); *set = std::unordered_set(vector.begin(), vector.end()); @@ -897,14 +899,14 @@ GcsFileSystem::GcsFileSystem(bool make_default_cache) { } // Get the additional header - StringPiece add_header_contents; + absl::string_view add_header_contents; if (GetEnvVar(kAdditionalRequestHeader, StringPieceIdentity, &add_header_contents)) { size_t split = add_header_contents.find(':', 0); - if (split != StringPiece::npos) { - StringPiece header_name = add_header_contents.substr(0, split); - StringPiece header_value = add_header_contents.substr(split + 1); + if (split != absl::string_view::npos) { + absl::string_view header_name = add_header_contents.substr(0, split); + absl::string_view header_value = add_header_contents.substr(split + 1); if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair( @@ -968,7 +970,7 @@ GcsFileSystem::GcsFileSystem(bool make_default_cache) { GetEnvVar(kAllowedBucketLocations, SplitByCommaToLowercaseSet, &allowed_locations_); - StringPiece append_mode; + absl::string_view append_mode; GetEnvVar(kAppendMode, StringPieceIdentity, &append_mode); if (append_mode == kComposeAppend) { compose_append_ = true; @@ -1006,7 +1008,7 @@ GcsFileSystem::GcsFileSystem( compose_append_(compose_append), additional_header_(additional_header) {} -Status GcsFileSystem::NewRandomAccessFile( +absl::Status GcsFileSystem::NewRandomAccessFile( const string& fname, TransactionToken* token, std::unique_ptr* result) { string bucket, object; @@ -1016,7 +1018,7 @@ Status GcsFileSystem::NewRandomAccessFile( result->reset(new GcsRandomAccessFile(fname, [this, bucket, object]( const string& fname, uint64 offset, size_t n, - StringPiece* result, + absl::string_view* result, char* scratch) { tf_shared_lock l(block_cache_lock_); GcsFileStat stat; @@ -1031,37 +1033,37 @@ Status GcsFileSystem::NewRandomAccessFile( << "File signature has been changed. Refreshing the cache. Path: " << fname; } - *result = StringPiece(); + *result = absl::string_view(); size_t bytes_transferred; TF_RETURN_IF_ERROR(file_block_cache_->Read(fname, offset, n, scratch, &bytes_transferred)); - *result = StringPiece(scratch, bytes_transferred); + *result = absl::string_view(scratch, bytes_transferred); if (bytes_transferred < n) { return errors::OutOfRange("EOF reached, ", result->size(), " bytes were read out of ", n, " bytes requested."); } - return OkStatus(); + return absl::OkStatus(); })); } else { result->reset(new BufferedGcsRandomAccessFile( fname, block_size_, [this, bucket, object](const string& fname, uint64 offset, size_t n, - StringPiece* result, char* scratch) { - *result = StringPiece(); + absl::string_view* result, char* scratch) { + *result = absl::string_view(); size_t bytes_transferred; TF_RETURN_IF_ERROR( LoadBufferFromGCS(fname, offset, n, scratch, &bytes_transferred)); - *result = StringPiece(scratch, bytes_transferred); + *result = absl::string_view(scratch, bytes_transferred); if (bytes_transferred < n) { return errors::OutOfRange("EOF reached, ", result->size(), " bytes were read out of ", n, " bytes requested."); } - return OkStatus(); + return absl::OkStatus(); })); } - return OkStatus(); + return absl::OkStatus(); } void GcsFileSystem::ResetFileBlockCache(size_t block_size_bytes, @@ -1092,9 +1094,10 @@ std::unique_ptr GcsFileSystem::MakeFileBlockCache( } // A helper function to actually read the data from GCS. -Status GcsFileSystem::LoadBufferFromGCS(const string& fname, size_t offset, - size_t n, char* buffer, - size_t* bytes_transferred) { +absl::Status GcsFileSystem::LoadBufferFromGCS(const string& fname, + size_t offset, size_t n, + char* buffer, + size_t* bytes_transferred) { *bytes_transferred = 0; string bucket, object; @@ -1148,11 +1151,11 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& fname, size_t offset, } } - return OkStatus(); + return absl::OkStatus(); } /// Initiates a new upload session. -Status GcsFileSystem::CreateNewUploadSession( +absl::Status GcsFileSystem::CreateNewUploadSession( uint64 start_offset, const std::string& object_to_upload, const std::string& bucket, uint64 file_size, const std::string& gcs_path, UploadSessionHandle* session_handle) { @@ -1179,15 +1182,13 @@ Status GcsFileSystem::CreateNewUploadSession( gcs_path, ": 'Location' header not returned."); } } - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::UploadToSession(const std::string& session_uri, - uint64 start_offset, - uint64 already_uploaded, - const std::string& tmp_content_filename, - uint64 file_size, - const std::string& file_path) { +absl::Status GcsFileSystem::UploadToSession( + const std::string& session_uri, uint64 start_offset, + uint64 already_uploaded, const std::string& tmp_content_filename, + uint64 file_size, const std::string& file_path) { std::unique_ptr request; TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(session_uri); @@ -1203,14 +1204,12 @@ Status GcsFileSystem::UploadToSession(const std::string& session_uri, start_offset + already_uploaded)); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when uploading ", file_path); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri, - uint64 file_size, - const std::string& gcs_path, - bool* completed, - uint64* uploaded) { +absl::Status GcsFileSystem::RequestUploadSessionStatus( + const string& session_uri, uint64 file_size, const std::string& gcs_path, + bool* completed, uint64* uploaded) { CHECK(completed != nullptr) << "RequestUploadSessionStatus() called with out " "param 'completed' == nullptr."; // Crash ok CHECK(uploaded != nullptr) << "RequestUploadSessionStatus() called with out " @@ -1221,10 +1220,10 @@ Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri, request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); request->AddHeader("Content-Range", strings::StrCat("bytes */", file_size)); request->SetPutEmptyBody(); - Status status = request->Send(); + absl::Status status = request->Send(); if (status.ok()) { *completed = true; - return OkStatus(); + return absl::OkStatus(); } *completed = false; if (request->GetResponseCode() != HTTP_CODE_RESUME_INCOMPLETE) { @@ -1235,7 +1234,7 @@ Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri, // This means GCS doesn't have any bytes of the file yet. *uploaded = 0; } else { - StringPiece range_piece(received_range); + absl::string_view range_piece(received_range); absl::ConsumePrefix(&range_piece, "bytes="); // May or may not be present. @@ -1269,13 +1268,15 @@ Status GcsFileSystem::RequestUploadSessionStatus(const string& session_uri, // If GCS returned "Range: 0-10", this means 11 bytes were uploaded. *uploaded = range_parts[1] + 1; } - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::ParseGcsPathForScheme(StringPiece fname, string scheme, - bool empty_object_ok, - string* bucket, string* object) { - StringPiece parsed_scheme, bucketp, objectp; +absl::Status GcsFileSystem::ParseGcsPathForScheme(absl::string_view fname, + string scheme, + bool empty_object_ok, + string* bucket, + string* object) { + absl::string_view parsed_scheme, bucketp, objectp; io::ParseURI(fname, &parsed_scheme, &bucketp, &objectp); if (parsed_scheme != scheme) { return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", @@ -1292,11 +1293,12 @@ Status GcsFileSystem::ParseGcsPathForScheme(StringPiece fname, string scheme, return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); } - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::ParseGcsPath(StringPiece fname, bool empty_object_ok, - string* bucket, string* object) { +absl::Status GcsFileSystem::ParseGcsPath(absl::string_view fname, + bool empty_object_ok, string* bucket, + string* object) { return ParseGcsPathForScheme(fname, "gs", empty_object_ok, bucket, object); } @@ -1308,9 +1310,9 @@ void GcsFileSystem::ClearFileCaches(const string& fname) { // MatchingPathsCache as well. } -Status GcsFileSystem::NewWritableFile(const string& fname, - TransactionToken* token, - std::unique_ptr* result) { +absl::Status GcsFileSystem::NewWritableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); @@ -1344,7 +1346,7 @@ Status GcsFileSystem::NewWritableFile(const string& fname, }, retry_config_)); *generation = stat.generation_number; - return OkStatus(); + return absl::OkStatus(); }; result->reset(new GcsWritableFile( @@ -1352,20 +1354,20 @@ Status GcsFileSystem::NewWritableFile(const string& fname, [this, fname]() { ClearFileCaches(fname); }, retry_config_, compose_append_, session_creator, object_uploader, status_poller, generation_getter)); - return OkStatus(); + return absl::OkStatus(); } // Reads the file from GCS in chunks and stores it in a tmp file, // which is then passed to GcsWritableFile. -Status GcsFileSystem::NewAppendableFile(const string& fname, - TransactionToken* token, - std::unique_ptr* result) { +absl::Status GcsFileSystem::NewAppendableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) { std::unique_ptr reader; TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &reader)); std::unique_ptr buffer(new char[kReadAppendableFileBufferSize]); - Status status; + absl::Status status; uint64 offset = 0; - StringPiece read_chunk; + absl::string_view read_chunk; // Read the file from GCS in chunks and save it to a tmp file. string old_content_filename; @@ -1421,7 +1423,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname, }, retry_config_)); *generation = stat.generation_number; - return OkStatus(); + return absl::OkStatus(); }; // Create a writable file and pass the old content to it. @@ -1432,10 +1434,10 @@ Status GcsFileSystem::NewAppendableFile(const string& fname, [this, fname]() { ClearFileCaches(fname); }, retry_config_, compose_append_, session_creator, object_uploader, status_poller, generation_getter)); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( +absl::Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( const string& fname, TransactionToken* token, std::unique_ptr* result) { uint64 size; @@ -1445,21 +1447,22 @@ Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile( std::unique_ptr file; TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, token, &file)); - StringPiece piece; + absl::string_view piece; TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get())); result->reset(new GcsReadOnlyMemoryRegion(std::move(data), size)); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) { +absl::Status GcsFileSystem::FileExists(const string& fname, + TransactionToken* token) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); if (object.empty()) { bool result; TF_RETURN_IF_ERROR(BucketExists(bucket, &result)); if (result) { - return OkStatus(); + return absl::OkStatus(); } else { return absl::NotFoundError( absl::StrCat("The specified bucket ", fname, " was not found.")); @@ -1468,7 +1471,7 @@ Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) { // Check if the object exists. GcsFileStat stat; - const Status status = StatForObject(fname, bucket, object, &stat); + const absl::Status status = StatForObject(fname, bucket, object, &stat); if (!absl::IsNotFound(status)) { return status; } @@ -1477,31 +1480,32 @@ Status GcsFileSystem::FileExists(const string& fname, TransactionToken* token) { bool result; TF_RETURN_IF_ERROR(FolderExists(fname, &result)); if (result) { - return OkStatus(); + return absl::OkStatus(); } return errors::NotFound("The specified path ", fname, " was not found."); } -Status GcsFileSystem::ObjectExists(const string& fname, const string& bucket, - const string& object, bool* result) { +absl::Status GcsFileSystem::ObjectExists(const string& fname, + const string& bucket, + const string& object, bool* result) { GcsFileStat stat; - const Status status = StatForObject(fname, bucket, object, &stat); + const absl::Status status = StatForObject(fname, bucket, object, &stat); switch (static_cast(status.code())) { case static_cast(error::Code::OK): *result = !stat.base.is_directory; - return OkStatus(); + return absl::OkStatus(); case static_cast(error::Code::NOT_FOUND): *result = false; - return OkStatus(); + return absl::OkStatus(); default: return status; } } -Status GcsFileSystem::UncachedStatForObject(const string& fname, - const string& bucket, - const string& object, - GcsFileStat* stat) { +absl::Status GcsFileSystem::UncachedStatForObject(const string& fname, + const string& bucket, + const string& object, + GcsFileStat* stat) { std::vector output_buffer; std::unique_ptr request; TF_RETURN_WITH_CONTEXT_IF_ERROR(CreateHttpRequest(&request), @@ -1542,7 +1546,7 @@ Status GcsFileSystem::UncachedStatForObject(const string& fname, << "; mtime_nsec: " << stat->base.mtime_nsec << "; updated: " << updated; - if (str_util::EndsWith(fname, "/")) { + if (absl::EndsWith(fname, "/")) { // In GCS a path can be both a directory and a file, both it is uncommon for // other file systems. To avoid the ambiguity, if a path ends with "/" in // GCS, we always regard it as a directory mark or a virtual directory. @@ -1550,11 +1554,13 @@ Status GcsFileSystem::UncachedStatForObject(const string& fname, } else { stat->base.is_directory = false; } - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, - const string& object, GcsFileStat* stat) { +absl::Status GcsFileSystem::StatForObject(const string& fname, + const string& bucket, + const string& object, + GcsFileStat* stat) { if (object.empty()) { return errors::InvalidArgument(strings::Printf( "'object' must be a non-empty string. (File: %s)", fname.c_str())); @@ -1565,26 +1571,27 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket, [this, &bucket, &object](const string& fname, GcsFileStat* stat) { return UncachedStatForObject(fname, bucket, object, stat); })); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::BucketExists(const string& bucket, bool* result) { - const Status status = GetBucketMetadata(bucket, nullptr); +absl::Status GcsFileSystem::BucketExists(const string& bucket, bool* result) { + const absl::Status status = GetBucketMetadata(bucket, nullptr); switch (static_cast(status.code())) { case absl::StatusCode::kOk: *result = true; - return OkStatus(); + return absl::OkStatus(); case absl::StatusCode::kNotFound: *result = false; - return OkStatus(); + return absl::OkStatus(); default: return status; } } -Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) { +absl::Status GcsFileSystem::CheckBucketLocationConstraint( + const string& bucket) { if (allowed_locations_.empty()) { - return OkStatus(); + return absl::OkStatus(); } // Avoid calling external API's in the constructor @@ -1597,7 +1604,7 @@ Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) { string location; TF_RETURN_IF_ERROR(GetBucketLocation(bucket, &location)); if (allowed_locations_.find(location) != allowed_locations_.end()) { - return OkStatus(); + return absl::OkStatus(); } return errors::FailedPrecondition(strings::Printf( @@ -1606,11 +1613,11 @@ Status GcsFileSystem::CheckBucketLocationConstraint(const string& bucket) { absl::StrJoin(allowed_locations_, ", ").c_str())); } -Status GcsFileSystem::GetBucketLocation(const string& bucket, - string* location) { +absl::Status GcsFileSystem::GetBucketLocation(const string& bucket, + string* location) { auto compute_func = [this](const string& bucket, string* location) { std::vector result_buffer; - Status status = GetBucketMetadata(bucket, &result_buffer); + absl::Status status = GetBucketMetadata(bucket, &result_buffer); Json::Value result; TF_RETURN_IF_ERROR(ParseJson(result_buffer, &result)); string bucket_location; @@ -1618,17 +1625,17 @@ Status GcsFileSystem::GetBucketLocation(const string& bucket, GetStringValue(result, kBucketMetadataLocationKey, &bucket_location)); // Lowercase the GCS location to be case insensitive for allowed locations. *location = absl::AsciiStrToLower(bucket_location); - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR( bucket_location_cache_->LookupOrCompute(bucket, location, compute_func)); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::GetBucketMetadata(const string& bucket, - std::vector* result_buffer) { +absl::Status GcsFileSystem::GetBucketMetadata( + const string& bucket, std::vector* result_buffer) { std::unique_ptr request; TF_RETURN_IF_ERROR(CreateHttpRequest(&request)); request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket)); @@ -1641,7 +1648,7 @@ Status GcsFileSystem::GetBucketMetadata(const string& bucket, return request->Send(); } -Status GcsFileSystem::FolderExists(const string& dirname, bool* result) { +absl::Status GcsFileSystem::FolderExists(const string& dirname, bool* result) { StatCache::ComputeFunc compute_func = [this](const string& dirname, GcsFileStat* stat) { std::vector children; @@ -1650,36 +1657,36 @@ Status GcsFileSystem::FolderExists(const string& dirname, bool* result) { true /* include_self_directory_marker */)); if (!children.empty()) { stat->base = DIRECTORY_STAT; - return OkStatus(); + return absl::OkStatus(); } else { return errors::InvalidArgument("Not a directory!"); } }; GcsFileStat stat; - Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), &stat, - compute_func); + absl::Status s = stat_cache_->LookupOrCompute(MaybeAppendSlash(dirname), + &stat, compute_func); if (s.ok()) { *result = stat.base.is_directory; - return OkStatus(); + return absl::OkStatus(); } if (absl::IsInvalidArgument(s)) { *result = false; - return OkStatus(); + return absl::OkStatus(); } return s; } -Status GcsFileSystem::GetChildren(const string& dirname, - TransactionToken* token, - std::vector* result) { +absl::Status GcsFileSystem::GetChildren(const string& dirname, + TransactionToken* token, + std::vector* result) { return GetChildrenBounded(dirname, UINT64_MAX, result, false /* recursively */, false /* include_self_directory_marker */); } -Status GcsFileSystem::GetMatchingPaths(const string& pattern, - TransactionToken* token, - std::vector* results) { +absl::Status GcsFileSystem::GetMatchingPaths(const string& pattern, + TransactionToken* token, + std::vector* results) { MatchingPathsCache::ComputeFunc compute_func = [this](const string& pattern, std::vector* results) { results->clear(); @@ -1700,7 +1707,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // To handle `/` in the object names, we need to remove it from `dir` // and then use `StrCat` to insert it back. - const StringPiece dir_no_slash = str_util::StripSuffix(dir, "/"); + const absl::string_view dir_no_slash = absl::StripSuffix(dir, "/"); // Match all obtained paths to the input pattern. for (const auto& path : files_and_folders) { @@ -1715,18 +1722,16 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, results->push_back(full_path); } } - return OkStatus(); + return absl::OkStatus(); }; TF_RETURN_IF_ERROR( matching_paths_cache_->LookupOrCompute(pattern, results, compute_func)); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::GetChildrenBounded(const string& dirname, - uint64 max_results, - std::vector* result, - bool recursive, - bool include_self_directory_marker) { +absl::Status GcsFileSystem::GetChildrenBounded( + const string& dirname, uint64 max_results, std::vector* result, + bool recursive, bool include_self_directory_marker) { if (!result) { return errors::InvalidArgument("'result' cannot be null"); } @@ -1786,7 +1791,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, // The names should be relative to the 'dirname'. That means the // 'object_prefix', which is part of 'dirname', should be removed from // the beginning of 'name'. - StringPiece relative_path(name); + absl::string_view relative_path(name); if (!absl::ConsumePrefix(&relative_path, object_prefix)) { return errors::Internal(strings::StrCat( "Unexpected response: the returned file name ", name, @@ -1796,7 +1801,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, result->emplace_back(relative_path); } if (++retrieved_results >= max_results) { - return OkStatus(); + return absl::OkStatus(); } } } @@ -1815,7 +1820,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "response."); } const string& prefix_str = prefix.asString(); - StringPiece relative_path(prefix_str); + absl::string_view relative_path(prefix_str); if (!absl::ConsumePrefix(&relative_path, object_prefix)) { return errors::Internal( "Unexpected response: the returned folder name ", prefix_str, @@ -1823,13 +1828,13 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, } result->emplace_back(relative_path); if (++retrieved_results >= max_results) { - return OkStatus(); + return absl::OkStatus(); } } } const auto token = root.get("nextPageToken", Json::Value::null); if (token.isNull()) { - return OkStatus(); + return absl::OkStatus(); } if (!token.isString()) { return errors::Internal( @@ -1839,8 +1844,8 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, } } -Status GcsFileSystem::Stat(const string& fname, TransactionToken* token, - FileStatistics* stat) { +absl::Status GcsFileSystem::Stat(const string& fname, TransactionToken* token, + FileStatistics* stat) { if (!stat) { return errors::Internal("'stat' cannot be nullptr."); } @@ -1851,16 +1856,16 @@ Status GcsFileSystem::Stat(const string& fname, TransactionToken* token, TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); if (is_bucket) { *stat = DIRECTORY_STAT; - return OkStatus(); + return absl::OkStatus(); } return errors::NotFound("The specified bucket ", fname, " was not found."); } GcsFileStat gcs_stat; - const Status status = StatForObject(fname, bucket, object, &gcs_stat); + const absl::Status status = StatForObject(fname, bucket, object, &gcs_stat); if (status.ok()) { *stat = gcs_stat.base; - return OkStatus(); + return absl::OkStatus(); } if (!absl::IsNotFound(status)) { return status; @@ -1869,12 +1874,13 @@ Status GcsFileSystem::Stat(const string& fname, TransactionToken* token, TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder)); if (is_folder) { *stat = DIRECTORY_STAT; - return OkStatus(); + return absl::OkStatus(); } return errors::NotFound("The specified path ", fname, " was not found."); } -Status GcsFileSystem::DeleteFile(const string& fname, TransactionToken* token) { +absl::Status GcsFileSystem::DeleteFile(const string& fname, + TransactionToken* token) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); @@ -1887,11 +1893,11 @@ Status GcsFileSystem::DeleteFile(const string& fname, TransactionToken* token) { TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname); ClearFileCaches(fname); - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::CreateDir(const string& dirname, - TransactionToken* token) { +absl::Status GcsFileSystem::CreateDir(const string& dirname, + TransactionToken* token) { string dirname_with_slash = MaybeAppendSlash(dirname); VLOG(3) << "CreateDir: creating directory with dirname: " << dirname << " and dirname_with_slash: " << dirname_with_slash; @@ -1901,7 +1907,7 @@ Status GcsFileSystem::CreateDir(const string& dirname, if (object.empty()) { bool is_bucket; TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); - return is_bucket ? OkStatus() + return is_bucket ? absl::OkStatus() : errors::NotFound("The specified bucket ", dirname_with_slash, " was not found."); } @@ -1924,10 +1930,10 @@ Status GcsFileSystem::CreateDir(const string& dirname, request->SetPostEmptyBody(); request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.metadata); - const Status& status = request->Send(); + const absl::Status& status = request->Send(); if (status.ok()) { VLOG(3) << "CreateDir: finished uploading directory " << dirname; - return OkStatus(); + return absl::OkStatus(); } if (request->GetResponseCode() != HTTP_CODE_PRECONDITION_FAILED) { TF_RETURN_WITH_CONTEXT_IF_ERROR(status, " when uploading ", @@ -1940,8 +1946,8 @@ Status GcsFileSystem::CreateDir(const string& dirname, // Checks that the directory is empty (i.e no objects with this prefix exist). // Deletes the GCS directory marker if it exists. -Status GcsFileSystem::DeleteDir(const string& dirname, - TransactionToken* token) { +absl::Status GcsFileSystem::DeleteDir(const string& dirname, + TransactionToken* token) { std::vector children; // A directory is considered empty either if there are no matching objects // with the corresponding name prefix or if there is exactly one matching @@ -1958,11 +1964,12 @@ Status GcsFileSystem::DeleteDir(const string& dirname, // This is the directory marker object. Delete it. return DeleteFile(MaybeAppendSlash(dirname), token); } - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::GetFileSize(const string& fname, TransactionToken* token, - uint64* file_size) { +absl::Status GcsFileSystem::GetFileSize(const string& fname, + TransactionToken* token, + uint64* file_size) { if (!file_size) { return errors::Internal("'file_size' cannot be nullptr."); } @@ -1974,11 +1981,11 @@ Status GcsFileSystem::GetFileSize(const string& fname, TransactionToken* token, FileStatistics stat; TF_RETURN_IF_ERROR(Stat(fname, token, &stat)); *file_size = stat.length; - return OkStatus(); + return absl::OkStatus(); } -Status GcsFileSystem::RenameFile(const string& src, const string& target, - TransactionToken* token) { +absl::Status GcsFileSystem::RenameFile(const string& src, const string& target, + TransactionToken* token) { if (!IsDirectory(src, token).ok()) { return RenameObject(src, target); } @@ -1991,11 +1998,12 @@ Status GcsFileSystem::RenameFile(const string& src, const string& target, TF_RETURN_IF_ERROR( RenameObject(JoinGcsPath(src, subpath), JoinGcsPath(target, subpath))); } - return OkStatus(); + return absl::OkStatus(); } // Uses a GCS API command to copy the object and then deletes the old one. -Status GcsFileSystem::RenameObject(const string& src, const string& target) { +absl::Status GcsFileSystem::RenameObject(const string& src, + const string& target) { VLOG(3) << "RenameObject: started gs://" << src << " to " << target; string src_bucket, src_object, target_bucket, target_object; TF_RETURN_IF_ERROR(ParseGcsPath(src, false, &src_bucket, &src_object)); @@ -2040,15 +2048,15 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) { [this, &src]() { return DeleteFile(src, nullptr); }, retry_config_); } -Status GcsFileSystem::IsDirectory(const string& fname, - TransactionToken* token) { +absl::Status GcsFileSystem::IsDirectory(const string& fname, + TransactionToken* token) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, true, &bucket, &object)); if (object.empty()) { bool is_bucket; TF_RETURN_IF_ERROR(BucketExists(bucket, &is_bucket)); if (is_bucket) { - return OkStatus(); + return absl::OkStatus(); } return errors::NotFound("The specified bucket gs://", bucket, " was not found."); @@ -2056,7 +2064,7 @@ Status GcsFileSystem::IsDirectory(const string& fname, bool is_folder; TF_RETURN_IF_ERROR(FolderExists(fname, &is_folder)); if (is_folder) { - return OkStatus(); + return absl::OkStatus(); } bool is_object; TF_RETURN_IF_ERROR(ObjectExists(fname, bucket, object, &is_object)); @@ -2067,10 +2075,10 @@ Status GcsFileSystem::IsDirectory(const string& fname, return errors::NotFound("The specified path ", fname, " was not found."); } -Status GcsFileSystem::DeleteRecursively(const string& dirname, - TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) { +absl::Status GcsFileSystem::DeleteRecursively(const string& dirname, + TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) { if (!undeleted_files || !undeleted_dirs) { return errors::Internal( "'undeleted_files' and 'undeleted_dirs' cannot be nullptr."); @@ -2079,7 +2087,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, *undeleted_dirs = 0; if (!IsDirectory(dirname, token).ok()) { *undeleted_dirs = 1; - return Status( + return absl::Status( absl::StatusCode::kNotFound, strings::StrCat(dirname, " doesn't exist or not a directory.")); } @@ -2106,7 +2114,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, } } } - return OkStatus(); + return absl::OkStatus(); } // Flushes all caches for filesystem metadata and file contents. Useful for @@ -2148,7 +2156,8 @@ void GcsFileSystem::SetAuthProvider( // Creates an HttpRequest and sets several parameters that are common to all // requests. All code (in GcsFileSystem) that creates an HttpRequest should // go through this method, rather than directly using http_request_factory_. -Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { +absl::Status GcsFileSystem::CreateHttpRequest( + std::unique_ptr* request) { std::unique_ptr new_request{http_request_factory_->Create()}; if (dns_cache_) { dns_cache_->AnnotateRequest(new_request.get()); @@ -2177,7 +2186,7 @@ Status GcsFileSystem::CreateHttpRequest(std::unique_ptr* request) { } *request = std::move(new_request); - return OkStatus(); + return absl::OkStatus(); } RetryingGcsFileSystem::RetryingGcsFileSystem() diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h index 17725e8d5b01e6..f7452a4eb69989 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system.h @@ -66,7 +66,7 @@ constexpr uint64 kDefaultMaxStaleness = 0; // Helper function to extract an environment variable and convert it into a // value of type T. template -bool GetEnvVar(const char* varname, bool (*convert)(StringPiece, T*), +bool GetEnvVar(const char* varname, bool (*convert)(absl::string_view, T*), T* value) { const char* env_value = std::getenv(varname); if (env_value == nullptr) { @@ -144,48 +144,54 @@ class GcsFileSystem : public FileSystem { TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT; - Status NewRandomAccessFile( + absl::Status NewRandomAccessFile( const string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status NewWritableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; + absl::Status NewWritableFile(const string& fname, TransactionToken* token, + std::unique_ptr* result) override; - Status NewAppendableFile(const string& fname, TransactionToken* token, - std::unique_ptr* result) override; + absl::Status NewAppendableFile( + const string& fname, TransactionToken* token, + std::unique_ptr* result) override; - Status NewReadOnlyMemoryRegionFromFile( + absl::Status NewReadOnlyMemoryRegionFromFile( const string& fname, TransactionToken* token, std::unique_ptr* result) override; - Status FileExists(const string& fname, TransactionToken* token) override; + absl::Status FileExists(const string& fname, + TransactionToken* token) override; - Status Stat(const string& fname, TransactionToken* token, - FileStatistics* stat) override; + absl::Status Stat(const string& fname, TransactionToken* token, + FileStatistics* stat) override; - Status GetChildren(const string& dir, TransactionToken* token, - std::vector* result) override; + absl::Status GetChildren(const string& dir, TransactionToken* token, + std::vector* result) override; - Status GetMatchingPaths(const string& pattern, TransactionToken* token, - std::vector* results) override; + absl::Status GetMatchingPaths(const string& pattern, TransactionToken* token, + std::vector* results) override; - Status DeleteFile(const string& fname, TransactionToken* token) override; + absl::Status DeleteFile(const string& fname, + TransactionToken* token) override; - Status CreateDir(const string& dirname, TransactionToken* token) override; + absl::Status CreateDir(const string& dirname, + TransactionToken* token) override; - Status DeleteDir(const string& dirname, TransactionToken* token) override; + absl::Status DeleteDir(const string& dirname, + TransactionToken* token) override; - Status GetFileSize(const string& fname, TransactionToken* token, - uint64* file_size) override; + absl::Status GetFileSize(const string& fname, TransactionToken* token, + uint64* file_size) override; - Status RenameFile(const string& src, const string& target, - TransactionToken* token) override; + absl::Status RenameFile(const string& src, const string& target, + TransactionToken* token) override; - Status IsDirectory(const string& fname, TransactionToken* token) override; + absl::Status IsDirectory(const string& fname, + TransactionToken* token) override; - Status DeleteRecursively(const string& dirname, TransactionToken* token, - int64_t* undeleted_files, - int64_t* undeleted_dirs) override; + absl::Status DeleteRecursively(const string& dirname, TransactionToken* token, + int64_t* undeleted_files, + int64_t* undeleted_dirs) override; void FlushCaches(TransactionToken* token) override; @@ -267,7 +273,7 @@ class GcsFileSystem : public FileSystem { write(write) {} }; - Status CreateHttpRequest(std::unique_ptr* request); + absl::Status CreateHttpRequest(std::unique_ptr* request); /// \brief Sets a new AuthProvider on the GCS FileSystem. /// @@ -289,37 +295,38 @@ class GcsFileSystem : public FileSystem { size_t block_size, size_t max_bytes, uint64 max_staleness); /// Loads file contents from GCS for a given filename, offset, and length. - virtual Status LoadBufferFromGCS(const string& fname, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred); + virtual absl::Status LoadBufferFromGCS(const string& fname, size_t offset, + size_t n, char* buffer, + size_t* bytes_transferred); // Creates an upload session for an upcoming GCS object upload. - virtual Status CreateNewUploadSession(uint64 start_offset, - const std::string& object_to_upload, - const std::string& bucket, - uint64 file_size, - const std::string& gcs_path, - UploadSessionHandle* session_handle); + virtual absl::Status CreateNewUploadSession( + uint64 start_offset, const std::string& object_to_upload, + const std::string& bucket, uint64 file_size, const std::string& gcs_path, + UploadSessionHandle* session_handle); // Uploads object data to session. - virtual Status UploadToSession(const std::string& session_uri, - uint64 start_offset, uint64 already_uploaded, - const std::string& tmp_content_filename, - uint64 file_size, - const std::string& file_path); + virtual absl::Status UploadToSession(const std::string& session_uri, + uint64 start_offset, + uint64 already_uploaded, + const std::string& tmp_content_filename, + uint64 file_size, + const std::string& file_path); /// \brief Requests status of a previously initiated upload session. /// /// If the upload has already succeeded, sets 'completed' to true. /// Otherwise sets 'completed' to false and 'uploaded' to the currently /// uploaded size in bytes. - virtual Status RequestUploadSessionStatus(const string& session_uri, - uint64 file_size, - const std::string& gcs_path, - bool* completed, uint64* uploaded); + virtual absl::Status RequestUploadSessionStatus(const string& session_uri, + uint64 file_size, + const std::string& gcs_path, + bool* completed, + uint64* uploaded); - Status ParseGcsPathForScheme(StringPiece fname, string scheme, - bool empty_object_ok, string* bucket, - string* object); + absl::Status ParseGcsPathForScheme(absl::string_view fname, string scheme, + bool empty_object_ok, string* bucket, + string* object); /// \brief Splits a GCS path to a bucket and an object. /// @@ -327,8 +334,9 @@ class GcsFileSystem : public FileSystem { /// "bucket-name" and "path/to/file.txt". /// If fname only contains the bucket and empty_object_ok = true, the returned /// object is empty. - virtual Status ParseGcsPath(StringPiece fname, bool empty_object_ok, - string* bucket, string* object); + virtual absl::Status ParseGcsPath(absl::string_view fname, + bool empty_object_ok, string* bucket, + string* object); std::shared_ptr compute_engine_metadata_client_; @@ -348,7 +356,7 @@ class GcsFileSystem : public FileSystem { /// \brief Checks if the bucket exists. Returns OK if the check succeeded. /// /// 'result' is set if the function returns OK. 'result' cannot be nullptr. - Status BucketExists(const string& bucket, bool* result); + absl::Status BucketExists(const string& bucket, bool* result); /// \brief Retrieves the GCS bucket location. Returns OK if the location was /// retrieved. @@ -359,28 +367,28 @@ class GcsFileSystem : public FileSystem { /// This requires the bucket metadata permission. /// Repeated calls for the same bucket are cached so this function can be /// called frequently without causing an extra API call - Status GetBucketLocation(const string& bucket, string* location); + absl::Status GetBucketLocation(const string& bucket, string* location); /// \brief Check if the GCS buckets location is allowed with the current /// constraint configuration - Status CheckBucketLocationConstraint(const string& bucket); + absl::Status CheckBucketLocationConstraint(const string& bucket); /// \brief Given the input bucket `bucket`, fills `result_buffer` with the /// results of the metadata. Returns OK if the API call succeeds without /// error. - Status GetBucketMetadata(const string& bucket, - std::vector* result_buffer); + absl::Status GetBucketMetadata(const string& bucket, + std::vector* result_buffer); /// \brief Checks if the object exists. Returns OK if the check succeeded. /// /// 'result' is set if the function returns OK. 'result' cannot be nullptr. - Status ObjectExists(const string& fname, const string& bucket, - const string& object, bool* result); + absl::Status ObjectExists(const string& fname, const string& bucket, + const string& object, bool* result); /// \brief Checks if the folder exists. Returns OK if the check succeeded. /// /// 'result' is set if the function returns OK. 'result' cannot be nullptr. - Status FolderExists(const string& dirname, bool* result); + absl::Status FolderExists(const string& dirname, bool* result); /// \brief Internal version of GetChildren with more knobs. /// @@ -390,19 +398,19 @@ class GcsFileSystem : public FileSystem { /// If 'include_self_directory_marker' is true and there is a GCS directory /// marker at the path 'dir', GetChildrenBound will return an empty string /// as one of the children that represents this marker. - Status GetChildrenBounded(const string& dir, uint64 max_results, - std::vector* result, bool recursively, - bool include_self_directory_marker); + absl::Status GetChildrenBounded(const string& dir, uint64 max_results, + std::vector* result, bool recursively, + bool include_self_directory_marker); /// Retrieves file statistics assuming fname points to a GCS object. The data /// may be read from cache or from GCS directly. - Status StatForObject(const string& fname, const string& bucket, - const string& object, GcsFileStat* stat); + absl::Status StatForObject(const string& fname, const string& bucket, + const string& object, GcsFileStat* stat); /// Retrieves file statistics of file fname directly from GCS. - Status UncachedStatForObject(const string& fname, const string& bucket, - const string& object, GcsFileStat* stat); + absl::Status UncachedStatForObject(const string& fname, const string& bucket, + const string& object, GcsFileStat* stat); - Status RenameObject(const string& src, const string& target); + absl::Status RenameObject(const string& src, const string& target); // Clear all the caches related to the file with name `filename`. void ClearFileCaches(const string& fname); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc index 9221128276af9e..9d9d3088467df0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_file_system_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cloud/http_request_fake.h" #include "tsl/platform/errors.h" #include "tsl/platform/str_util.h" @@ -45,17 +45,17 @@ static std::unordered_set* kAllowedLocationsAuto = class FakeAuthProvider : public AuthProvider { public: - Status GetToken(string* token) override { + absl::Status GetToken(string* token) override { *token = "fake_token"; - return OkStatus(); + return absl::OkStatus(); } }; class FakeZoneProvider : public ZoneProvider { public: - Status GetZone(string* zone) override { + absl::Status GetZone(string* zone) override { *zone = "us-east1-b"; - return OkStatus(); + return absl::OkStatus(); } }; @@ -88,12 +88,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[6]; - StringPiece result; + absl::string_view result; // Read the first chunk. TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); @@ -135,12 +135,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[6]; - StringPiece result; + absl::string_view result; // Read the first chunk. TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); @@ -183,12 +183,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Errors) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[6]; - StringPiece result; + absl::string_view result; // Read the first chunk. EXPECT_TRUE( @@ -230,12 +230,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadAtEOF) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[10]; - StringPiece result; + absl::string_view result; // Read the first chunk. TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); @@ -271,12 +271,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedOutOfRange) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[5]; - StringPiece result; + absl::string_view result; // Read the first chunk. Even though the backend response is out-of-range, // we should get a OK status since we're just reading the first 5 bytes. @@ -323,12 +323,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_CachedNotSequential) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[5]; - StringPiece result; + absl::string_view result; TF_EXPECT_OK(file->Read(1, sizeof(scratch), &result, scratch)); EXPECT_EQ("12345", result); @@ -365,12 +365,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_Growing) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[10]; - StringPiece result; + absl::string_view result; // Read the first chunk. Since the first read is out-of-range, // we don't cache the out-of-range flag and each subsequent read triggers a @@ -413,12 +413,12 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_Buffered_ReadBackwards) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); - StringPiece filename; + absl::string_view filename; TF_EXPECT_OK(file->Name(&filename)); EXPECT_EQ(filename, "gs://bucket/random_access.txt"); char scratch[10]; - StringPiece result; + absl::string_view result; // Read the first chunk. EXPECT_TRUE( @@ -574,7 +574,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) { fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); char small_scratch[3]; - StringPiece result; + absl::string_view result; // Read the first chunk. TF_EXPECT_OK(file->Read(0, sizeof(small_scratch), &result, small_scratch)); @@ -629,7 +629,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { nullptr /* gcs additional header */, false /* compose append */); char scratch[100]; - StringPiece result; + absl::string_view result; { // We are instantiating this in an enclosed scope to make sure after the // unique ptr goes out of scope, we can still access result. @@ -716,7 +716,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { nullptr /* gcs additional header */, false /* compose append */); char scratch[100]; - StringPiece result; + absl::string_view result; std::unique_ptr file; TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); @@ -766,7 +766,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */, false /* compose append */); char scratch[100]; - StringPiece result; + absl::string_view result; // There should only be two HTTP requests issued to GCS even though we iterate // this loop 10 times. This shows that the underlying FileBlockCache persists // across file close/open boundaries. @@ -841,7 +841,7 @@ TEST(GcsFileSystemTest, fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); char scratch[5]; - StringPiece result; + absl::string_view result; // First read. TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); @@ -908,7 +908,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); char scratch[6]; - StringPiece result; + absl::string_view result; EXPECT_TRUE( errors::IsInternal(file->Read(0, sizeof(scratch), &result, scratch))); @@ -972,7 +972,7 @@ TEST(GcsFileSystemTest, NewWritableFile) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile)); char scratch[100]; - StringPiece result; + absl::string_view result; TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch)); EXPECT_EQ("0123", result); // Open the writable file. @@ -1107,7 +1107,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Timeouts: 5 1 10\n" "Header Content-Range: bytes */17\n" "Put: yes\n", - "", OkStatus(), nullptr, {}, 201), + "", absl::OkStatus(), nullptr, {}, 201), new FakeHttpRequest( "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/" "path%2Fwriteable?fields=size%2Cgeneration%2Cupdated\n" @@ -1138,7 +1138,7 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/path/writeable", nullptr, &rfile)); char scratch[100]; - StringPiece result; + absl::string_view result; TF_EXPECT_OK(rfile->Read(0, 4, &result, scratch)); EXPECT_EQ("0123", result); // Now write to the same file. Once the write succeeds, the cached block will @@ -1402,7 +1402,7 @@ TEST(GcsFileSystemTest, NewAppendableFile) { TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/path/appendable", nullptr, &rfile)); char scratch[100]; - StringPiece result; + absl::string_view result; TF_EXPECT_OK(rfile->Read(0, 8, &result, scratch)); EXPECT_EQ("content1", result); // Closing the appendable file will flush its contents to GCS, triggering HTTP @@ -1496,8 +1496,9 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile( "gs://bucket/path/random_access.txt", nullptr, ®ion)); - EXPECT_EQ(content, StringPiece(reinterpret_cast(region->data()), - region->length())); + EXPECT_EQ(content, + absl::string_view(reinterpret_cast(region->data()), + region->length())); } TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { @@ -2262,7 +2263,7 @@ TEST(GcsFileSystemTest, DeleteFile) { // Do an initial read of the file to load its contents into the block cache. char scratch[100]; - StringPiece result; + absl::string_view result; std::unique_ptr file; TF_EXPECT_OK( fs.NewRandomAccessFile("gs://bucket/path/file1.txt", nullptr, &file)); @@ -2656,7 +2657,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) { // Do an initial read of the source and destination files to load their // contents into the block cache. char scratch[100]; - StringPiece result; + absl::string_view result; std::unique_ptr src; std::unique_ptr dst; TF_EXPECT_OK( @@ -3798,7 +3799,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) { fs.NewRandomAccessFile("gs://bucket/random_access.txt", nullptr, &file)); char scratch[6]; - StringPiece result; + absl::string_view result; TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch)); EXPECT_EQ("012345", result); diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc index dfd8310953272f..658629fab6bb2b 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/gcs_throttle_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/cloud/gcs_throttle.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/str_util.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc index f1b62fb0b26c9b..7f1f94dc778e1e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.cc @@ -82,7 +82,7 @@ bool IsFile(const string& filename) { } /// Returns the credentials file name from the env variable. -Status GetEnvironmentVariableFileName(string* filename) { +absl::Status GetEnvironmentVariableFileName(string* filename) { if (!filename) { return errors::FailedPrecondition("'filename' cannot be nullptr."); } @@ -92,11 +92,11 @@ Status GetEnvironmentVariableFileName(string* filename) { " is not set or corrupt.")); } *filename = result; - return OkStatus(); + return absl::OkStatus(); } /// Returns the well known file produced by command 'gcloud auth login'. -Status GetWellKnownFileName(string* filename) { +absl::Status GetWellKnownFileName(string* filename) { if (!filename) { return errors::FailedPrecondition("'filename' cannot be nullptr."); } @@ -118,7 +118,7 @@ Status GetWellKnownFileName(string* filename) { "Could not find the credentials file in the standard gcloud location."); } *filename = result; - return OkStatus(); + return absl::OkStatus(); } } // namespace @@ -138,42 +138,42 @@ GoogleAuthProvider::GoogleAuthProvider( std::move(compute_engine_metadata_client)), env_(env) {} -Status GoogleAuthProvider::GetToken(string* t) { +absl::Status GoogleAuthProvider::GetToken(string* t) { mutex_lock lock(mu_); const uint64 now_sec = env_->NowSeconds(); if (now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) { *t = current_token_; - return OkStatus(); + return absl::OkStatus(); } if (GetTokenForTesting().ok()) { *t = current_token_; - return OkStatus(); + return absl::OkStatus(); } auto token_from_files_status = GetTokenFromFiles(); if (token_from_files_status.ok()) { *t = current_token_; - return OkStatus(); + return absl::OkStatus(); } char* no_gce_check_var = std::getenv(kNoGceCheck); bool skip_gce_check = no_gce_check_var != nullptr && absl::EqualsIgnoreCase(no_gce_check_var, "true"); - Status token_from_gce_status; + absl::Status token_from_gce_status; if (skip_gce_check) { token_from_gce_status = - Status(absl::StatusCode::kCancelled, - strings::StrCat("GCE check skipped due to presence of $", - kNoGceCheck, " environment variable.")); + absl::Status(absl::StatusCode::kCancelled, + strings::StrCat("GCE check skipped due to presence of $", + kNoGceCheck, " environment variable.")); } else { token_from_gce_status = GetTokenFromGce(); } if (token_from_gce_status.ok()) { *t = current_token_; - return OkStatus(); + return absl::OkStatus(); } if (skip_gce_check) { @@ -203,10 +203,10 @@ Status GoogleAuthProvider::GetToken(string* t) { } current_token_ = ""; - return OkStatus(); + return absl::OkStatus(); } -Status GoogleAuthProvider::GetTokenFromFiles() { +absl::Status GoogleAuthProvider::GetTokenFromFiles() { string credentials_filename; if (!GetEnvironmentVariableFileName(&credentials_filename).ok() && !GetWellKnownFileName(&credentials_filename).ok()) { @@ -231,33 +231,33 @@ Status GoogleAuthProvider::GetTokenFromFiles() { return errors::FailedPrecondition( "Unexpected content of the JSON credentials file."); } - return OkStatus(); + return absl::OkStatus(); } -Status GoogleAuthProvider::GetTokenFromGce() { +absl::Status GoogleAuthProvider::GetTokenFromGce() { std::vector response_buffer; const uint64 request_timestamp_sec = env_->NowSeconds(); TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata( kGceTokenPath, &response_buffer)); - StringPiece response = - StringPiece(&response_buffer[0], response_buffer.size()); + absl::string_view response = + absl::string_view(&response_buffer[0], response_buffer.size()); TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse( response, request_timestamp_sec, ¤t_token_, &expiration_timestamp_sec_)); - return OkStatus(); + return absl::OkStatus(); } -Status GoogleAuthProvider::GetTokenForTesting() { +absl::Status GoogleAuthProvider::GetTokenForTesting() { const char* token = std::getenv(kGoogleAuthTokenForTesting); if (!token) { return errors::NotFound("The env variable for testing was not set."); } expiration_timestamp_sec_ = UINT64_MAX; current_token_ = token; - return OkStatus(); + return absl::OkStatus(); } } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h index 63b7ea63abf5f3..38ab66df63c25c 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider.h @@ -40,20 +40,20 @@ class GoogleAuthProvider : public AuthProvider { /// \brief Returns the short-term authentication bearer token. /// /// Safe for concurrent use by multiple threads. - Status GetToken(string* token) override; + absl::Status GetToken(string* token) override; private: /// \brief Gets the bearer token from files. /// /// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the /// standard gcloud tool's location. - Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status GetTokenFromFiles() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Gets the bearer token from Google Compute Engine environment. - Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status GetTokenFromGce() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Gets the bearer token from the system env variable, for testing purposes. - Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); + absl::Status GetTokenForTesting() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); std::unique_ptr oauth_client_; std::shared_ptr compute_engine_metadata_client_; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc index 6f3072fdf3b83e..e7d6c4aab68634 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/google_auth_provider_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cloud/http_request_fake.h" #include "tsl/platform/path.h" #include "tsl/platform/test.h" @@ -40,23 +40,24 @@ class FakeEnv : public EnvWrapper { class FakeOAuthClient : public OAuthClient { public: - Status GetTokenFromServiceAccountJson( - Json::Value json, StringPiece oauth_server_uri, StringPiece scope, - string* token, uint64* expiration_timestamp_sec) override { + absl::Status GetTokenFromServiceAccountJson( + Json::Value json, absl::string_view oauth_server_uri, + absl::string_view scope, string* token, + uint64* expiration_timestamp_sec) override { provided_credentials_json = json; *token = return_token; *expiration_timestamp_sec = return_expiration_timestamp; - return OkStatus(); + return absl::OkStatus(); } /// Retrieves a bearer token using a refresh token. - Status GetTokenFromRefreshTokenJson( - Json::Value json, StringPiece oauth_server_uri, string* token, + absl::Status GetTokenFromRefreshTokenJson( + Json::Value json, absl::string_view oauth_server_uri, string* token, uint64* expiration_timestamp_sec) override { provided_credentials_json = json; *token = return_token; *expiration_timestamp_sec = return_expiration_timestamp; - return OkStatus(); + return absl::OkStatus(); } string return_token; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h index a3a3136d66e6f7..8102dd666a7c08 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request.h @@ -85,7 +85,8 @@ class HttpRequest { /// RecordResponse is called after the response has been received. virtual void RecordResponse(const HttpRequest* request, const string& uri, - RequestMethod method, const Status& result) = 0; + RequestMethod method, + const absl::Status& result) = 0; }; HttpRequest() {} @@ -124,7 +125,8 @@ class HttpRequest { /// /// The request body will be taken from the specified file starting from /// the given offset. - virtual Status SetPutFromFile(const string& body_filepath, size_t offset) = 0; + virtual absl::Status SetPutFromFile(const string& body_filepath, + size_t offset) = 0; /// Makes the request a PUT request with an empty body. virtual void SetPutEmptyBody() = 0; @@ -169,7 +171,7 @@ class HttpRequest { /// /// If the result buffer was defined, the response will be written there. /// The object is not designed to be re-used after Send() is executed. - virtual Status Send() = 0; + virtual absl::Status Send() = 0; // Url encodes str and returns a new string. virtual string EscapeString(const string& str) = 0; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h index ea1f487516795e..869d2abca6ee70 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/http_request_fake.h @@ -21,7 +21,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cloud/curl_http_request.h" #include "tsl/platform/errors.h" #include "tsl/platform/macros.h" @@ -38,12 +38,13 @@ class FakeHttpRequest : public CurlHttpRequest { public: /// Return the response for the given request. FakeHttpRequest(const string& request, const string& response) - : FakeHttpRequest(request, response, OkStatus(), nullptr, {}, 200) {} + : FakeHttpRequest(request, response, absl::OkStatus(), nullptr, {}, 200) { + } /// Return the response with headers for the given request. FakeHttpRequest(const string& request, const string& response, const std::map& response_headers) - : FakeHttpRequest(request, response, OkStatus(), nullptr, + : FakeHttpRequest(request, response, absl::OkStatus(), nullptr, response_headers, 200) {} /// \brief Return the response for the request and capture the POST body. @@ -51,12 +52,12 @@ class FakeHttpRequest : public CurlHttpRequest { /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string& request, const string& response, string* captured_post_body) - : FakeHttpRequest(request, response, OkStatus(), captured_post_body, {}, - 200) {} + : FakeHttpRequest(request, response, absl::OkStatus(), captured_post_body, + {}, 200) {} /// \brief Return the response and the status for the given request. FakeHttpRequest(const string& request, const string& response, - Status response_status, uint64 response_code) + absl::Status response_status, uint64 response_code) : FakeHttpRequest(request, response, response_status, nullptr, {}, response_code) {} @@ -65,7 +66,7 @@ class FakeHttpRequest : public CurlHttpRequest { /// /// Post body is not expected to be a part of the 'request' parameter. FakeHttpRequest(const string& request, const string& response, - Status response_status, string* captured_post_body, + absl::Status response_status, string* captured_post_body, const std::map& response_headers, uint64 response_code) : expected_request_(request), @@ -88,20 +89,21 @@ class FakeHttpRequest : public CurlHttpRequest { actual_request_ += "Auth Token: " + auth_token + "\n"; } void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; } - Status SetPutFromFile(const string& body_filepath, size_t offset) override { + absl::Status SetPutFromFile(const string& body_filepath, + size_t offset) override { std::ifstream stream(body_filepath); const string& content = string(std::istreambuf_iterator(stream), std::istreambuf_iterator()) .substr(offset); actual_request_ += "Put body: " + content + "\n"; - return OkStatus(); + return absl::OkStatus(); } void SetPostFromBuffer(const char* buffer, size_t size) override { if (captured_post_body_) { *captured_post_body_ = string(buffer, size); } else { actual_request_ += - strings::StrCat("Post body: ", StringPiece(buffer, size), "\n"); + strings::StrCat("Post body: ", absl::string_view(buffer, size), "\n"); } } void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; } @@ -123,7 +125,7 @@ class FakeHttpRequest : public CurlHttpRequest { size_t GetResultBufferDirectBytesTransferred() override { return direct_result_bytes_transferred_; } - Status Send() override { + absl::Status Send() override { EXPECT_EQ(expected_request_, actual_request()) << "Unexpected HTTP request."; if (buffer_) { @@ -182,7 +184,7 @@ class FakeHttpRequest : public CurlHttpRequest { string actual_uri_; string actual_request_; string response_; - Status response_status_; + absl::Status response_status_; string* captured_post_body_ = nullptr; std::map response_headers_; uint64 response_code_ = 0; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc index c983577204b436..74806805abea51 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.cc @@ -49,8 +49,8 @@ constexpr char kJwtType[] = "JWT"; constexpr char kGrantType[] = "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer"; -Status ReadJsonValue(const Json::Value& json, const string& name, - Json::Value* value) { +absl::Status ReadJsonValue(const Json::Value& json, const string& name, + Json::Value* value) { if (!value) { return errors::FailedPrecondition("'value' cannot be nullptr."); } @@ -59,11 +59,11 @@ Status ReadJsonValue(const Json::Value& json, const string& name, return errors::FailedPrecondition( strings::StrCat("Couldn't read a JSON value '", name, "'.")); } - return OkStatus(); + return absl::OkStatus(); } -Status ReadJsonString(const Json::Value& json, const string& name, - string* value) { +absl::Status ReadJsonString(const Json::Value& json, const string& name, + string* value) { Json::Value json_value; TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isString()) { @@ -71,11 +71,11 @@ Status ReadJsonString(const Json::Value& json, const string& name, strings::StrCat("JSON value '", name, "' is not string.")); } *value = json_value.asString(); - return OkStatus(); + return absl::OkStatus(); } -Status ReadJsonInt(const Json::Value& json, const string& name, - int64_t* value) { +absl::Status ReadJsonInt(const Json::Value& json, const string& name, + int64_t* value) { Json::Value json_value; TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value)); if (!json_value.isIntegral()) { @@ -83,11 +83,11 @@ Status ReadJsonInt(const Json::Value& json, const string& name, strings::StrCat("JSON value '", name, "' is not integer.")); } *value = json_value.asInt64(); - return OkStatus(); + return absl::OkStatus(); } -Status CreateSignature(RSA* private_key, StringPiece to_sign, - string* signature) { +absl::Status CreateSignature(RSA* private_key, absl::string_view to_sign, + string* signature) { if (!private_key || !signature) { return errors::FailedPrecondition( "'private_key' and 'signature' cannot be nullptr."); @@ -126,14 +126,15 @@ Status CreateSignature(RSA* private_key, StringPiece to_sign, if (EVP_DigestSignFinal(md_ctx.get(), sig.get(), &sig_len) != 1) { return errors::Internal("DigestFinal (signature compute) failed."); } - return Base64Encode(StringPiece(reinterpret_cast(sig.get()), sig_len), - signature); + return Base64Encode( + absl::string_view(reinterpret_cast(sig.get()), sig_len), + signature); } /// Encodes a claim for a JSON web token (JWT) to make an OAuth request. -Status EncodeJwtClaim(StringPiece client_email, StringPiece scope, - StringPiece audience, uint64 request_timestamp_sec, - string* encoded) { +absl::Status EncodeJwtClaim(absl::string_view client_email, + absl::string_view scope, absl::string_view audience, + uint64 request_timestamp_sec, string* encoded) { // Step 1: create the JSON with the claim. Json::Value root; root["iss"] = Json::Value(client_email.data(), @@ -155,7 +156,7 @@ Status EncodeJwtClaim(StringPiece client_email, StringPiece scope, } /// Encodes a header for a JSON web token (JWT) to make an OAuth request. -Status EncodeJwtHeader(StringPiece key_id, string* encoded) { +absl::Status EncodeJwtHeader(absl::string_view key_id, string* encoded) { // Step 1: create the JSON with the header. Json::Value root; root["alg"] = kCryptoAlgorithm; @@ -180,9 +181,9 @@ OAuthClient::OAuthClient( std::unique_ptr http_request_factory, Env* env) : http_request_factory_(std::move(http_request_factory)), env_(env) {} -Status OAuthClient::GetTokenFromServiceAccountJson( - Json::Value json, StringPiece oauth_server_uri, StringPiece scope, - string* token, uint64* expiration_timestamp_sec) { +absl::Status OAuthClient::GetTokenFromServiceAccountJson( + Json::Value json, absl::string_view oauth_server_uri, + absl::string_view scope, string* token, uint64* expiration_timestamp_sec) { if (!token || !expiration_timestamp_sec) { return errors::FailedPrecondition( "'token' and 'expiration_timestamp_sec' cannot be nullptr."); @@ -228,15 +229,15 @@ Status OAuthClient::GetTokenFromServiceAccountJson( request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); - StringPiece response = - StringPiece(response_buffer.data(), response_buffer.size()); + absl::string_view response = + absl::string_view(response_buffer.data(), response_buffer.size()); TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, expiration_timestamp_sec)); - return OkStatus(); + return absl::OkStatus(); } -Status OAuthClient::GetTokenFromRefreshTokenJson( - Json::Value json, StringPiece oauth_server_uri, string* token, +absl::Status OAuthClient::GetTokenFromRefreshTokenJson( + Json::Value json, absl::string_view oauth_server_uri, string* token, uint64* expiration_timestamp_sec) { if (!token || !expiration_timestamp_sec) { return errors::FailedPrecondition( @@ -260,17 +261,17 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); - StringPiece response = - StringPiece(response_buffer.data(), response_buffer.size()); + absl::string_view response = + absl::string_view(response_buffer.data(), response_buffer.size()); TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token, expiration_timestamp_sec)); - return OkStatus(); + return absl::OkStatus(); } -Status OAuthClient::ParseOAuthResponse(StringPiece response, - uint64 request_timestamp_sec, - string* token, - uint64* expiration_timestamp_sec) { +absl::Status OAuthClient::ParseOAuthResponse(absl::string_view response, + uint64 request_timestamp_sec, + string* token, + uint64* expiration_timestamp_sec) { if (!token || !expiration_timestamp_sec) { return errors::FailedPrecondition( "'token' and 'expiration_timestamp_sec' cannot be nullptr."); @@ -292,7 +293,7 @@ Status OAuthClient::ParseOAuthResponse(StringPiece response, *expiration_timestamp_sec = request_timestamp_sec + expires_in; TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h index 895c2d01cd5f8f..19d8b4fb1589c2 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client.h @@ -37,20 +37,20 @@ class OAuthClient { /// /// Retrieves the authentication bearer token using a JSON file /// with the client's private key. - virtual Status GetTokenFromServiceAccountJson( - Json::Value json, StringPiece oauth_server_uri, StringPiece scope, - string* token, uint64* expiration_timestamp_sec); + virtual absl::Status GetTokenFromServiceAccountJson( + Json::Value json, absl::string_view oauth_server_uri, + absl::string_view scope, string* token, uint64* expiration_timestamp_sec); /// Retrieves a bearer token using a refresh token. - virtual Status GetTokenFromRefreshTokenJson(Json::Value json, - StringPiece oauth_server_uri, - string* token, - uint64* expiration_timestamp_sec); + virtual absl::Status GetTokenFromRefreshTokenJson( + Json::Value json, absl::string_view oauth_server_uri, string* token, + uint64* expiration_timestamp_sec); /// Parses the JSON response with the token from an OAuth 2.0 server. - virtual Status ParseOAuthResponse(StringPiece response, - uint64 request_timestamp_sec, string* token, - uint64* expiration_timestamp_sec); + virtual absl::Status ParseOAuthResponse(absl::string_view response, + uint64 request_timestamp_sec, + string* token, + uint64* expiration_timestamp_sec); private: std::unique_ptr http_request_factory_; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc index 8979f4442ef13b..dc4c116583b400 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/oauth_client_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/base64.h" #include "tsl/platform/cloud/http_request_fake.h" #include "tsl/platform/env.h" @@ -118,7 +118,7 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { EXPECT_EQ(13920, expiration_timestamp); // Now look at the JWT claim that was sent to the OAuth server. - StringPiece grant_type, assertion; + absl::string_view grant_type, assertion; ASSERT_TRUE(strings::Scanner(post_body) .OneLiteral("grant_type=") .RestartCapture() diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc index f16ab818b92687..57330d1ad30afe 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.cc @@ -68,12 +68,12 @@ void RamFileBlockCache::Trim() { } /// Move the block to the front of the LRU list if it isn't already there. -Status RamFileBlockCache::UpdateLRU(const Key& key, - const std::shared_ptr& block) { +absl::Status RamFileBlockCache::UpdateLRU(const Key& key, + const std::shared_ptr& block) { mutex_lock lock(mu_); if (block->timestamp == 0) { // The block was evicted from another thread. Allow it to remain evicted. - return OkStatus(); + return absl::OkStatus(); } if (block->lru_iterator != lru_list_.begin()) { lru_list_.erase(block->lru_iterator); @@ -95,11 +95,11 @@ Status RamFileBlockCache::UpdateLRU(const Key& key, Trim(); - return OkStatus(); + return absl::OkStatus(); } -Status RamFileBlockCache::MaybeFetch(const Key& key, - const std::shared_ptr& block) { +absl::Status RamFileBlockCache::MaybeFetch( + const Key& key, const std::shared_ptr& block) { bool downloaded_block = false; auto reconcile_state = absl::MakeCleanup([this, &downloaded_block, &key, &block] { @@ -123,7 +123,7 @@ Status RamFileBlockCache::MaybeFetch(const Key& key, // Loop until either block content is successfully fetched, or our request // encounters an error. mutex_lock l(block->mu); - Status status = OkStatus(); + absl::Status status = absl::OkStatus(); while (true) { switch (block->state) { case FetchState::ERROR: @@ -155,23 +155,24 @@ Status RamFileBlockCache::MaybeFetch(const Key& key, case FetchState::FETCHING: block->cond_var.wait_for(l, std::chrono::seconds(60)); if (block->state == FetchState::FINISHED) { - return OkStatus(); + return absl::OkStatus(); } // Re-loop in case of errors. break; case FetchState::FINISHED: - return OkStatus(); + return absl::OkStatus(); } } return errors::Internal( "Control flow should never reach the end of RamFileBlockCache::Fetch."); } -Status RamFileBlockCache::Read(const string& filename, size_t offset, size_t n, - char* buffer, size_t* bytes_transferred) { +absl::Status RamFileBlockCache::Read(const string& filename, size_t offset, + size_t n, char* buffer, + size_t* bytes_transferred) { *bytes_transferred = 0; if (n == 0) { - return OkStatus(); + return absl::OkStatus(); } if (!IsCacheEnabled() || (n > max_bytes_)) { // The cache is effectively disabled, so we pass the read through to the @@ -226,7 +227,7 @@ Status RamFileBlockCache::Read(const string& filename, size_t offset, size_t n, } } *bytes_transferred = total_bytes_transferred; - return OkStatus(); + return absl::OkStatus(); } bool RamFileBlockCache::ValidateAndUpdateFileSignature(const string& filename, diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h index 76cf7eb237dc53..627cf6f2c808fa 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache.h @@ -45,9 +45,9 @@ class RamFileBlockCache : public FileBlockCache { /// cache is constructed. The returned Status should be OK as long as the /// read from the remote filesystem succeeded (similar to the semantics of the /// read(2) system call). - typedef std::function + typedef std::function BlockFetcher; RamFileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness, @@ -88,8 +88,8 @@ class RamFileBlockCache : public FileBlockCache { /// placed in `out`. /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed /// in `out`). - Status Read(const string& filename, size_t offset, size_t n, char* buffer, - size_t* bytes_transferred) override; + absl::Status Read(const string& filename, size_t offset, size_t n, + char* buffer, size_t* bytes_transferred) override; // Validate the given file signature with the existing file signature in the // cache. Returns true if the signature doesn't change or the file doesn't @@ -197,14 +197,14 @@ class RamFileBlockCache : public FileBlockCache { /// Look up a Key in the block cache. std::shared_ptr Lookup(const Key& key) TF_LOCKS_EXCLUDED(mu_); - Status MaybeFetch(const Key& key, const std::shared_ptr& block) + absl::Status MaybeFetch(const Key& key, const std::shared_ptr& block) TF_LOCKS_EXCLUDED(mu_); /// Trim the block cache to make room for another entry. void Trim() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); /// Update the LRU iterator for the block at `key`. - Status UpdateLRU(const Key& key, const std::shared_ptr& block) + absl::Status UpdateLRU(const Key& key, const std::shared_ptr& block) TF_LOCKS_EXCLUDED(mu_); /// Remove all blocks of a file, with mu_ already held. diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc index 5d17d737661e70..cc716011a9b26e 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/ram_file_block_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/cloud/now_seconds_env.h" #include "tsl/platform/env.h" @@ -27,12 +27,12 @@ limitations under the License. namespace tsl { namespace { -Status ReadCache(RamFileBlockCache* cache, const string& filename, - size_t offset, size_t n, std::vector* out) { +absl::Status ReadCache(RamFileBlockCache* cache, const string& filename, + size_t offset, size_t n, std::vector* out) { out->clear(); out->resize(n, 0); size_t bytes_transferred = 0; - Status status = + absl::Status status = cache->Read(filename, offset, n, out->data(), &bytes_transferred); EXPECT_LE(bytes_transferred, n); out->resize(bytes_transferred, n); @@ -43,7 +43,7 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) { auto fetcher = [](const string& filename, size_t offset, size_t n, char* buffer, size_t* bytes_transferred) { // Do nothing. - return OkStatus(); + return absl::OkStatus(); }; RamFileBlockCache cache1(0, 0, 0, fetcher); RamFileBlockCache cache2(16, 0, 0, fetcher); @@ -63,7 +63,7 @@ TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) { calls++; memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; string filename = "file"; RamFileBlockCache cache(16, 32, 0, fetcher); @@ -99,7 +99,7 @@ TEST(RamFileBlockCacheTest, PassThrough) { calls++; memset(buffer, 'x', got_n); *bytes_transferred = got_n; - return OkStatus(); + return absl::OkStatus(); }; // If block_size, max_bytes, or both are zero, or want_n is larger than // max_bytes the cache is a pass-through. @@ -136,7 +136,7 @@ TEST(RamFileBlockCacheTest, BlockAlignment) { } else { *bytes_transferred = 0; } - return OkStatus(); + return absl::OkStatus(); }; for (size_t block_size = 2; block_size <= 4; block_size++) { // Make a cache of N-byte block size (1 block) and verify that reads of @@ -181,7 +181,7 @@ TEST(RamFileBlockCacheTest, CacheHits) { calls.insert(offset); memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; const uint32 block_count = 256; RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher); @@ -222,7 +222,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) { second_block = true; } *bytes_transferred = bytes_to_copy; - return OkStatus(); + return absl::OkStatus(); }; RamFileBlockCache cache(block_size, block_size, 0, fetcher); std::vector out; @@ -233,7 +233,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) { // Reading at offset file_size + 4 will read the second block (since the read // at file_size + 4 = 28 will be aligned to an offset of 16) but will return // OutOfRange because the offset is past the end of the 24-byte file. - Status status = ReadCache(&cache, "", file_size + 4, 4, &out); + absl::Status status = ReadCache(&cache, "", file_size + 4, 4, &out); EXPECT_EQ(status.code(), error::OUT_OF_RANGE); EXPECT_TRUE(second_block); // Reading the second full block will return 8 bytes, from a cache hit. @@ -255,7 +255,7 @@ TEST(RamFileBlockCacheTest, Inconsistent) { EXPECT_GE(n, 1); memset(buffer, 'x', 1); *bytes_transferred = 1; - return OkStatus(); + return absl::OkStatus(); }; RamFileBlockCache cache(block_size, 2 * block_size, 0, fetcher); std::vector out; @@ -264,7 +264,7 @@ TEST(RamFileBlockCacheTest, Inconsistent) { EXPECT_EQ(out.size(), 1); // Now read the first block; this should yield an INTERNAL error because we // had already cached a partial block at a later position. - Status status = ReadCache(&cache, "", 0, block_size, &out); + absl::Status status = ReadCache(&cache, "", 0, block_size, &out); EXPECT_EQ(status.code(), error::INTERNAL); } @@ -282,7 +282,7 @@ TEST(RamFileBlockCacheTest, LRU) { } memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; const uint32 block_count = 2; RamFileBlockCache cache(block_size, block_count * block_size, 0, fetcher); @@ -324,7 +324,7 @@ TEST(RamFileBlockCacheTest, MaxStaleness) { calls++; memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; std::vector out; std::unique_ptr env(new NowSecondsEnv); @@ -369,7 +369,7 @@ TEST(RamFileBlockCacheTest, RemoveFile) { } memset(buffer, c, n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; // This cache has space for 4 blocks; we'll read from two files. const size_t n = 3; @@ -426,7 +426,7 @@ TEST(RamFileBlockCacheTest, Prune) { calls++; memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; std::vector out; // Our fake environment is initialized with the current timestamp. @@ -493,7 +493,7 @@ TEST(RamFileBlockCacheTest, ParallelReads) { } memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; const int block_size = 8; RamFileBlockCache cache(block_size, 2 * callers * block_size, 0, fetcher); @@ -529,7 +529,7 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) { notification.Notify(); // Wait for other thread to issue read. Env::Default()->SleepForMicroseconds(100000); // 0.1 secs - return OkStatus(); + return absl::OkStatus(); }; RamFileBlockCache cache(block_size, block_size, 0, fetcher); // Fork off thread for parallel read. @@ -554,7 +554,7 @@ TEST(RamFileBlockCacheTest, Flush) { calls++; memset(buffer, 'x', n); *bytes_transferred = n; - return OkStatus(); + return absl::OkStatus(); }; RamFileBlockCache cache(16, 32, 0, fetcher); std::vector out; diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc index fecba6b1c8f742..62f8258bbf3ec0 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.cc @@ -34,7 +34,7 @@ constexpr int64_t kNanosecondsPerSecond = 1000 * 1000 * 1000; // Only implements one special case of RFC 3339 which is returned by // GCS API, e.g 2016-04-29T23:15:24.896Z. -Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) { +absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) { tm parsed{0}; float seconds; if (sscanf(time.c_str(), "%4d-%2d-%2dT%2d:%2d:%fZ", &(parsed.tm_year), @@ -52,7 +52,7 @@ Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec) { static_cast(std::floor((seconds - int_seconds) * kNanosecondsPerSecond)); - return OkStatus(); + return absl::OkStatus(); } } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h index 5eb116c6aca75d..4dd2d29ff15772 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util.h @@ -22,7 +22,7 @@ namespace tsl { /// Parses the timestamp in RFC 3339 format and returns it /// as nanoseconds since epoch. -Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec); +absl::Status ParseRfc3339Time(const string& time, int64_t* mtime_nsec); } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc index 3a9655589539e1..6b54787f6bd309 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/time_util_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tsl/platform/cloud/time_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace tsl { diff --git a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h index 8c000e08437d4e..14b64ea9955aae 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h +++ b/third_party/xla/third_party/tsl/tsl/platform/cloud/zone_provider.h @@ -34,9 +34,9 @@ class ZoneProvider { /// Returns an empty string in the case where the zone does not match the /// expected format /// Safe for concurrent use by multiple threads. - virtual Status GetZone(string* zone) = 0; + virtual absl::Status GetZone(string* zone) = 0; - static Status GetZone(ZoneProvider* provider, string* zone) { + static absl::Status GetZone(ZoneProvider* provider, string* zone) { if (!provider) { return errors::Internal("Zone provider is required."); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD index f777279540bc2b..20d43489eefa67 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/BUILD +++ b/third_party/xla/third_party/tsl/tsl/platform/default/BUILD @@ -3,6 +3,7 @@ load("@bazel_skylib//:bzl_library.bzl", "bzl_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load( "@local_xla//xla/tsl:tsl.bzl", + "if_hermetic_cuda_tools", "if_not_fuchsia", "if_not_windows", "if_oss", @@ -59,6 +60,9 @@ cc_library( srcs = ["cuda_libdevice_path.cc"], hdrs = ["//tsl/platform:cuda_libdevice_path.h"], compatible_with = [], + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:nvvm", + ]), tags = [ "manual", "no_oss", @@ -66,6 +70,7 @@ cc_library( ], deps = [ "//tsl/platform", + "//tsl/platform:env", "//tsl/platform:logging", "//tsl/platform:path", "//tsl/platform:types", diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc index 46321e74b5dc38..ac0a804b4dfd42 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/cuda_libdevice_path.cc @@ -31,6 +31,7 @@ limitations under the License. #if !defined(PLATFORM_GOOGLE) #include "third_party/gpus/cuda/cuda_config.h" +#include "tsl/platform/env.h" #endif #include "tsl/platform/logging.h" @@ -38,8 +39,25 @@ namespace tsl { std::vector CandidateCudaRoots() { #if !defined(PLATFORM_GOOGLE) - auto roots = std::vector{TF_CUDA_TOOLKIT_PATH, - std::string("/usr/local/cuda")}; + auto roots = std::vector{}; + std::string runfiles_suffix = "runfiles"; + + // The CUDA candidate root for c++ targets. + std::string executable_path = tsl::Env::Default()->GetExecutablePath(); + std::string cuda_nvcc_dir = + io::JoinPath(executable_path + "." + runfiles_suffix, "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + // The CUDA candidate root for python targets. + std::string runfiles_dir = tsl::Env::Default()->GetRunfilesDir(); + std::size_t runfiles_ind = runfiles_dir.rfind(runfiles_suffix); + cuda_nvcc_dir = io::JoinPath( + runfiles_dir.substr(0, runfiles_ind + runfiles_suffix.length()), + "cuda_nvcc"); + roots.emplace_back(cuda_nvcc_dir); + + roots.emplace_back(TF_CUDA_TOOLKIT_PATH); + roots.emplace_back(std::string("/usr/local/cuda")); #if defined(PLATFORM_POSIX) && !defined(__APPLE__) Dl_info info; @@ -53,13 +71,17 @@ std::vector CandidateCudaRoots() { // relative to the current binary for the wheel-based nvcc package. for (auto path : {"../nvidia/cuda_nvcc", "../../nvidia/cuda_nvcc"}) roots.emplace_back(io::JoinPath(dir, path)); + + // Also add the path to the copy of libdevice.10.bc that we include within + // the Python wheel. + roots.emplace_back(io::JoinPath(dir, "cuda")); } #endif // defined(PLATFORM_POSIX) && !defined(__APPLE__) for (auto root : roots) VLOG(3) << "CUDA root = " << root; return roots; #else // !defined(PLATFORM_GOOGLE) - return {std::string("/usr/local/cuda")}; + return {}; #endif //! defined(PLATFORM_GOOGLE) } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc index 868fb35f887dab..e5dbff497ad710 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/port.cc @@ -411,6 +411,13 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { void AlignedFree(void* aligned_memory) { Free(aligned_memory); } +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size) { + (void)alignment; + (void)size; + + Free(aligned_memory); +} + void* Malloc(size_t size) { return malloc(size); } void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } diff --git a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc index c5373b9fb2e859..a1934f81e35723 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/default/rocm_rocdl_path.cc @@ -26,7 +26,7 @@ limitations under the License. namespace tsl { -string RocmRoot() { +std::string RocmRoot() { #if TENSORFLOW_USE_ROCM if (const char* rocm_path_env = std::getenv("ROCM_PATH")) { VLOG(3) << "ROCM root = " << rocm_path_env; @@ -40,12 +40,11 @@ string RocmRoot() { #endif } -string RocdlRoot() { +std::string RocdlRoot() { if (const char* device_lib_path_env = std::getenv("HIP_DEVICE_LIB_PATH")) { - return device_lib_path_env; - } - else{ - return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); + return device_lib_path_env; + } else { + return io::JoinPath(RocmRoot(), "amdgcn/bitcode"); } } diff --git a/third_party/xla/third_party/tsl/tsl/platform/mem.h b/third_party/xla/third_party/tsl/tsl/platform/mem.h index 0f32727f0f753d..6d0dc803e93b80 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/mem.h +++ b/third_party/xla/third_party/tsl/tsl/platform/mem.h @@ -28,6 +28,7 @@ namespace port { // and a multiple of sizeof(void*). void* AlignedMalloc(size_t size, int minimum_alignment); void AlignedFree(void* aligned_memory); +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size); void* Malloc(size_t size); void* Realloc(void* ptr, size_t size); diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc index 8477cdb353e21f..33792c8ecfd293 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_file_system_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/str_util.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc index 5d55ec31cc2f20..00241685d00d5d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/retrying_utils_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "absl/time/time.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/str_util.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc index 1b1bbcb3113e17..1b01a8c4c4d54d 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/subprocess_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/path.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc index f8e19503edb305..57600173577329 100644 --- a/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc +++ b/third_party/xla/third_party/tsl/tsl/platform/windows/port.cc @@ -211,6 +211,13 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); } +void AlignedSizedFree(void* aligned_memory, size_t alignment, size_t size) { + (void)alignment; + (void)size; + + _aligned_free(aligned_memory); +} + void* Malloc(size_t size) { return malloc(size); } void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc index 438f98c2b3ef24..4943fba0c1bfea 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/nvtx_utils.cc @@ -36,7 +36,7 @@ namespace { // nvtxNameOsThreadA: // https://nvidia.github.io/NVTX/doxygen/group___r_e_s_o_u_r_c_e___n_a_m_i_n_g.html // This convention may not match the one in tsl::Env::GetCurrentThreadId(). -std::optional GetCurrentThreadId() { +std::optional MaybeGetCurrentThreadId() { #ifdef __linux__ return syscall(SYS_gettid); #else @@ -57,7 +57,8 @@ ProfilerDomainHandle DefaultProfilerDomain() { } void NameCurrentThread(const std::string& thread_name) { - if (std::optional tid = GetCurrentThreadId(); tid.has_value()) { + if (std::optional tid = MaybeGetCurrentThreadId(); + tid.has_value()) { nvtxNameOsThreadA(*tid, thread_name.c_str()); } } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h index 478dae87b8a399..ef303663b3d142 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/profiler_lock.h @@ -46,9 +46,9 @@ class ProfilerLock { ProfilerLock& operator=(const ProfilerLock&) = delete; // Movable. - ProfilerLock(ProfilerLock&& other) + ProfilerLock(ProfilerLock&& other) noexcept : active_(std::exchange(other.active_, false)) {} - ProfilerLock& operator=(ProfilerLock&& other) { + ProfilerLock& operator=(ProfilerLock&& other) noexcept { active_ = std::exchange(other.active_, false); return *this; } diff --git a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h index 75c2902f323d05..da9fe210737dd9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/lib/traceme.h @@ -146,8 +146,8 @@ class TraceMe { } // Movable. - TraceMe(TraceMe&& other) { *this = std::move(other); } - TraceMe& operator=(TraceMe&& other) { + TraceMe(TraceMe&& other) noexcept { *this = std::move(other); } + TraceMe& operator=(TraceMe&& other) noexcept { #if !defined(IS_MOBILE_PLATFORM) if (TF_PREDICT_FALSE(other.start_time_ != kUntracedActivity)) { name_.Emplace(std::move(other.name_).Consume()); diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD index fdf0979d82ea69..141f1b6e6edf82 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/BUILD @@ -29,7 +29,6 @@ cc_library( "@local_xla//xla/python:__pkg__", "//tensorflow/core/profiler/rpc/client:__pkg__", "//tensorflow/python/profiler/internal:__pkg__", - "//learning/pathways/data_parallel:__pkg__", ]), deps = [ ":profiler_client_for_pybind", @@ -62,6 +61,7 @@ cc_library( "@local_xla//xla/python:__pkg__", "//tsl/profiler:internal", "//tsl/profiler/rpc:__pkg__", + "//learning/pathways/data_parallel:__pkg__", ]), deps = [ "//tsl/lib/io:zlib_compression_options", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc index 8bc9a1986effb7..47d8638005931c 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/client/profiler_client.cc @@ -17,10 +17,10 @@ limitations under the License. #include #include -#include "grpcpp/grpcpp.h" #include "absl/memory/memory.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "grpcpp/grpcpp.h" // IWYU pragma: keep #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_server.cc b/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_server.cc index 8b598fa450cdc6..f619c1346a0af4 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_server.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_server.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include -#include "grpcpp/grpcpp.h" #include "absl/strings/str_cat.h" +#include "grpcpp/grpcpp.h" // IWYU pragma: keep #include "tsl/platform/logging.h" #include "tsl/platform/types.h" #include "tsl/profiler/protobuf/profiler_service.grpc.pb.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc b/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc index 8deee9782aa9fe..efb544ebdf2278 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/rpc/profiler_service_impl.cc @@ -17,9 +17,9 @@ limitations under the License. #include -#include "grpcpp/support/status.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/str_replace.h" +#include "grpcpp/support/status.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" #include "tsl/platform/errors.h" diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD index 203657d0744e82..39113ebb7fc07f 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/BUILD @@ -363,6 +363,7 @@ tsl_cc_test( ":tpu_xplane_utils", ":xplane_schema", ":xplane_utils", + ":xplane_visitor", "//tsl/platform:test", "//tsl/platform:test_main", "//tsl/profiler/protobuf:xplane_proto_cc", diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc index 19841f53ce7fdb..9274a1da941743 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "tsl/platform/regexp.h" #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/xplane_schema.h" @@ -48,5 +49,11 @@ std::optional GetTensorCoreId(absl::string_view plane_name) { return std::nullopt; } +std::optional GetSparseCoreId(absl::string_view plane_name) { + std::optional core_id; + RE2::FullMatch(plane_name, {kSparseCorePlaneRegex}, &core_id); + return core_id; +} + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h index f3a150ca37e607..2fb7c677e3a058 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils.h @@ -36,6 +36,10 @@ std::vector FindMutableTensorCorePlanes( // TensorCore plane name. std::optional GetTensorCoreId(absl::string_view plane_name); +// Get Sparsecore Id from SparseCore plane name if plane name is a valid +// SparseCore plane name. +std::optional GetSparseCoreId(absl::string_view plane_name); + } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc index a385c77821c347..e5bcd73c339be9 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/tpu_xplane_utils_test.cc @@ -21,11 +21,13 @@ limitations under the License. #include "tsl/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/utils/xplane_schema.h" #include "tsl/profiler/utils/xplane_utils.h" +#include "tsl/profiler/utils/xplane_visitor.h" namespace tsl { namespace profiler { namespace { +using ::testing::Optional; using ::testing::UnorderedElementsAre; TEST(TpuXPlaneUtilsTest, GetTensorCoreXPlanesFromXSpace) { @@ -65,6 +67,22 @@ TEST(TpuXPlaneUtilsTest, IsNotTensorCorePlaneNameWithPrefix) { GetTensorCoreId(absl::StrCat("/prefix", TpuPlaneName(0))).has_value()); } +TEST(TpuXplaneUtilsTest, GetSparseCorePlanesFromXSpace) { + XSpace space; + XPlane* p1 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(0)); + XPlane* p2 = FindOrAddMutablePlaneWithName(&space, TpuPlaneName(1)); + XPlane* p3 = FindOrAddMutablePlaneWithName( + &space, absl::StrCat(TpuPlaneName(0), " SparseCore 0")); + XPlane* p4 = FindOrAddMutablePlaneWithName( + &space, absl::StrCat(TpuPlaneName(0), " SparseCore 1")); + + EXPECT_THAT(FindTensorCorePlanes(space), UnorderedElementsAre(p1, p2)); + EXPECT_THAT(FindPlanesWithPrefix(space, kTpuPlanePrefix), + UnorderedElementsAre(p1, p2, p3, p4)); + EXPECT_THAT(GetSparseCoreId(p3->name()), Optional(0)); + EXPECT_THAT(GetSparseCoreId(p4->name()), Optional(1)); +} + } // namespace } // namespace profiler } // namespace tsl diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc index 33de2b0f6c3e19..2cd8aaa74b55b0 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.cc @@ -33,6 +33,8 @@ const absl::string_view kGpuPlanePrefix = "/device:GPU:"; const absl::string_view kTpuPlanePrefix = "/device:TPU:"; const absl::string_view kTpuNonCorePlaneNamePrefix = "#Chip"; const char kTpuPlaneRegex[] = {"/device:TPU:([0-9]*)$"}; +const char kSparseCorePlaneRegex[] = { + "/device:TPU:[0-9]+ SparseCore ([0-9]+)$"}; // TODO(b/195582092): change it to /device:custom once all literals are // migrated. const absl::string_view kCustomPlanePrefix = "/device:CUSTOM:"; diff --git a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h index 2e693b4474b92d..edf808b864648e 100644 --- a/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h +++ b/third_party/xla/third_party/tsl/tsl/profiler/utils/xplane_schema.h @@ -41,6 +41,8 @@ TF_CONST_INIT extern const absl::string_view kGpuPlanePrefix; TF_CONST_INIT extern const absl::string_view kTpuPlanePrefix; // Regex for XPlanes that contain TensorCore planes. TF_CONST_INIT extern const char kTpuPlaneRegex[]; +// Regex for XPlanes that contain TPU Core planes. +TF_CONST_INIT extern const char kSparseCorePlaneRegex[]; // Name prefix of XPlane that contains custom device events. TF_CONST_INIT extern const absl::string_view kCustomPlanePrefix; // Name prefix of XPlane that contains TPU non-core events such as HBM, ICI etc. diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD index 65000ff408801c..10e1dac5abc717 100644 --- a/third_party/xla/third_party/tsl/tsl/protobuf/BUILD +++ b/third_party/xla/third_party/tsl/tsl/protobuf/BUILD @@ -20,13 +20,6 @@ package( licenses = ["notice"], ) -tf_proto_library( - name = "bfc_memory_map_proto", - srcs = ["bfc_memory_map.proto"], - make_default_target_header_only = True, - visibility = ["//visibility:public"], -) - tf_proto_library( name = "dnn_proto", srcs = ["dnn.proto"], @@ -123,7 +116,7 @@ tf_proto_library( protodeps = [ # TODO(tlongeri): Conceptually, these fit into protos_all but adding them currently causes # breakages (and they are not actually used). - ":bfc_memory_map_proto", + "@local_xla//xla/tsl/protobuf:bfc_memory_map_proto", ":coordination_config_proto", ":distributed_runtime_payloads_proto", ":error_codes_proto_impl", diff --git a/third_party/xla/third_party/tsl/workspace2.bzl b/third_party/xla/third_party/tsl/workspace2.bzl index 0a2993f3542ba4..7b85e735b1f880 100644 --- a/third_party/xla/third_party/tsl/workspace2.bzl +++ b/third_party/xla/third_party/tsl/workspace2.bzl @@ -17,14 +17,12 @@ load("//third_party/eigen3:workspace.bzl", eigen3 = "repo") load("//third_party/farmhash:workspace.bzl", farmhash = "repo") load("//third_party/gemmlowp:workspace.bzl", gemmlowp = "repo") load("//third_party/git:git_configure.bzl", "git_configure") -load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") load("//third_party/gpus:sycl_configure.bzl", "sycl_configure") load("//third_party/hwloc:workspace.bzl", hwloc = "repo") load("//third_party/implib_so:workspace.bzl", implib_so = "repo") load("//third_party/llvm:setup.bzl", "llvm_setup") load("//third_party/nasm:workspace.bzl", nasm = "repo") -load("//third_party/nccl:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "python_configure") load("//third_party/py/ml_dtypes:workspace.bzl", ml_dtypes = "repo") load("//third_party/pybind11_abseil:workspace.bzl", pybind11_abseil = "repo") @@ -69,9 +67,7 @@ def _tf_toolchains(): # Note that we check the minimum bazel version in WORKSPACE. clang6_configure(name = "local_config_clang6") cc_download_clang_toolchain(name = "local_config_download_clang") - cuda_configure(name = "local_config_cuda") tensorrt_configure(name = "local_config_tensorrt") - nccl_configure(name = "local_config_nccl") git_configure(name = "local_config_git") syslibs_configure(name = "local_config_syslibs") python_configure(name = "local_config_python") @@ -160,13 +156,13 @@ def _tf_repositories(): tf_http_archive( name = "mkl_dnn_acl_compatible", - build_file = "//tensorflow/third_party/mkl_dnn:mkldnn_acl.BUILD", + build_file = "//third_party/mkl_dnn:mkldnn_acl.BUILD", patch_file = [ - "//tensorflow/third_party/mkl_dnn:onednn_acl_threadcap.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_reorder.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", - "//tensorflow/third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", + "//third_party/mkl_dnn:onednn_acl_threadcap.patch", + "//third_party/mkl_dnn:onednn_acl_reorder.patch", + "//third_party/mkl_dnn:onednn_acl_thread_local_scheduler.patch", + "//third_party/mkl_dnn:onednn_acl_fp32_bf16_reorder.patch", + "//third_party/mkl_dnn:onednn_acl_bf16_capability_detection_for_ubuntu20.04.patch", ], sha256 = "2f76b407ef8893cca71340f88cd800019a1f14f8ac1bbdbb89a84be1370b52e3", strip_prefix = "oneDNN-3.2.1", @@ -560,9 +556,9 @@ def _tf_repositories(): tf_http_archive( name = "pybind11", - urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.10.0.tar.gz"), - sha256 = "eacf582fa8f696227988d08cfc46121770823839fe9e301a20fbce67e7cd70ec", - strip_prefix = "pybind11-2.10.0", + urls = tf_mirror_urls("https://github.com/pybind/pybind11/archive/v2.13.4.tar.gz"), + sha256 = "efc901aa0aab439a3fea6efeaf930b5a349fb06394bf845c64ce15a9cf8f0240", + strip_prefix = "pybind11-2.13.4", build_file = "//third_party:pybind11.BUILD", system_build_file = "//third_party/systemlibs:pybind11.BUILD", ) @@ -591,6 +587,22 @@ def _tf_repositories(): urls = tf_mirror_urls("https://github.com/google/glog/archive/refs/tags/v0.4.0.tar.gz"), ) + tf_http_archive( + name = "spirv_headers", + sha256 = "11d835c60297b26532c05c3f3b581ba7a2787b5ae7399e94f72c392169216f11", + strip_prefix = "SPIRV-Headers-b73e168ca5e123dcf3dea8a34b19a5130f421ae1", + urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-Headers/archive/b73e168ca5e123dcf3dea8a34b19a5130f421ae1.tar.gz"), + ) + + tf_http_archive( + name = "spirv_llvm_translator", + sha256 = "d499769f4fd1e0ce9d4dbd3622ee7e3e641b5623dcdf811521e3e7c0bdb1e6c2", + strip_prefix = "SPIRV-LLVM-Translator-dad1f0eaab8047a4f73c50ed5f3d1694b78aae97", + build_file = "//third_party/spirv_llvm_translator:spirv_llvm_translator.BUILD", + patch_file = ["//third_party/spirv_llvm_translator:spirv_llvm_translator.patch"], + urls = tf_mirror_urls("https://github.com/KhronosGroup/SPIRV-LLVM-Translator/archive/dad1f0eaab8047a4f73c50ed5f3d1694b78aae97.tar.gz"), + ) + # buildifier: disable=unnamed-macro def workspace(): # Check the bazel version before executing any repository rules, in case diff --git a/third_party/xla/third_party/uv/BUILD b/third_party/xla/third_party/uv/BUILD new file mode 100644 index 00000000000000..3c413807167aeb --- /dev/null +++ b/third_party/xla/third_party/uv/BUILD @@ -0,0 +1 @@ +# copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) diff --git a/third_party/xla/third_party/uv/uv.BUILD b/third_party/xla/third_party/uv/uv.BUILD new file mode 100644 index 00000000000000..43c194a53ea516 --- /dev/null +++ b/third_party/xla/third_party/uv/uv.BUILD @@ -0,0 +1,82 @@ +# Description: +# libuv is a cross-platform asynchronous I/O library. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "uv", + srcs = [ + "src/fs-poll.c", + "src/idna.c", + "src/inet.c", + "src/random.c", + "src/strscpy.c", + "src/threadpool.c", + "src/timer.c", + "src/uv-common.c", + "src/uv-data-getter-setters.c", + "src/version.c", + ] + [ + "src/unix/async.c", + "src/unix/core.c", + "src/unix/dl.c", + "src/unix/fs.c", + "src/unix/getaddrinfo.c", + "src/unix/getnameinfo.c", + "src/unix/loop.c", + "src/unix/loop-watcher.c", + "src/unix/pipe.c", + "src/unix/poll.c", + "src/unix/process.c", + "src/unix/random-devurandom.c", + "src/unix/signal.c", + "src/unix/stream.c", + "src/unix/tcp.c", + "src/unix/thread.c", + "src/unix/tty.c", + "src/unix/udp.c", + ] + select({ + "@platforms//os:osx": [ + "src/unix/bsd-ifaddrs.c", + "src/unix/darwin.c", + "src/unix/darwin-proctitle.c", + "src/unix/fsevents.c", + "src/unix/kqueue.c", + "src/unix/proctitle.c", + "src/unix/random-getentropy.c", + ], + }), + # TODO: Add Linux, etc. as in https://github.com/libuv/libuv/blob/v1.38.0/CMakeLists.txt. + hdrs = [ + "include/uv.h", + "src/heap-inl.h", + "src/idna.h", + "src/queue.h", + "src/strscpy.h", + "src/unix/atomic-ops.h", + "src/unix/internal.h", + "src/unix/spinlock.h", + "src/uv-common.h", + ] + select({ + "@platforms//os:osx": [ + "src/unix/darwin-stub.h", + ], + }) + glob(["include/uv/*.h"]), + copts = [ + "-fexceptions", + "-Wno-unused-variable", + ], + includes = [ + "include", + "src", + ], + textual_hdrs = [ + "include/uv.h", + ], +) diff --git a/third_party/xla/third_party/uv/workspace.bzl b/third_party/xla/third_party/uv/workspace.bzl new file mode 100644 index 00000000000000..8d26ab4dcd41b5 --- /dev/null +++ b/third_party/xla/third_party/uv/workspace.bzl @@ -0,0 +1,17 @@ +"""Provides the repository macro to import libuv.""" + +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") + +def repo(): + """Imports libuv.""" + + UV_VERSION = "v1.38.0" + UV_SHA256 = "71344f62c5020ed3643ad0bcba98ae4d7d6037285923c5416844d7c141a3ff93" + + tf_http_archive( + name = "uv", + sha256 = UV_SHA256, + strip_prefix = "libuv-{version}".format(version = UV_VERSION), + urls = tf_mirror_urls("https://dist.libuv.org/dist/{version}/libuv-{version}.tar.gz".format(version = UV_VERSION)), + build_file = "//third_party/uv:uv.BUILD", + ) diff --git a/third_party/xla/tools/toolchains/remote_config/configs.bzl b/third_party/xla/tools/toolchains/remote_config/configs.bzl index 0c28198f980b95..9a4dfa2aafdc51 100644 --- a/third_party/xla/tools/toolchains/remote_config/configs.bzl +++ b/third_party/xla/tools/toolchains/remote_config/configs.bzl @@ -225,8 +225,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -236,8 +236,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "9.1", + cuda_version = "12.3.2", + cudnn_version = "9.1.1", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -248,8 +248,8 @@ def initialize_rbe_configs(): name = "ubuntu20.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu20.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -258,8 +258,8 @@ def initialize_rbe_configs(): tensorflow_rbe_config( name = "ubuntu22.04-clang_manylinux2014-cuda12.3-cudnn8.9", compiler = "/usr/lib/llvm-18/bin/clang", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], sysroot = "/dt9", @@ -270,8 +270,8 @@ def initialize_rbe_configs(): name = "ubuntu22.04-gcc9_manylinux2014-cuda12.3-cudnn8.9", compiler = "/dt9/usr/bin/gcc", compiler_prefix = "/usr/bin", - cuda_version = "12.3", - cudnn_version = "8.9", + cuda_version = "12.3.2", + cudnn_version = "8.9.7.29", os = "ubuntu22.04-manylinux2014-multipython", python_versions = ["3.9", "3.10", "3.11", "3.12"], python_install_path = "/usr/local", @@ -479,7 +479,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -558,7 +558,7 @@ def initialize_rbe_configs(): "TF_CUDNN_VERSION": "8.6", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.4", }, @@ -710,11 +710,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -749,11 +749,11 @@ def initialize_rbe_configs(): "TENSORRT_INSTALL_PATH": "/usr/lib/x86_64-linux-gnu", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_SYSROOT": "/dt9", "TF_TENSORRT_VERSION": "8.6", }, @@ -788,12 +788,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "0", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -826,12 +826,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "8.9", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "8.9.7.29", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "8.6", }, ) @@ -864,12 +864,12 @@ def initialize_rbe_configs(): "PYTHON_BIN_PATH": "/usr/bin/python3", "TF_CUDA_CLANG": "1", "TF_CUDA_COMPUTE_CAPABILITIES": "3.5,6.0", - "TF_CUDA_VERSION": "12.3", - "TF_CUDNN_VERSION": "9.1", + "TF_CUDA_VERSION": "12.3.2", + "TF_CUDNN_VERSION": "9.1.1", "TF_ENABLE_XLA": "1", "TF_NEED_CUDA": "1", "TF_SYSROOT": "/dt9", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": "10.0", }, ) diff --git a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl index 18a84d96c39f82..ec2ac4cc8ea430 100644 --- a/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl +++ b/third_party/xla/tools/toolchains/remote_config/rbe_config.bzl @@ -1,8 +1,8 @@ """Macro that creates external repositories for remote config.""" -load("//third_party/gpus:cuda_configure.bzl", "remote_cuda_configure") +load("@local_config_cuda//cuda/hermetic:cuda_configure.bzl", "cuda_configure") load("//third_party/gpus:rocm_configure.bzl", "remote_rocm_configure") -load("//third_party/nccl:nccl_configure.bzl", "remote_nccl_configure") +load("//third_party/nccl/hermetic:nccl_configure.bzl", "nccl_configure") load("//third_party/py:python_configure.bzl", "local_python_configure", "remote_python_configure") load("//third_party/remote_config:remote_platform_configure.bzl", "remote_platform_configure") load("//third_party/tensorrt:tensorrt_configure.bzl", "remote_tensorrt_configure") @@ -42,7 +42,7 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_CUDNN_VERSION": cudnn_version, "TF_CUDA_VERSION": cuda_version, "CUDNN_INSTALL_PATH": cudnn_install_path if cudnn_install_path != None else "/usr/lib/x86_64-linux-gnu", - "TF_NEED_TENSORRT": "1", + "TF_NEED_TENSORRT": "0", "TF_TENSORRT_VERSION": tensorrt_version if tensorrt_version != None else "", "TENSORRT_INSTALL_PATH": tensorrt_install_path if tensorrt_install_path != None else "/usr/lib/x86_64-linux-gnu", "GCC_HOST_COMPILER_PATH": compiler if not compiler.endswith("clang") else "", @@ -51,20 +51,26 @@ def _tensorflow_rbe_config(name, compiler, python_versions, os, rocm_version = N "TF_SYSROOT": sysroot if sysroot else "", }) - container_name = "cuda%s-cudnn%s-%s" % (cuda_version, cudnn_version, os) + cuda_version_in_container = ".".join(cuda_version.split(".")[:2]) + cudnn_version_in_container = ".".join(cudnn_version.split(".")[:2]) + container_name = "cuda%s-cudnn%s-%s" % ( + cuda_version_in_container, + cudnn_version_in_container, + os, + ) container_image = _container_image_uri(container_name) exec_properties = { "container-image": container_image, "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, @@ -175,13 +181,13 @@ def sigbuild_tf_configs(name_container_map, env): "Pool": "default", } - remote_cuda_configure( + cuda_configure( name = "%s_config_cuda" % name, environ = env, exec_properties = exec_properties, ) - remote_nccl_configure( + nccl_configure( name = "%s_config_nccl" % name, environ = env, exec_properties = exec_properties, diff --git a/third_party/xla/warnings.bazelrc b/third_party/xla/warnings.bazelrc index a5711c9d6c6394..ae92c8c9db8472 100644 --- a/third_party/xla/warnings.bazelrc +++ b/third_party/xla/warnings.bazelrc @@ -4,13 +4,8 @@ build:warnings --copt=-Werror --host_copt=-Werror # ...and silence them outside of the workspace. build:warnings --per_file_copt=external/.*@-w -# ...and silence them on host builds. There is no host_per_file_copt and -# everything we build in the host configuration we either also build in the -# target configuration or is external, so we can't control it. -# If/when Bazel supports --host_per_file_copt, we could use that instead: -# https://github.com/bazelbuild/bazel/issues/12406. -# Would need to then make all the --copt below duplicated with --host_copt. -build:warnings --host_copt=-w +# ...and silence them on host builds. +build:warnings --host_per_file_copt=external/.*@-w build:warnings --copt=-Wall build:warnings --copt=-Werror @@ -93,7 +88,5 @@ build:warnings --copt=-Wno-final-dtor-non-final-class build:warnings --copt=-Wnon-virtual-dtor build:warnings --copt=-Wimplicit-fallthrough build:warnings --copt=-Wthread-safety-analysis -build:warnings --copt=-Wno-tautological-type-limit-compare -build:warnings --copt=-Wno-nullability-completeness build:warnings --copt=-Wno-builtin-macro-redefined build:warnings --copt=-Wno-macro-redefined diff --git a/third_party/xla/workspace0.bzl b/third_party/xla/workspace0.bzl index 76b8ed2bbae1f2..f0b37ee94921f4 100644 --- a/third_party/xla/workspace0.bzl +++ b/third_party/xla/workspace0.bzl @@ -5,6 +5,7 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load("@build_bazel_apple_support//lib:repositories.bzl", "apple_support_dependencies") load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies") load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies") +load("@com_google_benchmark//:bazel/benchmark_deps.bzl", "benchmark_deps") load("@local_tsl//:workspace0.bzl", "tsl_workspace0") def _tf_bind(): @@ -125,6 +126,9 @@ def workspace(): swift_rules_dependencies() apple_support_dependencies() + # We only need `benchmark_deps` to be able to have bazel query to work and not complain about missing `@libpfm`. + benchmark_deps() + # If a target is bound twice, the later one wins, so we have to do tf bindings # at the end of the WORKSPACE file. _tf_bind() diff --git a/third_party/xla/workspace2.bzl b/third_party/xla/workspace2.bzl index e2244c1ae9d216..dea8d378e31806 100644 --- a/third_party/xla/workspace2.bzl +++ b/third_party/xla/workspace2.bzl @@ -16,6 +16,7 @@ load("//third_party/robin_map:workspace.bzl", robin_map = "repo") load("//third_party/shardy:workspace.bzl", shardy = "repo") load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") load("//third_party/triton:workspace.bzl", triton = "repo") +load("//third_party/uv:workspace.bzl", uv = "repo") def _initialize_third_party(): """ Load third party repositories. See above load() statements. """ @@ -27,6 +28,7 @@ def _initialize_third_party(): shardy() stablehlo() triton() + uv() # Define all external repositories required by TensorFlow def _tf_repositories(): diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index b3abffcac9c390..bd161a52b8757e 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -516,6 +516,7 @@ xla_cc_test( ":test", ":util", ":xla_data_proto_cc", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -632,13 +633,13 @@ xla_cc_test( ":types", ":util", ":xla_data_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/hash", "@com_google_absl//absl/random", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:macros", @@ -692,8 +693,8 @@ xla_cc_test( ":literal_comparison", ":literal_util", ":test_helpers", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test_main", ], @@ -960,7 +961,7 @@ xla_cc_test( ":test_helpers", ":text_literal_writer", ":types", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test_main", ], @@ -1048,7 +1049,9 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":literal", ":literal_util", + ":shape_util", ":util", ":window_util", ":xla_data_proto_cc", @@ -1056,9 +1059,11 @@ cc_library( "//xla/client:xla_builder", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:shape_inference", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", @@ -1072,7 +1077,9 @@ xla_cc_test( ":array2d", ":array3d", ":array4d", + ":error_spec", ":literal", + ":literal_util", ":reference_util", ":test", ":xla_data_proto_cc", @@ -1127,6 +1134,7 @@ cc_library( [ ":parse_flags_from_env", ":xla_proto_cc", + "//xla/stream_executor/cuda:nvjitlink_support", "//xla/stream_executor/cuda:ptx_compiler_support", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", @@ -1254,10 +1262,10 @@ cc_library( deps = [ ":autotune_results_proto_cc", ":autotuning_proto_cc", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/strings:proto_serialization", ], ) @@ -1312,6 +1320,33 @@ cc_library( visibility = ["//visibility:public"], ) +cc_library( + name = "sort_json", + srcs = ["sort_json.cc"], + hdrs = ["sort_json.h"], + visibility = ["//visibility:public"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "sort_json_test", + srcs = ["sort_json_test.cc"], + deps = [ + ":sort_json", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + # Needed to workaround https://github.com/bazelbuild/bazel/issues/21519 alias( name = "bazel_issue_21519", diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 6a6f50574e1d9e..03c5f3b9760c4b 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -603,12 +603,12 @@ class Array { std::fill(data.get(), data.get() + size, init); } - OwnedBuffer(OwnedBuffer&& other) + OwnedBuffer(OwnedBuffer&& other) noexcept : data(std::move(other.data)), size(other.size) { other.size = 0; } - OwnedBuffer& operator=(OwnedBuffer&& other) { + OwnedBuffer& operator=(OwnedBuffer&& other) noexcept { data = std::move(other.data); size = other.size; other.size = 0; diff --git a/third_party/xla/xla/autotune_result_wrapper.cc b/third_party/xla/xla/autotune_result_wrapper.cc index 855c8aaeb13f5d..ee92f173d6a4d4 100644 --- a/third_party/xla/xla/autotune_result_wrapper.cc +++ b/third_party/xla/xla/autotune_result_wrapper.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/runtime/BUILD b/third_party/xla/xla/backends/cpu/runtime/BUILD similarity index 95% rename from third_party/xla/xla/service/cpu/runtime/BUILD rename to third_party/xla/xla/backends/cpu/runtime/BUILD index 03d54f6adf1305..56ded99d3af407 100644 --- a/third_party/xla/xla/service/cpu/runtime/BUILD +++ b/third_party/xla/xla/backends/cpu/runtime/BUILD @@ -1,7 +1,7 @@ load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") load("//xla:xla.bzl", "xla_cc_test") load("//xla/service/cpu:build_defs.bzl", "runtime_copts") -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup") package( @@ -19,13 +19,16 @@ package_group( filegroup( name = "runtime_srcs", - srcs = ["conv_impl.cc"], + srcs = [ + "convolution_thunk_f16.cc", + "convolution_thunk_f32.cc", + ], visibility = internal_visibility([":friends"]), ) filegroup( name = "runtime_hdrs", - srcs = ["conv_impl.h"], + srcs = ["convolution_thunk_internal.h"], visibility = internal_visibility([":friends"]), ) @@ -53,8 +56,8 @@ xla_cc_test( "//xla/service:buffer_assignment", "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -112,8 +115,8 @@ xla_cc_test( "//xla:executable_run_options", "//xla/service/cpu:collectives_interface", "//xla/service/cpu:cpu_executable_run_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -124,6 +127,7 @@ cc_library( name = "thunk_executor", srcs = ["thunk_executor.cc"], hdrs = ["thunk_executor.h"], + defines = if_windows(["_ENABLE_EXTENDED_ALIGNED_STORAGE"]), deps = [ ":resource_use", ":thunk", @@ -159,11 +163,12 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor:device_memory", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -229,8 +234,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -270,16 +275,18 @@ cc_library( ) cc_library( - name = "conv_impl", - srcs = ["conv_impl.cc"], - hdrs = ["conv_impl.h"], + name = "convolution_thunk_internal", + srcs = [ + "convolution_thunk_f16.cc", + "convolution_thunk_f32.cc", + ], + hdrs = ["convolution_thunk_internal.h"], copts = runtime_copts(), visibility = internal_visibility([":friends"]), deps = [ "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:mutex", # build_cleaner: keep ], ) @@ -289,7 +296,7 @@ cc_library( hdrs = ["convolution_thunk.h"], copts = runtime_copts(), deps = [ - ":conv_impl", + ":convolution_thunk_internal", ":thunk", "//xla:executable_run_options", "//xla:shape_util", @@ -538,7 +545,7 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -555,7 +562,9 @@ cc_library( "//xla:util", "//xla/ffi:attribute_map", "//xla/ffi:call_frame", + "//xla/ffi:execution_state", "//xla/ffi:ffi_api", + "//xla/ffi/api:c_api", "//xla/runtime:buffer_use", "//xla/service:buffer_assignment", "//xla/service:custom_call_status", @@ -566,6 +575,7 @@ cc_library( "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", @@ -663,8 +673,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -754,8 +764,8 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -776,7 +786,10 @@ cc_library( "//xla/stream_executor/host:host_kernel", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/numeric:bits", "@com_google_absl//absl/status", @@ -785,7 +798,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", - "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:numbers", @@ -806,9 +818,10 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -908,9 +921,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor/host:host_kernel_c_api", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -923,6 +936,7 @@ cc_library( srcs = ["while_thunk.cc"], hdrs = ["while_thunk.h"], deps = [ + ":buffer_allocations", ":thunk", ":thunk_executor", "//xla/runtime:buffer_use", @@ -933,6 +947,8 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -955,9 +971,9 @@ xla_cc_test( "//xla/service:maybe_owning_device_memory", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@eigen_archive//:eigen3", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/backends/cpu/runtime/README.md b/third_party/xla/xla/backends/cpu/runtime/README.md new file mode 100644 index 00000000000000..84d313e5a2afe4 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/README.md @@ -0,0 +1,16 @@ +# XLA:CPU Runtime + +XLA:CPU runtime is implemented as a collection of `Thunks` that are responsible +for executing individual operations. XLA fusions, for example are jit-compiled +to executables using LLVM, and executed at run time by `KernelThunk`. Operations +that are not compiled have corresponding thunks, i.e., `FFT` operations is +executed as `FftThunk` and relies on DUCC FFT implementation. + +Thunks are executed concurrently using `ThunkExecutor`, which launches thunks +when all data dependencies are ready. We rely on buffer assignment to track read +and write conflicts, and compute a directed acyclic graph that defines execution +order. + +Conceptually, XLA:CPU runtime is similar to XLA:GPU, which also has thunks. +However, for CPU backend we do a lot more multi-threading to be able to +efficiently use all available cores on the host CPU. diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc similarity index 95% rename from third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc index 3bb705ebf9fcd2..fa55bbc48dbffc 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_gather_thunk.h" +#include "xla/backends/cpu/runtime/all_gather_thunk.h" #include #include @@ -24,11 +24,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h index 28ba6c6ace84a1..2d2dca9a7eac9d 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_gather_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_gather_thunk.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -40,4 +40,4 @@ class AllGatherThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_GATHER_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_GATHER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc index 923d03ce7fd464..a5d9d283867c2d 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_reduce_thunk.h" +#include "xla/backends/cpu/runtime/all_reduce_thunk.h" #include #include @@ -26,12 +26,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h index f4580b0f63be45..77866382353e02 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_reduce_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_reduce_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -45,4 +45,4 @@ class AllReduceThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_REDUCE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc index d55486602d6546..8badd0c4e7e232 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/all_to_all_thunk.h" +#include "xla/backends/cpu/runtime/all_to_all_thunk.h" #include #include @@ -23,11 +23,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h index 0c24627354829b..b58afe94394572 100644 --- a/third_party/xla/xla/service/cpu/runtime/all_to_all_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/all_to_all_thunk.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -40,4 +40,4 @@ class AllToAllThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_ALL_TO_ALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/buffer_allocations.h rename to third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h index fe26d441359b76..44d71712a9c19c 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations.h +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ -#define XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#define XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ #include #include @@ -45,7 +45,7 @@ class BufferAllocations { // Same as above, but also adjusts the returned address for the offset and // size contained in the given slice. absl::StatusOr GetDeviceAddress( - const BufferAllocation::Slice& slice) const; + BufferAllocation::Slice slice) const; // Unchecked version of `GetDeviceAddress` that does not check the buffer // index and assumes it is valid. @@ -55,16 +55,19 @@ class BufferAllocations { // Unchecked version of `GetDeviceAddress` that does not check the slice // buffer index, offset and size and assumes they all are valid. se::DeviceMemoryBase GetDeviceAddressUnchecked( - const BufferAllocation::Slice& slice) const; + BufferAllocation::Slice slice) const; private: std::vector buffers_; + se::DeviceMemoryBase* buffers_data_; // buffers_.data() size_t num_buffers_; }; inline BufferAllocations::BufferAllocations( absl::Span buffers) - : buffers_(buffers.size()), num_buffers_(buffers_.size()) { + : buffers_(buffers.size()), + buffers_data_(buffers_.data()), + num_buffers_(buffers_.size()) { for (size_t i = 0; i < buffers.size(); ++i) { buffers_[i] = buffers[i].AsDeviceMemoryBase(); } @@ -82,8 +85,7 @@ BufferAllocations::GetDeviceAddress(BufferAllocation::Index index) const { } inline ABSL_ATTRIBUTE_ALWAYS_INLINE absl::StatusOr -BufferAllocations::GetDeviceAddress( - const BufferAllocation::Slice& slice) const { +BufferAllocations::GetDeviceAddress(BufferAllocation::Slice slice) const { // Handle empty slices explicitly and return a null pointer device memory to // guarantee that we do not accidentally write through the empty slice which // would hide a real bug in the code. @@ -97,7 +99,7 @@ BufferAllocations::GetDeviceAddress( "Invalid buffer index %d. It must be in the range [0, %d)", index, num_buffers_); } - const se::DeviceMemoryBase& base = buffers_[index]; + const se::DeviceMemoryBase& base = buffers_data_[index]; int64_t offset = slice.offset(); int64_t extent = offset + slice.size(); @@ -125,17 +127,18 @@ BufferAllocations::GetDeviceAddress( inline ABSL_ATTRIBUTE_ALWAYS_INLINE se::DeviceMemoryBase BufferAllocations::GetDeviceAddressUnchecked( BufferAllocation::Index buffer_index) const { - return buffers_[buffer_index]; + return buffers_data_[buffer_index]; } // Unchecked version of `GetDeviceAddress` that does not check the slice // buffer index, offset and size and assumes they are valid. inline ABSL_ATTRIBUTE_ALWAYS_INLINE se::DeviceMemoryBase BufferAllocations::GetDeviceAddressUnchecked( - const BufferAllocation::Slice& slice) const { - return buffers_[slice.index()].GetByteSlice(slice.offset(), slice.size()); + BufferAllocation::Slice slice) const { + return buffers_data_[slice.index()].GetByteSlice(slice.offset(), + slice.size()); } } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_BUFFER_ALLOCATIONS_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc rename to third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc index 9fd7d447825de7..c92be6205ac910 100644 --- a/third_party/xla/xla/service/cpu/runtime/buffer_allocations_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/buffer_allocations_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" #include #include diff --git a/third_party/xla/xla/service/cpu/runtime/call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/call_thunk.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/call_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/call_thunk.cc index a0a4d2bf5c9673..0473ad78e40f49 100644 --- a/third_party/xla/xla/service/cpu/runtime/call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/call_thunk.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/call_thunk.h" +#include "xla/backends/cpu/runtime/call_thunk.h" #include #include #include "absl/memory/memory.h" #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" diff --git a/third_party/xla/xla/service/cpu/runtime/call_thunk.h b/third_party/xla/xla/backends/cpu/runtime/call_thunk.h similarity index 84% rename from third_party/xla/xla/service/cpu/runtime/call_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/call_thunk.h index e6c9ecbd3544e8..b7addf7297c392 100644 --- a/third_party/xla/xla/service/cpu/runtime/call_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/call_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ #include #include "absl/status/statusor.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -45,4 +45,4 @@ class CallThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc index 1908c3ff66e40c..a830c0f7fd4ea1 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/collective_permute_thunk.h" +#include "xla/backends/cpu/runtime/collective_permute_thunk.h" #include #include @@ -28,12 +28,12 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h similarity index 86% rename from third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h index 6478ced6f1e939..702b2f2b15f3dd 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_permute_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_permute_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -51,4 +51,4 @@ class CollectivePermuteThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_PERMUTE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc similarity index 95% rename from third_party/xla/xla/service/cpu/runtime/collective_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc index a0cd9f4936cb33..4bebdd09cd31c1 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include #include @@ -32,13 +32,13 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/computation_placer.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/status_macros.h" @@ -205,10 +205,6 @@ const Shape& CollectiveThunk::source_shape(int64_t index) const { return op_buffers_.source_shapes[index]; } -absl::Span CollectiveThunk::source_shapes() const { - return op_buffers_.source_shapes; -} - const BufferAllocation::Slice& CollectiveThunk::destination_buffer( int64_t index) const { return op_buffers_.destination_buffers[index]; @@ -223,8 +219,4 @@ const Shape& CollectiveThunk::destination_shape(int64_t index) const { return op_buffers_.destination_shapes[index]; } -absl::Span CollectiveThunk::destination_shapes() const { - return op_buffers_.destination_shapes; -} - } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h similarity index 90% rename from third_party/xla/xla/service/cpu/runtime/collective_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/collective_thunk.h index 5bcf16b4e10d5c..8efc767838806d 100644 --- a/third_party/xla/xla/service/cpu/runtime/collective_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/collective_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ #include #include @@ -27,11 +27,11 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory.h" @@ -77,7 +77,6 @@ class CollectiveThunk : public Thunk { OpBuffers op_buffers, OpResources op_resources); const OpParams& op_params() const { return op_params_; } - const OpBuffers& op_buffers() const { return op_buffers_; } // Resolves operation's device memory from the buffers and buffer allocations. absl::StatusOr GetOpDeviceMemory(const ExecuteParams& params); @@ -109,13 +108,11 @@ class CollectiveThunk : public Thunk { absl::Span source_buffers() const; const Shape& source_shape(int64_t index) const; - absl::Span source_shapes() const; const BufferAllocation::Slice& destination_buffer(int64_t index) const; absl::Span destination_buffers() const; const Shape& destination_shape(int64_t index) const; - absl::Span destination_shapes() const; private: OpParams op_params_; @@ -125,4 +122,4 @@ class CollectiveThunk : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COLLECTIVE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COLLECTIVE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc index 4ee46a975e6217..42246dd1d3df51 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.h b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h similarity index 85% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h index 6185b6dad9b27b..0b01d8517a6ff4 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -48,4 +48,4 @@ class ConditionalThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CONDITIONAL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONDITIONAL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc index d24a58dec3edcc..a5222a8de6bb3d 100644 --- a/third_party/xla/xla/service/cpu/runtime/conditional_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/conditional_thunk_test.cc @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" #include #include #include #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_testlib.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc index c7bdd0a2ccf18e..e4dd0ef3f98ce2 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" #define EIGEN_USE_THREADS @@ -31,10 +31,10 @@ limitations under the License. #include "absl/types/span.h" #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/conv_impl.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_conv2d_acl.h" #include "xla/shape.h" #include "xla/status_macros.h" @@ -328,7 +328,7 @@ ConvolutionThunk::HandleEigen2DConvolution(const ExecuteParams& params, std::optional> done_callback = std::nullopt) { using scalar_type = decltype(type_tag); - ::tensorflow::xla::EigenConv2DImpl( + internal::EigenConv2D( eigen_device, static_cast(output.opaque()), static_cast(input.opaque()), static_cast(kernel.opaque()), input_batch_, input_dims_.x, @@ -368,7 +368,7 @@ ConvolutionThunk::HandleEigen3DConvolution(const ExecuteParams& params, std::optional> done_callback = std::nullopt) { using scalar_type = decltype(type_tag); - ::tensorflow::xla::EigenConv3DImpl( + internal::EigenConv3D( eigen_device, static_cast(output.opaque()), static_cast(input.opaque()), static_cast(kernel.opaque()), input_batch_, input_dims_.x, diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h similarity index 94% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h index d3ba1173369827..de4f7629ae48dd 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk.h @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ #include #include #include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" +#include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -123,4 +124,4 @@ class ConvolutionThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CONVOLUTION_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_H_ diff --git a/third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc similarity index 62% rename from third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc index 8a6406fe05e5f6..7b6e2ae17d1855 100644 --- a/third_party/xla/xla/stream_executor/rocm/add_i32_kernel.cu.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f16.cc @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" -extern "C" __global__ void add(int32_t* a, int32_t* b, int32_t* c) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - c[index] = a[index] + b[index]; -} +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); + +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); diff --git a/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc new file mode 100644 index 00000000000000..b93314b8474444 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_f32.cc @@ -0,0 +1,27 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" +#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" // IWYU pragma: keep + +#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) +#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" // IWYU pragma: keep +#endif + +CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); + +CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); +CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); diff --git a/third_party/xla/xla/service/cpu/runtime/conv_impl.h b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h similarity index 62% rename from third_party/xla/xla/service/cpu/runtime/conv_impl.h rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h index c6b9747bc0ed51..3275f9d8fa8455 100644 --- a/third_party/xla/xla/service/cpu/runtime/conv_impl.h +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_internal.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,36 +12,38 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ -#define XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ + +#define EIGEN_USE_THREADS #include #include +#include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" -#include "xla/tsl/framework/convolution/eigen_spatial_convolutions.h" -#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) -#include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" -#endif +namespace xla::cpu::internal { -// 'tensorflow' namespace is used so that types don't require qualification. -namespace tensorflow { -namespace xla { +// TODO(ezhulenev): Make internal implementation a private static method of +// ConvolutionThunk (for consistency with DotThunk). Today we keep it as a free +// function to use it in the legacy XLA CPU runtime. template -void EigenConv2DImpl( - const EigenDevice& device, ScalarType* out, ScalarType* lhs, - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, - Eigen::Index input_y, Eigen::Index input_channels, Eigen::Index kernel_x, - Eigen::Index kernel_y, Eigen::Index kernel_channels, - Eigen::Index kernel_filters, Eigen::Index output_x, Eigen::Index output_y, - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index padding_x_before, - Eigen::Index padding_x_after, Eigen::Index padding_y_before, - Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, - Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, - Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, - std::optional> done_callback = std::nullopt) { +void EigenConv2D(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, Eigen::Index input_batch, + Eigen::Index input_x, Eigen::Index input_y, + Eigen::Index input_channels, Eigen::Index kernel_x, + Eigen::Index kernel_y, Eigen::Index kernel_channels, + Eigen::Index kernel_filters, Eigen::Index output_x, + Eigen::Index output_y, Eigen::Index x_stride, + Eigen::Index y_stride, Eigen::Index padding_x_before, + Eigen::Index padding_x_after, Eigen::Index padding_y_before, + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, + std::optional> done_callback) { const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_x, input_y, input_channels); @@ -114,22 +116,23 @@ void EigenConv2DImpl( } template -void EigenConv3DImpl( - const EigenDevice& device, ScalarType* out, ScalarType* lhs, - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, - Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, - Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, - Eigen::Index kernel_channels, Eigen::Index kernel_filters, - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, - Eigen::Index padding_x_before, Eigen::Index padding_x_after, - Eigen::Index padding_y_before, Eigen::Index padding_y_after, - Eigen::Index padding_z_before, Eigen::Index padding_z_after, - Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, - Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, - Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, - Eigen::Index feature_group_count, - std::optional> done_callback = std::nullopt) { +void EigenConv3D(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, Eigen::Index input_batch, + Eigen::Index input_x, Eigen::Index input_y, + Eigen::Index input_z, Eigen::Index input_channels, + Eigen::Index kernel_x, Eigen::Index kernel_y, + Eigen::Index kernel_z, Eigen::Index kernel_channels, + Eigen::Index kernel_filters, Eigen::Index output_x, + Eigen::Index output_y, Eigen::Index output_z, + Eigen::Index x_stride, Eigen::Index y_stride, + Eigen::Index z_stride, Eigen::Index padding_x_before, + Eigen::Index padding_x_after, Eigen::Index padding_y_before, + Eigen::Index padding_y_after, Eigen::Index padding_z_before, + Eigen::Index padding_z_after, Eigen::Index lhs_x_dilation, + Eigen::Index lhs_y_dilation, Eigen::Index lhs_z_dilation, + Eigen::Index rhs_x_dilation, Eigen::Index rhs_y_dilation, + Eigen::Index rhs_z_dilation, Eigen::Index feature_group_count, + std::optional> done_callback) { using ConstTType = Eigen::TensorMap, Eigen::Aligned>; @@ -210,10 +213,10 @@ void EigenConv3DImpl( } // Extern Conv2D template for all supported devices and data types. -#define CONV2D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ - extern template void EigenConv2DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ +#define CONV2D_EXTERN_TEMPLATE(DEVICE, SCALAR_TYPE) \ + extern template void EigenConv2D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ Eigen::Index input_y, Eigen::Index input_channels, \ Eigen::Index kernel_x, Eigen::Index kernel_y, \ Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ @@ -223,7 +226,7 @@ void EigenConv3DImpl( Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ - std::optional> done_callback = std::nullopt) + std::optional> done_callback) CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); @@ -233,10 +236,10 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); #undef CONV2D_EXTERN_TEMPLATE // Extern Conv3D template for all supported devices and data types. -#define CONV3D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \ - extern template void EigenConv3DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ +#define CONV3D_EXTERN_TEMPLATE(DEVICE, SCALAR_TYPE) \ + extern template void EigenConv3D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ @@ -249,7 +252,7 @@ CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ Eigen::Index feature_group_count, \ - std::optional> done_callback = std::nullopt) + std::optional> done_callback) CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half); CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float); @@ -258,7 +261,39 @@ CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float); #undef CONV3D_EXTERN_TEMPLATE -} // namespace xla -} // namespace tensorflow +} // namespace xla::cpu::internal + +#define CONV2D_INSTANTIATE_TEMPLATE(DEVICE, SCALAR_TYPE) \ + template void xla::cpu::internal::EigenConv2D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ + Eigen::Index y_stride, Eigen::Index padding_x_before, \ + Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ + Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ + Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ + std::optional> done_callback) + +#define CONV3D_INSTANTIATE_TEMPLATE(DEVICE, SCALAR_TYPE) \ + template void xla::cpu::internal::EigenConv3D( \ + const DEVICE& device, SCALAR_TYPE* out, SCALAR_TYPE* lhs, \ + SCALAR_TYPE* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ + Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ + Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ + Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ + Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ + Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ + Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ + Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ + Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ + Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ + Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ + Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ + Eigen::Index feature_group_count, \ + std::optional> done_callback) -#endif // XLA_SERVICE_CPU_RUNTIME_CONV_IMPL_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CONVOLUTION_THUNK_INTERNAL_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc index 3671431333d595..20a75d1f97ebcc 100644 --- a/third_party/xla/xla/service/cpu/runtime/convolution_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/convolution_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" #include #include @@ -25,10 +25,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/status/status.h" #include "Eigen/Core" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc index 1ea16dbdbf4d53..67b4d557256950 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" #define EIGEN_USE_THREADS @@ -34,10 +34,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/pjrt/transpose.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk.h b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk.h index a65425c7f5427d..ed2cd68df5137a 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/pjrt/transpose.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -69,4 +69,4 @@ class CopyThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_COPY_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_COPY_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc index 406d6b1a8aa7dc..8a8e4fb4debd27 100644 --- a/third_party/xla/xla/service/cpu/runtime/copy_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/copy_thunk_test.cc @@ -13,15 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" #include #include +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc similarity index 86% rename from third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc index 1161673db8764b..8c774ba7759c35 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/custom_call_thunk.h" +#include "xla/backends/cpu/runtime/custom_call_thunk.h" #include #include @@ -26,6 +26,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -35,18 +36,21 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Support/LLVM.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/ffi/api/c_api.h" #include "xla/ffi/attribute_map.h" #include "xla/ffi/call_frame.h" +#include "xla/ffi/execution_state.h" #include "xla/ffi/ffi_api.h" #include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" #include "xla/service/custom_call_target_registry.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/lib/traceme.h" @@ -111,6 +115,36 @@ absl::StatusOr BuildCallFrameForTypedFFI( return builder.Build(); } +absl::Status InstantiateHandlerState(absl::string_view target_name, + ffi::ExecutionState* execution_state) { + // Find the registered FFI handler for this target. + auto handler = ffi::FindHandler(target_name, "Host"); + if (!handler.ok()) { + return NotFound( + "No registered implementation for FFI custom call to %s for Host", + target_name); + } + + // Initialize FFI handler state if it has an instantiate callback. + if (handler->bundle.instantiate) { + // At FFI handler instantiation time, we don't have any arguments or + // results or access to the underlying device (stream, etc.) + ffi::CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + + // TODO(abanas): Add attributes support. All attributes should be accessible + // at all phases, namely instantiation and execution. Also add tests for CPU + // and GPU backends (GPU supports it, but tests are missing there). + ffi::CallFrame instantiate_call_frame = builder.Build(); + + ffi::CallOptions options; + options.execution_state = execution_state; + TF_RETURN_IF_ERROR(Call(handler->bundle.instantiate, instantiate_call_frame, + options, XLA_FFI_ExecutionStage_INSTANTIATE)); + } + + return absl::OkStatus(); +} + } // namespace absl::StatusOr> CustomCallThunk::Create( @@ -121,7 +155,13 @@ absl::StatusOr> CustomCallThunk::Create( TF_ASSIGN_OR_RETURN( call_frame, BuildCallFrameForTypedFFI(api_version, op_buffers, backend_config)); + + // TODO(abanas): Pass execution state to thunk. + auto execution_state = std::make_unique(); + TF_RETURN_IF_ERROR( + InstantiateHandlerState(target_name, execution_state.get())); } + return absl::WrapUnique(new CustomCallThunk( std::move(info), target_name, std::move(op_buffers), api_version, std::move(backend_config), std::move(call_frame))); @@ -196,11 +236,11 @@ tsl::AsyncValueRef CustomCallThunk::CallTypedFFI( // Forward ExecutableRunOptions to the FFI handlers via the call options. CustomCallExecuteParams* custom_call_params = params.custom_call_params; - ffi::CallOptions call_options = {custom_call_params->device_ordinal, - custom_call_params->stream, - custom_call_params->allocator, - /*called_computation=*/nullptr, - custom_call_params->ffi_execution_context}; + ffi::CallOptions call_options = { + custom_call_params->device_ordinal, + ffi::CallOptions::CpuOptions{custom_call_params->intra_op_thread_pool}, + /*called_computation=*/nullptr, + custom_call_params->ffi_execution_context}; // Call the function and check execution status. auto status = ffi::Call(handler->bundle.execute, call_frame, call_options); diff --git a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h index 901545fa9f5d1f..bfea5368f7cb9b 100644 --- a/third_party/xla/xla/service/cpu/runtime/custom_call_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/custom_call_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ #include #include @@ -25,9 +25,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/ffi/call_frame.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/custom_call_status.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -80,4 +80,4 @@ class CustomCallThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_CUSTOM_CALL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc index c92307c52f064c..418ed65ce1cbb2 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" #include #include @@ -30,10 +30,10 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk.h b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk.h index acaa94d5bf7779..61bcb8194e1150 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ #define EIGEN_USE_THREADS @@ -29,9 +29,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "Eigen/Core" #include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -175,4 +175,4 @@ DOT_THUNK_EXTERN_MATMUL_TEMPLATE(std::complex); } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_DOT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_DOT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc index cd2852e26aa980..1c791bd6fac78c 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c128.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c128.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul>( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc index 55f21cceb344bf..957e2d6d855630 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_c64.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_c64.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul>( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc index df04b0d1272a1a..35d85c89154187 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f16.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f16.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc similarity index 93% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc index d98c5d940ed3b1..f3aee5501ac413 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f32.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f32.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) #include "xla/tsl/framework/contraction/eigen_contraction_kernel.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc similarity index 91% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc index f782cc7045ff7e..bcb8bd676af8db 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_f64.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_f64.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc similarity index 91% rename from third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc rename to third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc index 59186ec8a5669a..0851e01b539c0a 100644 --- a/third_party/xla/xla/service/cpu/runtime/dot_thunk_s32.cc +++ b/third_party/xla/xla/backends/cpu/runtime/dot_thunk_s32.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" // NOLINT IWYU pragma: keep template void ::xla::cpu::DotThunk::TypedMatMul( const Eigen::ThreadPoolDevice* device, void* out, void* lhs, void* rhs, diff --git a/third_party/xla/xla/service/cpu/runtime/fft_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/fft_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc index 5d792c2fc8c163..b7c898b26d177c 100644 --- a/third_party/xla/xla/service/cpu/runtime/fft_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/fft_thunk.h" +#include "xla/backends/cpu/runtime/fft_thunk.h" #include #include @@ -21,10 +21,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_fft.h" #include "xla/service/cpu/runtime_single_threaded_fft.h" #include "xla/shape.h" diff --git a/third_party/xla/xla/service/cpu/runtime/fft_thunk.h b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/fft_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/fft_thunk.h index b63ed5e9b744e7..64d4063d828cf7 100644 --- a/third_party/xla/xla/service/cpu/runtime/fft_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/fft_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ #include #include @@ -22,8 +22,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -68,4 +68,4 @@ class FftThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_FFT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_FFT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc index 9e8acff4ecb271..e1a601565c69d3 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" #include #include @@ -24,10 +24,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.h b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h index 622046f2e3785d..1d4225d1ddd008 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ #include #include #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -61,4 +61,4 @@ class InfeedThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_INFEED_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_INFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc index 53394e242c56a0..3bbb4272f22834 100644 --- a/third_party/xla/xla/service/cpu/runtime/infeed_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/infeed_thunk_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc new file mode 100644 index 00000000000000..4656bf8ef73a39 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.cc @@ -0,0 +1,412 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/kernel_thunk.h" + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/base/optimization.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/runtime/buffer_use.h" +#include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" + +namespace xla::cpu { +namespace internal { + +// Checks that all buffers are aligned to the minimum alignment. We codegen +// with the assumption that all buffers are aligned, and if they are not, we +// will crash with a segmentation fault, or worse, produce incorrect results. +static absl::Status CheckBufferAlignment( + const Thunk::Info& info, uint64_t min_alignment, + absl::Span kernel_args) { + if (min_alignment == 0) return absl::OkStatus(); + + for (int64_t i = 0; i < kernel_args.size(); ++i) { + auto ptr = reinterpret_cast(kernel_args[i].data); + if (ABSL_PREDICT_FALSE((ptr & (min_alignment - 1)) != 0)) { + return Internal( + "Host kernel %s buffer argument #%d (%p) is not aligned to a " + "required minimum alignment of %d bytes", + info.op_name, i, kernel_args[i].data, min_alignment); + } + } + + return absl::OkStatus(); +} + +// VLOGs kernel arguments resolved from the buffer allocations. +static void VlogKernelArgs( + absl::Span arguments_buffers, + absl::Span results_buffers, + absl::Span kernel_args) { + for (int64_t i = 0; i < arguments_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, + arguments_buffers[i].ToString(), + kernel_args[i].data); + } + for (int64_t i = 0; i < results_buffers.size(); ++i) { + VLOG(3) << absl::StreamFormat( + " res #%d: %s (%p)", i, results_buffers[i].ToString(), + kernel_args[arguments_buffers.size() + i].data); + } +} + +// Returns kernel buffer uses for a given arguments and results buffers. +static Thunk::BufferUses KernelBufferUses( + absl::Span arguments_buffers, + absl::Span results_buffers) { + Thunk::BufferUses buffer_uses; + for (const BufferAllocation::Slice& buffer : arguments_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kRead); + } + for (const BufferAllocation::Slice& buffer : results_buffers) { + buffer_uses.emplace_back(buffer, BufferUse::kWrite); + } + return buffer_uses; +} + +template +KernelThunk::KernelThunk( + Info info, absl::Span arguments_buffers, + absl::Span results_buffers, + absl::flat_hash_set invariant_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment) + : Thunk(Kind::kKernel, std::move(info)), + invariant_buffers_(std::move(invariant_buffers)), + num_kernel_args_(arguments_buffers.size() + results_buffers.size()), + kernel_name_(std::move(kernel_name)), + thread_dim_(thread_dim), + min_alignment_(min_alignment), + call_once_(thread_dim_ == se::ThreadDim()), + kernel_ptr_(nullptr) { + // Resize storage for arguments and results buffers if it is dynamic. + if constexpr (IsDynamic(num_arguments)) { + arguments_buffers_.resize(arguments_buffers.size()); + } + if constexpr (IsDynamic(num_results)) { + results_buffers_.resize(results_buffers.size()); + } + + // Copy buffers from the arguments and results. + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + arguments_buffers_[i] = arguments_buffers[i]; + } + for (size_t i = 0; i < results_buffers.size(); ++i) { + results_buffers_[i] = results_buffers[i]; + } + + // Resize storage for kernel arguments if it is dynamic. + if constexpr (IsDynamic(num_arguments) || IsDynamic(num_results)) { + kernel_args_.resize(num_kernel_args_); + } + + // Initialize kernel arguments with null pointers and known buffer sizes. + // We'll use them as a template to resolve buffer addresses at run time. + for (size_t i = 0; i < arguments_buffers.size(); ++i) { + kernel_args_[i] = SE_HOST_KernelArg{ + nullptr, static_cast(arguments_buffers_[i].size())}; + } + for (size_t i = 0; i < results_buffers.size(); ++i) { + kernel_args_[arguments_buffers_.size() + i] = SE_HOST_KernelArg{ + nullptr, static_cast(results_buffers_[i].size())}; + } +} + +template +ABSL_ATTRIBUTE_ALWAYS_INLINE tsl::AsyncValueRef +KernelThunk::ExecuteInternal( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + VLOG(3) << absl::StreamFormat( + "Launch host kernel %s with %d arguments buffers and %d results buffers: " + "#threads=%s", + kernel_name_, arguments_buffers_.size(), results_buffers_.size(), + thread_dim_.ToString()); + + KernelArgs kernel_args = kernel_args_; + SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); + + const BufferAllocations* allocations = params.buffer_allocations; + + for (BufferAllocation::Slice& buffer : arguments_buffers_) { + if constexpr (ShouldCheckBufferSlices()) { + TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); + kernel_args_ptr++->data = mem.opaque(); + } else { + auto mem = allocations->GetDeviceAddressUnchecked(buffer); + kernel_args_ptr++->data = mem.opaque(); + } + } + + for (BufferAllocation::Slice& buffer : results_buffers_) { + if constexpr (ShouldCheckBufferSlices()) { + TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); + kernel_args_ptr++->data = mem.opaque(); + } else { + auto mem = allocations->GetDeviceAddressUnchecked(buffer); + kernel_args_ptr++->data = mem.opaque(); + } + } + + if (ABSL_PREDICT_FALSE(VLOG_IS_ON(3))) { + VlogKernelArgs(arguments_buffers_, results_buffers_, kernel_args); + } + + // Сheck that all resolved buffers are properly aligned, and that invariant + // property holds. + if constexpr (ShouldCheckBufferSlices()) { + TF_RETURN_IF_ERROR( + CheckBufferAlignment(info(), min_alignment_.value_or(0), kernel_args)); + TF_RETURN_IF_ERROR(CheckInvariantBufferSlices()); + TF_RETURN_IF_ERROR(CheckInvariantBuffersMemory(*allocations)); + } + + // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk + // initialization stage. + se::host::HostKernel* kernel = kernel_ptr_.load(std::memory_order_acquire); + + // Because thunks are owned by a parent CpuExecutable, we can safely assume + // that kernel pointer will not change after we find it the first time. + if (ABSL_PREDICT_FALSE(kernel == nullptr)) { + TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn, + params.function_registry->FindKernel(kernel_name_)); + + absl::MutexLock lock(&mutex_); + if ((kernel = kernel_ptr_.load(std::memory_order_relaxed)) == nullptr) { + kernel = &kernel_.emplace(num_kernel_args_, kernel_fn, nullptr); + kernel_ptr_.store(kernel, std::memory_order_release); + } + } + + // Use a fast path if kernel called just once. + if (ABSL_PREDICT_TRUE(call_once_)) { + TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args)); + return OkExecuteEvent(); + } + + // If intra-op thread pool is not nullptr, we launch HostKernel in async mode + // by scheduling tasks into it. HostKernel launch completion will + // automatically signal KernelThunk execute completion. + if (ABSL_PREDICT_TRUE(params.intra_op_threadpool)) { + return kernel->Launch( + thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { + params.intra_op_threadpool->getPool()->Schedule(std::move(task)); + }); + } + + TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); + return OkExecuteEvent(); +} + +template +absl::Status +KernelThunk::CheckInvariantBufferSlices() const { + // We can use absl::c_contains here when we have C++20 support. + // TODO(abanas): Check for overlapping buffers. + auto contains = [](const auto& container, + const BufferAllocation::Slice& buffer) { + return absl::c_find(container, buffer) != container.end(); + }; + + // Verify all argument buffers. + for (const BufferAllocation::Slice& buffer : arguments_buffers_) { + if (invariant_buffers_.contains(buffer)) { + // This argument should be read only, i.e. not one of the results. + if (contains(results_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, invariant buffer %s " + "should not be one of the results", + buffer.ToString()); + } + } else { + // For completeness, we check that a read write buffer is one of the + // results. + if (!contains(results_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, read-write buffer %s " + "is not one of the results", + buffer.ToString()); + } + } + } + + // Verify that there are no extra buffers in invariant buffers set. + for (auto& buffer : invariant_buffers_) { + if (!contains(arguments_buffers_, buffer)) { + return Internal( + "Mismatch in invariant buffers metadata, unknown buffer found: %s", + buffer.ToString()); + } + } + return absl::OkStatus(); +} + +// TODO(abanas): Return absl::flat_hash_set. This requires implementing a hash +// function for DeviceMemoryBase. +template +static absl::StatusOr> ToDeviceMemorySet( + const Iterable& buffers, const BufferAllocations& allocations) { + std::vector result; + for (const BufferAllocation::Slice& slice : buffers) { + TF_ASSIGN_OR_RETURN(auto memory, allocations.GetDeviceAddress(slice)); + result.push_back(std::move(memory)); + } + return result; +} + +// The logic here is similar to CheckInvariantBufferSlices, but we check +// memory addresses instead of buffer slices. +template +absl::Status +KernelThunk::CheckInvariantBuffersMemory( + const BufferAllocations& allocations) const { + // We can use absl::c_contains here when we have C++20 support. + auto contains = [](const std::vector& container, + const se::DeviceMemoryBase& memory) { + return absl::c_find(container, memory) != container.end(); + }; + + TF_ASSIGN_OR_RETURN(auto results_memory_set, + ToDeviceMemorySet(results_buffers_, allocations)); + TF_ASSIGN_OR_RETURN(auto invariant_memory_set, + ToDeviceMemorySet(invariant_buffers_, allocations)); + + // Verify all argument buffers. + for (const BufferAllocation::Slice& argument_slice : arguments_buffers_) { + TF_ASSIGN_OR_RETURN(auto argument_memory, + allocations.GetDeviceAddress(argument_slice)); + if (contains(invariant_memory_set, argument_memory)) { + // This argument should be read only, i.e. not one of the results. + if (contains(results_memory_set, argument_memory)) { + return Internal( + "Mismatch in invariant buffers metadata, device memory of " + "invariant buffer %s should not be one of the results", + argument_slice.ToString()); + } + } else { + // For completeness, we check that a read write buffer is one of the + // results. + if (!contains(results_memory_set, argument_memory)) { + return Internal( + "Mismatch in invariant buffers metadata, device memory of " + "read-write buffer %s is not one of the results", + argument_slice.ToString()); + } + } + } + + return absl::OkStatus(); +} + +template +Thunk::BufferUses KernelThunk::buffer_uses() const { + return KernelBufferUses(arguments_buffers_, results_buffers_); +} + +} // namespace internal + +tsl::AsyncValueRef KernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); +} + +template +tsl::AsyncValueRef +SmallKernelThunk::Execute( + const Thunk::ExecuteParams& params) { + return Base::ExecuteInternal(params); +} + +absl::StatusOr> KernelThunk::Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + absl::flat_hash_set invariant_buffers, + std::optional min_alignment) { + if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { + return Internal("Host kernel %s minimum alignment %d is not a power of 2", + info.op_name, *min_alignment); + } + + auto small_kernel_thunk = [&](auto num_arguments, auto num_results) { + return absl::WrapUnique( + new SmallKernelThunk( + std::move(info), arguments_buffers, results_buffers, + std::move(invariant_buffers), std::move(kernel_name), thread_dim, + min_alignment)); + }; + + static constexpr auto _0 = std::integral_constant{}; + static constexpr auto _1 = std::integral_constant{}; + static constexpr auto _2 = std::integral_constant{}; + static constexpr auto _3 = std::integral_constant{}; + static constexpr auto _4 = std::integral_constant{}; + static constexpr auto _5 = std::integral_constant{}; + static constexpr auto _6 = std::integral_constant{}; + + std::pair params(arguments_buffers.size(), + results_buffers.size()); + + // Return SmallKernelThunk specializations for the most common cases. + if (params == std::make_pair(_0(), _1())) return small_kernel_thunk(_0, _1); + if (params == std::make_pair(_1(), _1())) return small_kernel_thunk(_1, _1); + if (params == std::make_pair(_2(), _1())) return small_kernel_thunk(_2, _1); + if (params == std::make_pair(_3(), _1())) return small_kernel_thunk(_3, _1); + if (params == std::make_pair(_4(), _1())) return small_kernel_thunk(_4, _1); + if (params == std::make_pair(_5(), _1())) return small_kernel_thunk(_5, _1); + if (params == std::make_pair(_6(), _1())) return small_kernel_thunk(_6, _1); + + // Return a generic KernelThunk for dynamic numbers of arguments and results. + return absl::WrapUnique( + new KernelThunk(std::move(info), arguments_buffers, results_buffers, + std::move(invariant_buffers), std::move(kernel_name), + thread_dim, min_alignment)); +} + +} // namespace xla::cpu diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h new file mode 100644 index 00000000000000..fd0567ae1e62e9 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk.h @@ -0,0 +1,171 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/stream_executor/host/host_kernel.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/concurrency/async_value_ref.h" + +namespace xla::cpu { + +// Forward declare thunk defined below. +class KernelThunk; + +namespace internal { + +// If the number of kernel parameters (arguments and results) is unknown at +// compile time, we use this value to indicate that the parameter is dynamic. +inline constexpr int64_t kDynamicKernelParameter = -1; + +// A base template for a KernelThunk that can be specialized for a statically +// known number of arguments and results. We go extra mile here to optimize +// host kernel dispatching on the hot execution path to minimize the XLA runtime +// overheads for the smallest HLO modules. +template +class KernelThunk : public Thunk { + public: + BufferUses buffer_uses() const final; + + protected: + tsl::AsyncValueRef ExecuteInternal(const ExecuteParams& params); + + private: + friend class ::xla::cpu::KernelThunk; + + static constexpr bool IsDynamic(size_t n) { + return n == kDynamicKernelParameter; + } + + static constexpr size_t Size(int64_t size) { + return std::max(size, 0); + } + + // If we know the number of arguments and results at compile time, we use + // std::array with a fixed size, which allows compiler to automatically unroll + // all the loops on a hot path. + + using ArgumentsBuffers = std::conditional_t< + IsDynamic(num_arguments), std::vector, + std::array>; + + using ResultsBuffers = std::conditional_t< + IsDynamic(num_results), std::vector, + std::array>; + + using KernelArgs = std::conditional_t< + IsDynamic(num_arguments) || IsDynamic(num_results), + absl::InlinedVector, + std::array>; + + KernelThunk(Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + absl::flat_hash_set invariant_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + std::optional min_alignment); + + absl::Status CheckInvariantBufferSlices() const; + + absl::Status CheckInvariantBuffersMemory( + const BufferAllocations& buffer_allocations) const; + + ArgumentsBuffers arguments_buffers_; + ResultsBuffers results_buffers_; + + absl::flat_hash_set invariant_buffers_; + + size_t num_kernel_args_; + + std::string kernel_name_; + se::ThreadDim thread_dim_; + std::optional min_alignment_; + + // If `true`, host kernel will be called just once for a logical thread dim + // (1,1,1). This is a fast path for small host kernels that have just one + // logical thread dim. + bool call_once_; + + // Lazily loaded host kernel corresponding to `kernel_name_`. + absl::Mutex mutex_; + std::optional kernel_ ABSL_GUARDED_BY(mutex_); + std::atomic kernel_ptr_; // pointer to `kernel_` + + // Pre-initialized kernel arguments that are updated with memory addresses + // before the kernel launch. + KernelArgs kernel_args_; +}; + +} // namespace internal + +// Kernel thunk specialization for a small kernel with a statically known number +// of arguments and results. +template +class SmallKernelThunk final + : public internal::KernelThunk { + using Base = internal::KernelThunk; + + public: + using Base::Base; + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; +}; + +// Kernel thunk specialization for dynamic number of arguments and results. +class KernelThunk final : public internal::KernelThunk<> { + using Base = internal::KernelThunk<>; + + public: + using Base::Base; + + static absl::StatusOr> Create( + Thunk::Info info, + absl::Span arguments_buffers, + absl::Span results_buffers, + std::string kernel_name, se::ThreadDim thread_dim, + absl::flat_hash_set invariant_buffers, + std::optional min_alignment = std::nullopt); + + tsl::AsyncValueRef Execute( + const Thunk::ExecuteParams& params) final; +}; + +} // namespace xla::cpu + +#endif // XLA_BACKENDS_CPU_RUNTIME_KERNEL_THUNK_H_ diff --git a/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc new file mode 100644 index 00000000000000..1599694f8c7896 --- /dev/null +++ b/third_party/xla/xla/backends/cpu/runtime/kernel_thunk_test.cc @@ -0,0 +1,294 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/backends/cpu/runtime/kernel_thunk.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/maybe_owning_device_memory.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/host/host_kernel_c_api.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +class AddF32HostKernel : public Thunk::FunctionRegistry { + public: + absl::StatusOr FindKernel(std::string_view name) override { + return +[](const SE_HOST_KernelCallFrame* call_frame) { + const SE_HOST_KernelArg& in = call_frame->args[0]; + const SE_HOST_KernelArg& out = call_frame->args[1]; + + float* in_ptr = reinterpret_cast(in.data); + float* out_ptr = reinterpret_cast(out.data); + + uint64_t i = call_frame->thread->x; + *(out_ptr + i) = *(in_ptr + i) + *(in_ptr + i); + + return static_cast(nullptr); + }; + } +}; + +TEST(KernelThunkTest, CheckAlignment) { + auto thunk = + KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(), {}, + /*min_alignment=*/3); + EXPECT_TRUE(absl::StrContains(thunk.status().message(), + "minimum alignment 3 is not a power of 2")); +} + +TEST(KernelThunkTest, AddF32) { + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, + "add_f32", se::ThreadDim(4), {in_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()) << execute_event.GetError(); + + std::vector expected = {2.0, 4.0, 6.0, 8.0}; + EXPECT_EQ(out, expected); +} + +TEST(KernelThunkTest, AddF32Inline) { + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + BufferAllocation in_out_alloc(0, size_in_bytes, 0); + BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_out_slice}, {in_out_slice}, + "add_f32", se::ThreadDim(4), {})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + std::vector expected = {2.0, 4.0, 6.0, 8.0}; + EXPECT_EQ(in_out, expected); +} + +TEST(KernelThunkInvariantBuffersTest, MissingBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect - should include in_slice, but is empty. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", + se::ThreadDim(4), /*invariant_buffers=*/{})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +TEST(KernelThunkInvariantBuffersTest, ExtraInputOutputBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + BufferAllocation in_out_alloc(0, size_in_bytes, 0); + BufferAllocation::Slice in_out_slice(&in_out_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect - should be empty, but contains input + // buffer that's not invariant. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create( + {"add_f32"}, {in_out_slice}, {in_out_slice}, "add_f32", + se::ThreadDim(4), /*invariant_buffers=*/{in_out_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +TEST(KernelThunkInvariantBuffersTest, ExtraIncorrectBufferSlice) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + std::vector buffers; + std::vector in = {1.0, 2.0, 3.0, 4.0}; + std::vector out(4, 0.0); + std::vector unrelated(4, 0.0); + + size_t size_in_bytes = in.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(unrelated.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_alloc(0, size_in_bytes, 0); + BufferAllocation out_alloc(1, size_in_bytes, 0); + BufferAllocation unrelated_alloc(2, size_in_bytes, 0); + + BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); + BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); + BufferAllocation::Slice unrelated_slice(&unrelated_alloc, 0, size_in_bytes); + + // Invariant buffer set contains all invariant buffers, but still it is + // incorrect - it contains a buffer that's unrelated to the kernel. + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, + KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, "add_f32", + se::ThreadDim(4), + /*invariant_buffers=*/{in_slice, unrelated_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +// This case should never happen in practice, it simulates a bug in the code +// that incorrectly sets up aliases. +TEST(KernelThunkInvariantBuffersTest, + MemorySectionIncorrectlyMarkedAsInvariant) { +#ifdef NDEBUG + GTEST_SKIP() << "Invariant buffers check is disabled in optimized build."; +#endif + + // We've got only one memory section + std::vector buffers; + std::vector in_out = {1.0, 2.0, 3.0, 4.0}; + + // We've got two buffer slices with different indexes, but both pointing to + // the same memory section. + size_t size_in_bytes = in_out.size() * sizeof(float); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + buffers.emplace_back(se::DeviceMemoryBase(in_out.data(), size_in_bytes)); + + BufferAllocations allocations(buffers); + + BufferAllocation in_0_alloc(0, size_in_bytes, 0); + BufferAllocation in_1_alloc(1, size_in_bytes, 0); + + BufferAllocation::Slice in_0_slice(&in_0_alloc, 0, size_in_bytes); + BufferAllocation::Slice in_1_slice(&in_1_alloc, 0, size_in_bytes); + + // Invariant buffer set is incorrect. in_1_slice is not aliased to any output, + // but it points to the same memory section as in_0_slice (which is not + // invariant, because is aliased with the output). + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, KernelThunk::Create({"add_f32"}, {in_0_slice, in_1_slice}, + {in_0_slice}, "add_f32", se::ThreadDim(4), + /*invariant_buffers=*/{in_1_slice})); + + AddF32HostKernel host_kernels; + Thunk::ExecuteParams params = {&host_kernels, &allocations}; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_TRUE(execute_event.IsError()); + + auto status = execute_event.GetError(); + EXPECT_EQ(status.code(), absl::StatusCode::kInternal); + EXPECT_TRUE(absl::StrContains(status.message(), + "Mismatch in invariant buffers metadata")); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc index 61c8f4c801db32..ace52302dc953d 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h similarity index 90% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h index bb4d2fd12840ff..6a42fe69963d1a 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" #include "xla/service/computation_placer.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/global_device_id.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -68,4 +68,4 @@ class PartitionIdThunk final } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_LOGICAL_ID_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc index 72ce59f85dad5c..c8dd0a60782fed 100644 --- a/third_party/xla/xla/service/cpu/runtime/logical_id_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/logical_id_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" #include #include @@ -22,10 +22,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/executable_run_options.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc index a56ae0c437a7ec..b541953a403dee 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" #include #include @@ -23,10 +23,10 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h similarity index 87% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h index ff05339002ffc5..74920899255d46 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk.h @@ -13,17 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ #include #include #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -60,4 +60,4 @@ class OutfeedThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_OUTFEED_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_OUTFEED_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc index 2c6b9b9a91123f..0139a95f777e47 100644 --- a/third_party/xla/xla/service/cpu/runtime/outfeed_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/outfeed_thunk_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" #include +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc index 701ac3243ebd90..920aa3dc545b19 100644 --- a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" +#include "xla/backends/cpu/runtime/reduce_scatter_thunk.h" #include #include @@ -24,12 +24,12 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/primitive_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h similarity index 86% rename from third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h index d37e1b22db5566..104d6c354dfa88 100644 --- a/third_party/xla/xla/service/cpu/runtime/reduce_scatter_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/reduce_scatter_thunk.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" #include "xla/service/collective_ops_utils.h" -#include "xla/service/cpu/runtime/collective_thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/xla_data.pb.h" @@ -44,4 +44,4 @@ class ReduceScatterThunk final : public CollectiveThunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_REDUCE_SCATTER_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use.cc b/third_party/xla/xla/backends/cpu/runtime/resource_use.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/resource_use.cc rename to third_party/xla/xla/backends/cpu/runtime/resource_use.cc index 3e5ceabb9ac53a..a3c03849b5178a 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use.cc +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use.h b/third_party/xla/xla/backends/cpu/runtime/resource_use.h similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/resource_use.h rename to third_party/xla/xla/backends/cpu/runtime/resource_use.h index 6ee1f1bfd6ac95..1442a2895a02bf 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use.h +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ -#define XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ +#define XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ #include #include @@ -111,4 +111,4 @@ class ResourceUse { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_RESOURCE_USE_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/resource_use_test.cc b/third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc similarity index 97% rename from third_party/xla/xla/service/cpu/runtime/resource_use_test.cc rename to third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc index 4d3c9bbaf4cecc..dd5115bcaf2ae5 100644 --- a/third_party/xla/xla/service/cpu/runtime/resource_use_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/resource_use_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc index df611bd5fe169f..39a3de9b9429dc 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/rng_state_thunk.h" +#include "xla/backends/cpu/runtime/rng_state_thunk.h" #include #include @@ -26,8 +26,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h similarity index 89% rename from third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h index 9798ed7c105f4b..d00bf4523e5dea 100644 --- a/third_party/xla/xla/service/cpu/runtime/rng_state_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/rng_state_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ #include #include @@ -23,9 +23,9 @@ limitations under the License. #include "absl/numeric/int128.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" namespace xla::cpu { @@ -56,4 +56,4 @@ class RngGetAndUpdateStateThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_RNG_STATE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_RNG_STATE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc similarity index 99% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc index a24a2272587b25..8d2df6f298cbcf 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" #include #include @@ -38,11 +38,11 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout_util.h" #include "xla/primitive_util.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_memory.h" @@ -471,6 +471,9 @@ static absl::Status SortInplace(absl::Span data, case 25: sort(std::integral_constant{}); break; + case 29: + sort(std::integral_constant{}); + break; default: return Internal("Unsupported number of sorted inputs: %d", data.size()); } diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk.h b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h similarity index 93% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk.h index 049fa062cff603..a1c2b5eda242ee 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ #include #include @@ -28,8 +28,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/shape.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -81,4 +81,4 @@ class SortThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_SORT_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_SORT_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc similarity index 98% rename from third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc index 4c7b2514a1c709..1f450f77548d70 100644 --- a/third_party/xla/xla/service/cpu/runtime/sort_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/sort_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" #include #include @@ -21,11 +21,11 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/shape.h" #include "xla/shape_util.h" diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.cc b/third_party/xla/xla/backends/cpu/runtime/thunk.cc similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk.cc index 455c940e264f3b..41a02a5ca3a413 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include #include @@ -85,6 +85,10 @@ std::string_view Thunk::KindToString(Kind kind) { return "while"; } } +Thunk::Thunk(Kind kind, Info info) + : kind_(kind), + info_(std::move(info)), + ok_event_(OkExecuteEventSingleton()) {} absl::StatusOr Thunk::CollectiveExecuteParams::Create( @@ -136,27 +140,25 @@ Thunk::CustomCallExecuteParams::Create( ? run_options->device_ordinal() : run_options->stream()->parent()->device_ordinal(); - return CustomCallExecuteParams{device_ordinal, run_options->stream(), - run_options->allocator(), + return CustomCallExecuteParams{device_ordinal, + run_options->intra_op_thread_pool(), run_options->ffi_execution_context()}; } Thunk::CustomCallExecuteParams::CustomCallExecuteParams( - int32_t device_ordinal, stream_executor::Stream* stream, - stream_executor::DeviceMemoryAllocator* allocator, + int32_t device_ordinal, const Eigen::ThreadPoolDevice* intra_op_thread_pool, const ffi::ExecutionContext* ffi_execution_context) : device_ordinal(device_ordinal), - stream(stream), - allocator(allocator), + intra_op_thread_pool(intra_op_thread_pool), ffi_execution_context(ffi_execution_context) {} -tsl::AsyncValueRef Thunk::OkExecuteEvent() { - static tsl::AsyncValueOwningRef* event = [] { +tsl::AsyncValueRef Thunk::OkExecuteEventSingleton() { + static tsl::AsyncValueOwningRef* singleton = [] { auto* storage = new tsl::internal::AsyncValueStorage(); return new tsl::AsyncValueOwningRef( tsl::MakeAvailableAsyncValueRef(*storage)); }(); - return event->AsRef(); + return singleton->AsRef(); } Thunk::ExecuteState::ExecuteState(int64_t num_tasks) diff --git a/third_party/xla/xla/service/cpu/runtime/thunk.h b/third_party/xla/xla/backends/cpu/runtime/thunk.h similarity index 92% rename from third_party/xla/xla/service/cpu/runtime/thunk.h rename to third_party/xla/xla/backends/cpu/runtime/thunk.h index 210d19937b2173..cfc60597e6ac6d 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ #include #include @@ -31,12 +31,12 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" #include "xla/executable_run_options.h" #include "xla/ffi/execution_context.h" #include "xla/runtime/buffer_use.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" #include "xla/service/cpu/xfeed_manager.h" #include "xla/service/global_device_id.h" #include "xla/stream_executor/host/host_kernel_c_api.h" @@ -110,7 +110,7 @@ class Thunk { using Task = std::function; using TaskRunner = absl::AnyInvocable; - Thunk(Kind kind, Info info) : kind_(kind), info_(std::move(info)) {} + Thunk(Kind kind, Info info); Thunk(const Thunk&) = delete; Thunk& operator=(const Thunk&) = delete; @@ -208,14 +208,12 @@ class Thunk { const ExecutableRunOptions* run_options); int32_t device_ordinal; - stream_executor::Stream* stream = nullptr; - stream_executor::DeviceMemoryAllocator* allocator = nullptr; + const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr; const ffi::ExecutionContext* ffi_execution_context = nullptr; private: CustomCallExecuteParams(int32_t device_ordinal, - stream_executor::Stream* stream, - stream_executor::DeviceMemoryAllocator* allocator, + const Eigen::ThreadPoolDevice* intra_op_thread_pool, const ffi::ExecutionContext* ffi_execution_context); }; @@ -286,9 +284,21 @@ class Thunk { // An execute event that becomes ready when all tasks are completed. using ExecuteEvent = tsl::Chain; - // Returns non-reference-counted async value ref for thunks executed in the - // caller thread to avoid reference counting overhead. - static tsl::AsyncValueRef OkExecuteEvent(); + // Returns non-reference-counted async value ref in constructed state. + // Returned async value is a per-process singleton stored in a storage with a + // static duration, and can be safely compared using pointer equality. + static tsl::AsyncValueRef OkExecuteEventSingleton(); + + // Returns `OkExecuteEventSingleton()` cached by this thunk instance. + tsl::AsyncValueRef OkExecuteEvent() const { return ok_event_; } + + bool IsOkExecuteEvent(const tsl::AsyncValueRef& event) const { + return event == ok_event_; + } + + bool IsOkExecuteEvent(tsl::AsyncValuePtr event) const { + return event == ok_event_.AsPtr(); + } // Thunk execution must be asynchronous and never block the caller thread, // especially waiting for work submitted into the `intra_op_threadpool`, @@ -331,6 +341,8 @@ class Thunk { private: Kind kind_; Info info_; + + tsl::AsyncValueRef ok_event_; }; std::ostream& operator<<(std::ostream& os, Thunk::Kind kind); @@ -367,4 +379,4 @@ class ThunkSequence : public std::vector> { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc similarity index 68% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc index 4281442d5c4305..eb32b508b3a1b1 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk_executor.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include +#include #include #include #include @@ -31,9 +32,9 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/logging.h" #include "tsl/profiler/lib/traceme.h" @@ -41,8 +42,11 @@ limitations under the License. namespace xla::cpu { ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, - std::vector nodes_defs) + std::vector nodes_defs, + const ThunkExecutor::Options& options) : thunk_sequence_(std::move(thunk_sequence)), + options_(options), + num_thunks_(thunk_sequence_.size()), nodes_defs_(std::move(nodes_defs)), is_sequential_(true) { for (NodeId i = 0; i < nodes_defs_.size(); ++i) { @@ -58,7 +62,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, } // Erase redundant edges between nodes. - int64_t num_erased_edges = TransitiveReduction(); + int64_t num_erased_edges = RunTransitiveReductionAndUpdatePriorities(); // Check if constructed execution DAG is sequential: every node depends on the // completion of the previous node. @@ -66,11 +70,21 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0); } + // Maybe mark execution as sequential if all thunks use small buffers. + auto uses_small_buffers = [&](const std::unique_ptr& thunk) { + return absl::c_all_of(thunk->buffer_uses(), [&](const BufferUse& use) { + return use.slice().size() <= options.execute_sequential_buffer_threshold; + }); + }; + + bool small_buffers = absl::c_all_of(thunk_sequence_, uses_small_buffers); + is_sequential_ |= small_buffers; + VLOG(2) << absl::StreamFormat( "Constructed ThunkExecutor with %d nodes: #source_nodes=%d " - "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v", + "#sink_nodes=%d, #erased_edges=%d, is_sequential=%v, small_buffers=%v", nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges, - is_sequential_); + is_sequential_, small_buffers); // Sanity check that all vectors are empty or all vectors are non-empty. DCHECK((!source_.empty() && !sink_.empty() && !thunk_sequence_.empty()) || @@ -78,7 +92,7 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence, } absl::StatusOr ThunkExecutor::Create( - ThunkSequence thunk_sequence) { + ThunkSequence thunk_sequence, const ThunkExecutor::Options& options) { std::vector defs(thunk_sequence.size()); std::vector buffer_rwsets(thunk_sequence.size()); @@ -106,9 +120,12 @@ absl::StatusOr ThunkExecutor::Create( } } - return ThunkExecutor(std::move(thunk_sequence), std::move(defs)); + return ThunkExecutor(std::move(thunk_sequence), std::move(defs), options); } +ThunkExecutor::ExecuteState::Node::Node(const NodeDef& node_def) + : counter(node_def.in_edges.size()), out_edges(&node_def.out_edges) {} + ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner) : executor(executor), @@ -120,21 +137,19 @@ ThunkExecutor::ExecuteState::ExecuteState(ThunkExecutor* executor, DCHECK(runner == nullptr || static_cast(*runner)) << "`runner` must be nullptr or a valid TaskRunner"; - Node* node = nodes.data(); + NodeStorage* node = nodes.data(); for (const NodeDef& node_def : executor->nodes_defs()) { - node->counter.store(node_def.in_edges.size(), std::memory_order_release); - node->out_edges = &node_def.out_edges; - ++node; + new (node++) Node(node_def); } } tsl::AsyncValueRef ThunkExecutor::Execute( const Thunk::ExecuteParams& params) { // Short-circuit execution of trivial thunk sequences. - if (ABSL_PREDICT_FALSE(thunk_sequence_.empty())) { - return Thunk::OkExecuteEvent(); + if (ABSL_PREDICT_FALSE(num_thunks_ == 0)) { + return Thunk::OkExecuteEventSingleton(); } - if (ABSL_PREDICT_FALSE(thunk_sequence_.size() == 1)) { + if (ABSL_PREDICT_FALSE(num_thunks_ == 1)) { return thunk_sequence_[0]->Execute(params); } @@ -146,8 +161,20 @@ tsl::AsyncValueRef ThunkExecutor::Execute( // Create async execution state on heap and kick-off execution. auto state = std::make_unique(this, params.task_runner); - Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()), - /*lock=*/params.session.Join()); + + if (options_.use_priority_ready_queue) { + Execute(state.get(), params, PriorityReadyQueue(nodes_defs_, source_), + /*lock=*/params.session.Join()); + } else { + Execute(state.get(), params, FifoReadyQueue(source_), + /*lock=*/params.session.Join()); + } + + // If execution already completed (all kernels executed in the caller thread), + // immediately return the result to avoid wasteful reference counting below. + if (ABSL_PREDICT_TRUE(state->execute_event.IsAvailable())) { + return std::move(state->execute_event); + } // Move execute state to the execute event callback to ensure that it is kept // alive while thunk executor has pending tasks. @@ -164,19 +191,24 @@ tsl::AsyncValueRef ThunkExecutor::Execute( tsl::AsyncValueRef ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { - for (int64_t i = 0; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (auto it = thunk_sequence_.begin(); it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); + // Fast path for thunks executed inline and returned OkExecuteEvent. + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { + continue; + } + // If thunk execution is not completed yet, attach a continuation to // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { auto event = tsl::MakeConstructedAsyncValueRef(); - execute_event.AndThen([this, ¶ms, i, event](absl::Status status) { + execute_event.AndThen([this, ¶ms, it, event](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return event; @@ -190,25 +222,30 @@ ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) { // If we got to the end of the sequence it means that all thunks have // succeeded. - return Thunk::OkExecuteEvent(); + return Thunk::OkExecuteEventSingleton(); } void ThunkExecutor::ResumeExecuteSequential( - int64_t index, const Thunk::ExecuteParams& params, + ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event) { - for (int64_t i = index; i < thunk_sequence_.size(); ++i) { - Thunk& thunk = *thunk_sequence_[i]; + for (; it != thunk_sequence_.end(); ++it) { + Thunk& thunk = **it; auto execute_event = thunk.Execute(params); + // Fast path for thunks executed inline and returned OkExecuteEvent. + if (ABSL_PREDICT_TRUE(thunk.IsOkExecuteEvent(execute_event))) { + continue; + } + // If thunk execution is not completed yet, attach a continuation to // resume sequential execution starting from the next thunk. if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) { execute_event.AndThen( - [this, ¶ms, i, event = std::move(event)](absl::Status status) { + [this, ¶ms, it, event = std::move(event)](absl::Status status) { if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); } else { - ResumeExecuteSequential(i + 1, params, std::move(event)); + ResumeExecuteSequential(it + 1, params, std::move(event)); } }); return; @@ -226,13 +263,14 @@ void ThunkExecutor::ResumeExecuteSequential( event.SetStateConcrete(); } +template void ThunkExecutor::Execute(ExecuteState* state, const Thunk::ExecuteParams& params, ReadyQueue ready_queue, Thunk::ExecuteSession::Lock lock) { tsl::profiler::TraceMe trace("ThunkExecutor::Execute"); - DCHECK(!ready_queue.empty()) << "Ready queue must not be empty"; + DCHECK(!ready_queue.Empty()) << "Ready queue must not be empty"; DCHECK(lock) << "Execute session lock must be set"; bool has_runner = state->runner != nullptr; @@ -240,18 +278,18 @@ void ThunkExecutor::Execute(ExecuteState* state, // Threshold for splitting ready queue into separate thunk executor tasks. int64_t split_threshold = params.session.split_threshold(); - for (int64_t i = 0; i < ready_queue.size(); ++i) { - NodeId id = ready_queue[i]; - ExecuteState::Node& node = state->nodes[id]; + while (!ready_queue.Empty()) { + NodeId id = ready_queue.Pop(); + ExecuteState::Node& node = state->node(id); int64_t cnt = node.counter.load(std::memory_order_acquire); DCHECK_EQ(cnt, 0) << "Node counter must be 0"; // Crash Ok // If we have multiple ready thunks, split the ready queue and offload // thunks processing to the task runner. - int64_t num_ready_thunks = ready_queue.size() - i; + int64_t num_ready_thunks = ready_queue.Size(); if (ABSL_PREDICT_FALSE(has_runner && num_ready_thunks > split_threshold)) { - SplitReadyQueue(state, params, /*start_index=*/i + 1, ready_queue); + SplitReadyQueue(state, params, ready_queue, split_threshold); } // Execute thunk for the given node id. If execution is aborted, we keep @@ -259,7 +297,7 @@ void ThunkExecutor::Execute(ExecuteState* state, Thunk& thunk = *state->executor->thunk_sequence_[id]; tsl::AsyncValueRef execute_event = ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed)) - ? Thunk::OkExecuteEvent() + ? Thunk::OkExecuteEventSingleton() : thunk.Execute(params); if (ABSL_PREDICT_TRUE(execute_event.IsAvailable())) { @@ -277,13 +315,13 @@ void ThunkExecutor::Execute(ExecuteState* state, // the same execute session. execute_event.AndThen([¶ms, &node, state, execute_event = execute_event.AsPtr(), + ready_queue = ready_queue.CreateEmptyReadyQueue(), lock = params.session.Join()]() mutable { - ReadyQueue ready_queue; state->executor->ProcessOutEdges(state, execute_event, node, ready_queue); // If ready queue is empty it might mean that we have completed an // execution and destroyed the `state`. - if (ABSL_PREDICT_TRUE(!ready_queue.empty())) { + if (ABSL_PREDICT_TRUE(!ready_queue.Empty())) { state->executor->Execute(state, params, std::move(ready_queue), std::move(lock)); } @@ -292,17 +330,17 @@ void ThunkExecutor::Execute(ExecuteState* state, } } +template inline ABSL_ATTRIBUTE_ALWAYS_INLINE void ThunkExecutor::SplitReadyQueue( ExecuteState* state, const Thunk::ExecuteParams& params, - int64_t start_index, ReadyQueue& ready_queue) { + ReadyQueue& ready_queue, int64_t split_threshold) { DCHECK(state->runner) << "TaskRunner must be set"; - int64_t end_index = ready_queue.size(); // We use recursive work splitting to push the tail of the ready queue to // the task runner. Recursive work splitting creates a more uniform work // distribution across the task runner threads and avoids a situation when // we have a long tail of work that is processed by a single thread. - while (end_index > start_index) { + while (ready_queue.Size() > split_threshold) { // Try to acquire a lock to offload ready thunks to the task runner. If // we can't get a lock, we will keep processing the ready queue in the // current thread as it means that we have enough concurrent workers @@ -312,22 +350,16 @@ inline ABSL_ATTRIBUTE_ALWAYS_INLINE void ThunkExecutor::SplitReadyQueue( break; } - // Execute [mid_index, end_index) nodes in the task runner. - int64_t mid_index = (start_index + end_index) / 2; - (*state->runner)([¶ms, state, - ready_queue = ReadyQueue(ready_queue.begin() + mid_index, - ready_queue.begin() + end_index), + // Execute half of the ready queue nodes in the task runner. + (*state->runner)([¶ms, state, ready_queue = ready_queue.PopHalf(), lock = std::move(task_runner_lock)]() mutable { state->executor->Execute(state, params, std::move(ready_queue), std::move(lock)); }); - end_index = mid_index; } - - // Erase ready nodes passed to the task runner. - ready_queue.erase(ready_queue.begin() + end_index, ready_queue.end()); } +template void ThunkExecutor::ProcessOutEdges( ExecuteState* state, tsl::AsyncValuePtr node_event, ExecuteState::Node& node, ReadyQueue& ready_queue) { @@ -346,11 +378,11 @@ void ThunkExecutor::ProcessOutEdges( // Append ready nodes to the back of the ready queue. for (NodeId out_edge : *node.out_edges) { - ExecuteState::Node& out_node = state->nodes[out_edge]; + ExecuteState::Node& out_node = state->node(out_edge); int64_t cnt = out_node.counter.fetch_sub(1, std::memory_order_release); - CHECK_GE(cnt, 1) << "Node counter can't drop below 0"; // Crash Ok - if (cnt == 1) ready_queue.push_back(out_edge); + DCHECK_GE(cnt, 1) << "Node counter can't drop below 0"; + if (cnt == 1) ready_queue.Push(out_edge); } // Drop the pending sink nodes counter if the node is a sink. @@ -367,7 +399,7 @@ void ThunkExecutor::ProcessOutEdges( if (ABSL_PREDICT_FALSE(state->abort.load(std::memory_order_relaxed))) { auto take_error = [&] { absl::MutexLock lock(&state->abort_mutex); - CHECK(!state->abort_status.ok()) // Crash Ok + DCHECK(!state->abort_status.ok()) << "Abort status must be set if execution is aborted"; return std::move(state->abort_status); }; @@ -401,7 +433,7 @@ static int64_t EraseEdge(ThunkExecutor::NodeDef& from, return 0; } -int64_t ThunkExecutor::TransitiveReduction() { +int64_t ThunkExecutor::RunTransitiveReductionAndUpdatePriorities() { int64_t num_erased_edges = 0; // Keep workspace for DFS traversal between iterations. @@ -424,11 +456,11 @@ int64_t ThunkExecutor::TransitiveReduction() { stack.clear(); visited.assign(nodes_defs_.size(), false); - // Initialize stack with nodes reachable via immediate out nodes. We don't - // need to add source node and immediate out nodes to the visited set - // because graph is acyclic and we don't visit them again. + // Initialize stack with nodes reachable via immediate out nodes. We mark + // immediate out nodes as visited to correctly compute node priority below. for (int64_t out_id : source_node.out_edges) { NodeDef& out_node = nodes_defs_[out_id]; + visited[out_id] = true; for (int64_t start_id : out_node.out_edges) add_to_stack(start_id); } @@ -442,6 +474,9 @@ int64_t ThunkExecutor::TransitiveReduction() { for (int64_t out_id : node.out_edges) add_to_stack(out_id); } + + // Set node priority to the number of visited nodes in the DFS traversal. + source_node.priority = absl::c_count(visited, true); } return num_erased_edges; @@ -449,11 +484,11 @@ int64_t ThunkExecutor::TransitiveReduction() { std::string ThunkExecutor::ToString() const { std::string str = absl::StrFormat( - "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", - thunk_sequence_.size(), source_.size(), sink_.size()); + "ThunkExecutor: #thunks=%d #source_nodes=%d #sink_nodes=%d", num_thunks_, + source_.size(), sink_.size()); // Collect names of `in_edges`. - std::vector> in_edges(thunk_sequence_.size()); + std::vector> in_edges(num_thunks_); for (const auto& node_def : nodes_defs_) { for (NodeId in_edge : node_def.in_edges) { in_edges[node_def.id].push_back(thunk_sequence_[in_edge]->info().op_name); @@ -461,18 +496,92 @@ std::string ThunkExecutor::ToString() const { } // Print thunks with a list of their dependencies; - for (NodeId i = 0; i < thunk_sequence_.size(); ++i) { + for (NodeId i = 0; i < num_thunks_; ++i) { const Thunk& thunk = *thunk_sequence_[i]; bool is_source = absl::c_find(source_, i) != source_.end(); bool is_sink = absl::c_find(sink_, i) != sink_.end(); - absl::StrAppendFormat( - &str, - "\n thunk #%05d: op_name=%s, dependencies=[%s], source=%v, sink=%v", i, - thunk.info().op_name, absl::StrJoin(in_edges[i], ", "), is_source, - is_sink); + absl::StrAppendFormat(&str, + "\n thunk #%05d: op_name=%s, dependencies=[%s], " + "source=%v, sink=%v, priority=%d", + i, thunk.info().op_name, + absl::StrJoin(in_edges[i], ", "), is_source, is_sink, + nodes_defs_[i].priority); } return str; } +ThunkExecutor::FifoReadyQueue::FifoReadyQueue( + absl::Span ready_nodes) + : queue_(ready_nodes.begin(), ready_nodes.end()) {} + +void ThunkExecutor::FifoReadyQueue::Push(NodeId id) { queue_.push_back(id); } + +ThunkExecutor::NodeId ThunkExecutor::FifoReadyQueue::Pop() { + DCHECK(!Empty()) << "Queue must not be empty"; + return queue_[head_++]; +} + +ThunkExecutor::FifoReadyQueue ThunkExecutor::FifoReadyQueue::PopHalf() { + DCHECK(!Empty()) << "Queue must not be empty"; + auto mid = queue_.begin() + head_ + Size() / 2; + FifoReadyQueue popped(absl::MakeConstSpan(&*mid, queue_.end() - mid)); + queue_.resize(mid - queue_.begin()); + return popped; +} + +size_t ThunkExecutor::FifoReadyQueue::Size() const { + return queue_.size() - head_; +} + +bool ThunkExecutor::FifoReadyQueue::Empty() const { + return head_ == queue_.size(); +} + +ThunkExecutor::FifoReadyQueue +ThunkExecutor::FifoReadyQueue::CreateEmptyReadyQueue() const { + return FifoReadyQueue(absl::Span()); +} + +ThunkExecutor::PriorityReadyQueue::PriorityReadyQueue( + absl::Span nodes_defs, absl::Span ready_nodes) + : nodes_defs_(nodes_defs), + queue_(ready_nodes.begin(), ready_nodes.end(), Compare{nodes_defs}) {} + +void ThunkExecutor::PriorityReadyQueue::Push(NodeId id) { queue_.push(id); } + +ThunkExecutor::NodeId ThunkExecutor::PriorityReadyQueue::Pop() { + DCHECK(!Empty()) << "Queue must not be empty"; + NodeId id = queue_.top(); + queue_.pop(); + return id; +} + +ThunkExecutor::PriorityReadyQueue ThunkExecutor::PriorityReadyQueue::PopHalf() { + DCHECK(!Empty()) << "Queue must not be empty"; + int64_t keep_top_nodes = queue_.size() / 2; + + // First pop nodes with highest priority from the queue. + PriorityReadyQueue popped(nodes_defs_, {}); + while (keep_top_nodes-- > 0) { + popped.queue_.push(queue_.top()); + queue_.pop(); + } + + // Swap popped nodes with remaining nodes, to return to the caller nodes with + // smaller priorities, and keep higher priority nodes in the queue. + popped.queue_.swap(queue_); + + return popped; +} + +size_t ThunkExecutor::PriorityReadyQueue::Size() const { return queue_.size(); } + +bool ThunkExecutor::PriorityReadyQueue::Empty() const { return queue_.empty(); } + +ThunkExecutor::PriorityReadyQueue +ThunkExecutor::PriorityReadyQueue::CreateEmptyReadyQueue() const { + return PriorityReadyQueue(nodes_defs_, {}); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h similarity index 64% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor.h rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor.h index a48dd843871d4c..5ba15b0432b504 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor.h @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ #include #include #include #include #include +#include #include +#include #include #include "absl/base/thread_annotations.h" @@ -31,11 +33,26 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { +namespace internal { +// Clang does not allow defining a nested struct with member initializer, as +// a workaround we define a struct in internal namespace and create an alias. +struct ThunkExecutorOptions { + // If all thunks in a sequence use buffers of size less than or equal to + // `execute_sequential_buffer_threshold`, we mark execution as sequential, as + // concurrency overheads will likely dominate the overall execution time. + size_t execute_sequential_buffer_threshold = 512; + + // Use priority ready queue to execute nodes according to their priority. By + // default we use FIFO ready queue. + bool use_priority_ready_queue = false; +}; +} // namespace internal + // A dataflow-style (run when ready) executor for a ThunkSequence that depends // on buffer uses to build a DAG defining execution order. At run time executes // thunks concurrently in a given thread pool. @@ -44,6 +61,7 @@ class ThunkExecutor { using BufferUses = Thunk::BufferUses; using ResourceUses = Thunk::ResourceUses; using ExecuteEvent = Thunk::ExecuteEvent; + using Options = internal::ThunkExecutorOptions; // Nodes identified by their index in the captured ThunkSequence. using NodeId = int64_t; @@ -53,11 +71,13 @@ class ThunkExecutor { ThunkExecutor(ThunkExecutor&&) = default; ThunkExecutor& operator=(ThunkExecutor&&) = default; - static absl::StatusOr Create(ThunkSequence thunk_sequence); + static absl::StatusOr Create( + ThunkSequence thunk_sequence, const Options& options = Options()); // NodeDef defines an execution order for all thunks in a sequence. struct NodeDef { NodeId id = kInvalidNodeId; + int64_t priority = 0; std::vector in_edges; std::vector out_edges; }; @@ -83,6 +103,57 @@ class ThunkExecutor { bool is_sequential() const { return is_sequential_; } + // A ready queue that executes nodes in FIFO order. + class FifoReadyQueue { + public: + explicit FifoReadyQueue(absl::Span ready_nodes); + + void Push(NodeId id); + + NodeId Pop(); + FifoReadyQueue PopHalf(); + + size_t Size() const; + bool Empty() const; + + FifoReadyQueue CreateEmptyReadyQueue() const; + + private: + absl::InlinedVector queue_; + size_t head_ = 0; + }; + + // A ready queue that executes nodes sorted by NodeDef priority. + class PriorityReadyQueue { + public: + PriorityReadyQueue(absl::Span nodes_defs, + absl::Span ready_nodes); + + void Push(NodeId id); + + NodeId Pop(); + PriorityReadyQueue PopHalf(); + + size_t Size() const; + bool Empty() const; + + PriorityReadyQueue CreateEmptyReadyQueue() const; + + private: + struct Compare { + bool operator()(NodeId a, NodeId b) const { + return nodes_defs[a].priority < nodes_defs[b].priority; + } + absl::Span nodes_defs; + }; + + using InlinedPriorityQueue = + std::priority_queue, Compare>; + + absl::Span nodes_defs_; + InlinedPriorityQueue queue_; + }; + private: // Align all atomic counters to a cache line boundary to avoid false // sharing between multiple worker threads. @@ -93,23 +164,32 @@ class ThunkExecutor { 64; #endif - using ReadyQueue = absl::InlinedVector; - // A struct to keep the state of a running ThunkExecutor. struct ExecuteState { // At run time NodeDef instantiated as a Node with an atomic counter that // drops to zero when all `in_edges` are ready. struct Node { + explicit Node(const NodeDef& node_def); + alignas(kAtomicAlignment) std::atomic counter; const std::vector* out_edges; }; + static_assert(std::is_trivially_destructible_v, + "Node must be trivially destructible"); + + // We use indirection via NodeStorage to be able to allocate uninitialized + // memory and do not pay the cost of default initializing all nodes. + using NodeStorage = std::aligned_storage_t; + ExecuteState(ThunkExecutor* executor, Thunk::TaskRunner* runner); + Node& node(NodeId id) { return *reinterpret_cast(&nodes[id]); } + ThunkExecutor* executor; Thunk::TaskRunner* runner; - absl::FixedArray nodes; + absl::FixedArray nodes; tsl::AsyncValueRef execute_event; // Once the number of pending sink nodes drops to zero, the execution is @@ -123,40 +203,49 @@ class ThunkExecutor { absl::Status abort_status ABSL_GUARDED_BY(abort_mutex); }; - ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs); + ThunkExecutor(ThunkSequence thunk_sequence, std::vector nodes_defs, + const Options& options); // Executes thunks sequentially starting from the first thunk in the sequence. tsl::AsyncValueRef ExecuteSequential( const Thunk::ExecuteParams& params); // Resumes sequential thunk execution starting from the given index. - void ResumeExecuteSequential(int64_t index, + using ThunkIterator = typename ThunkSequence::iterator; + void ResumeExecuteSequential(ThunkIterator it, const Thunk::ExecuteParams& params, tsl::AsyncValueRef event); // Executes nodes in the ready queue with given thunk parameters. + template void Execute(ExecuteState* state, const Thunk::ExecuteParams& params, ReadyQueue ready_queue, Thunk::ExecuteSession::Lock lock); // Splits ready queue starting from `start_index` into ThunkExecutor tasks and // offloads them to the task runner. + template void SplitReadyQueue(ExecuteState* state, const Thunk::ExecuteParams& params, - int64_t start_index, ReadyQueue& ready_queue); + ReadyQueue& ready_queue, int64_t split_threshold); // Processes out edges of a completed `node` and updates `ready_queue` with // nodes that are ready to execute. If `event` is in error state, aborts the // execution and records the error status to forward it to the caller. + template void ProcessOutEdges(ExecuteState* state, tsl::AsyncValuePtr node_event, ExecuteState::Node& node, ReadyQueue& ready_queue); - // Runs a transitive reduction on the NodeDef graph to remove redundant edges. - // Returns the number of removed edges. + // Runs a transitive reduction on the NodeDef graph to remove redundant edges, + // and updates nodes priorities. Returns the number of removed edges. // // See: https://en.wikipedia.org/wiki/Transitive_reduction - int64_t TransitiveReduction(); + int64_t RunTransitiveReductionAndUpdatePriorities(); ThunkSequence thunk_sequence_; + Options options_; + + int64_t num_thunks_; + std::vector nodes_defs_; std::vector source_; @@ -170,4 +259,4 @@ class ThunkExecutor { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_EXECUTOR_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_EXECUTOR_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc similarity index 73% rename from third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc index 697a47d83a4e8d..ebe98304b9f6f4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_executor_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_executor_test.cc @@ -13,9 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk_executor.h" - -#define EIGEN_USE_THREADS +#include "xla/backends/cpu/runtime/thunk_executor.h" #include #include @@ -28,15 +26,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/types/span.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -48,6 +46,10 @@ limitations under the License. #include "tsl/platform/test_benchmark.h" #include "tsl/platform/threadpool.h" +#define EIGEN_USE_THREADS + +#include "unsupported/Eigen/CXX11/Tensor" + namespace xla::cpu { namespace { @@ -210,6 +212,130 @@ AddI32Thunk::ResourceUses AddI32Thunk::resource_uses() const { : ResourceUses{}; } +static ThunkExecutor::Options OptionsForTest() { + // Override small buffers threshold to make sure that we test all execution + // paths, because in test we always use small buffers below the default + // threshold of `512`. + return ThunkExecutor::Options{/*execute_sequential_buffer_threshold=*/0}; +} + +TEST(ThunkExecutorTest, FifoReadyQueueTest) { + ThunkExecutor::FifoReadyQueue queue({}); + + // Check basic queue properties. + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + queue.Push(1); + queue.Push(2); + queue.Push(3); + + EXPECT_EQ(queue.Size(), 3); + + EXPECT_EQ(queue.Pop(), 1); + EXPECT_EQ(queue.Pop(), 2); + EXPECT_EQ(queue.Pop(), 3); + + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + // Prepare queue for PopHalf test case. + queue.Push(1); + queue.Push(2); + queue.Push(3); + + // Pop half of the queue. + ThunkExecutor::FifoReadyQueue half0 = queue.PopHalf(); + EXPECT_EQ(half0.Size(), 2); + EXPECT_EQ(half0.Pop(), 2); + EXPECT_EQ(half0.Pop(), 3); + + // Check that the rest is still in the queue. + EXPECT_EQ(queue.Size(), 1); + + // Pop the rest of the queue. + ThunkExecutor::FifoReadyQueue half1 = queue.PopHalf(); + EXPECT_EQ(half1.Size(), 1); + + // Check that all nodes were returned from PopHalf. + EXPECT_EQ(queue.Size(), 0); + + // Add 5 elements to test Pop followed by PopHalf. + queue.Push(1); + queue.Push(2); + queue.Push(3); + queue.Push(4); + queue.Push(5); + + EXPECT_EQ(queue.Pop(), 1); + + // Check that PopHalf returns 2 last nodes. + ThunkExecutor::FifoReadyQueue half2 = queue.PopHalf(); + EXPECT_EQ(half2.Size(), 2); + EXPECT_EQ(half2.Pop(), 4); + EXPECT_EQ(half2.Pop(), 5); +} + +TEST(ThunkExecutorTest, PriorityReadyQueueTest) { + std::vector nodes_defs(16); + for (size_t i = 0; i < nodes_defs.size(); ++i) { + nodes_defs[i].priority = i; + } + + ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {}); + // Check basic queue properties. + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + queue.Push(1); + queue.Push(3); + queue.Push(2); + + EXPECT_EQ(queue.Pop(), 3); + EXPECT_EQ(queue.Pop(), 2); + EXPECT_EQ(queue.Pop(), 1); + + EXPECT_TRUE(queue.Empty()); + EXPECT_EQ(queue.Size(), 0); + + // Prepare queue for PopHalf test case. + queue.Push(2); + queue.Push(1); + queue.Push(3); + + // Pop half of the queue. + ThunkExecutor::PriorityReadyQueue half0 = queue.PopHalf(); + EXPECT_EQ(half0.Size(), 2); + EXPECT_EQ(half0.Pop(), 2); + EXPECT_EQ(half0.Pop(), 1); + + // Check that the rest is still in the queue. + EXPECT_EQ(queue.Size(), 1); + + // Pop the rest of the queue. + ThunkExecutor::PriorityReadyQueue half1 = queue.PopHalf(); + EXPECT_EQ(half1.Size(), 1); + EXPECT_EQ(half1.Pop(), 3); + + // Check that all nodes were returned from PopHalf. + EXPECT_EQ(queue.Size(), 0); + + // Add 5 elements to test Pop followed by PopHalf. + queue.Push(4); + queue.Push(3); + queue.Push(5); + queue.Push(1); + queue.Push(2); + + EXPECT_EQ(queue.Pop(), 5); + + // Check that PopHalf returns 2 last nodes. + ThunkExecutor::PriorityReadyQueue half2 = queue.PopHalf(); + EXPECT_EQ(half2.Size(), 2); + EXPECT_EQ(half2.Pop(), 2); + EXPECT_EQ(half2.Pop(), 1); +} + TEST(ThunkExecutorTest, DependencyOrdering) { BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0); @@ -222,12 +348,17 @@ TEST(ThunkExecutorTest, DependencyOrdering) { sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1})); sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_FALSE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0, 1)); EXPECT_THAT(executor.sink(), ElementsAre(2)); + + EXPECT_EQ(executor.node_def(0).priority, 1); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, SequentialOrdering) { @@ -239,12 +370,17 @@ TEST(ThunkExecutorTest, SequentialOrdering) { sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice})); sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(2)); + + EXPECT_EQ(executor.node_def(0).priority, 2); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, ResourceOrdering) { @@ -261,12 +397,16 @@ TEST(ThunkExecutorTest, ResourceOrdering) { /*trace=*/nullptr, /*use_shared_resource=*/true)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_TRUE(executor.is_sequential()); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(1)); + + EXPECT_EQ(executor.node_def(0).priority, 1); + EXPECT_EQ(executor.node_def(1).priority, 0); } TEST(ThunkExecutorTest, TransitiveReduction) { @@ -278,8 +418,9 @@ TEST(ThunkExecutorTest, TransitiveReduction) { sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice})); sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice})); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); EXPECT_THAT(executor.source(), ElementsAre(0)); EXPECT_THAT(executor.sink(), ElementsAre(2)); @@ -288,6 +429,10 @@ TEST(ThunkExecutorTest, TransitiveReduction) { EXPECT_THAT(executor.node_def(1).in_edges, ElementsAre(0)); EXPECT_THAT(executor.node_def(1).out_edges, ElementsAre(2)); EXPECT_THAT(executor.node_def(2).in_edges, ElementsAre(1)); + + EXPECT_EQ(executor.node_def(0).priority, 2); + EXPECT_EQ(executor.node_def(1).priority, 1); + EXPECT_EQ(executor.node_def(2).priority, 0); } TEST(ThunkExecutorTest, Execute) { @@ -304,8 +449,9 @@ TEST(ThunkExecutorTest, Execute) { sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1}, &trace)); sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2}, &trace)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(sequence))); + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(sequence), OptionsForTest())); std::vector data(20, 1); // shared src and dst allocation @@ -320,7 +466,7 @@ TEST(ThunkExecutorTest, Execute) { Thunk::ExecuteParams params = {nullptr, &allocations}; params.task_runner = &task_runner; params.session = - Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/1); + Thunk::ExecuteSession(/*max_workers=*/8, /*split_threshold=*/0); auto execute_event = executor.Execute(params); @@ -420,11 +566,11 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks, // and optionally uses a thread pool to execute thunk executor tasks. class ThunkExecutorStressTest : public testing::TestWithParam< - std::tuple> { + std::tuple> { public: void SetUp() override { auto& [num_thunks, use_task_runner, use_device, shared_resource_use, - inject_errors] = GetParam(); + inject_errors, use_priority_ready_queue] = GetParam(); use_task_runner_ = use_task_runner; use_device_ = use_device; @@ -464,15 +610,21 @@ class ThunkExecutorStressTest TEST_P(ThunkExecutorStressTest, Execute) { auto [num_thunks, use_task_runner, use_device, shared_resource_use, - inject_errors] = GetParam(); + inject_errors, use_priority_ready_queue] = GetParam(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr g, GenerateThunkSequence(/*num_elements=*/1024, num_thunks, shared_resource_use, inject_errors)); - TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor, - ThunkExecutor::Create(std::move(g->sequence))); + ThunkExecutor::Options executor_options = { + /*execute_sequential_buffer_threshold=*/0, + /*use_priority_ready_queue=*/use_priority_ready_queue, + }; + + TF_ASSERT_OK_AND_ASSIGN( + ThunkExecutor executor, + ThunkExecutor::Create(std::move(g->sequence), executor_options)); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations, nullptr, device(), @@ -502,12 +654,95 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(SharedResourceUse::kNo, SharedResourceUse::kAll, SharedResourceUse::kRandom), - /*inject_errors=*/testing::Bool())); + /*inject_errors=*/testing::Bool(), + /*use_priority_ready_queue=*/testing::Bool())); //===----------------------------------------------------------------------===// // Performance benchmarks below //===----------------------------------------------------------------------===// +static void BM_FifoReadyQueuePushPop(benchmark::State& state) { + ThunkExecutor::FifoReadyQueue queue({}); + const size_t num_push_pop = state.range(0); + + for (auto _ : state) { + for (int i = 0; i < num_push_pop; ++i) { + queue.Push(i); + } + for (int i = 0; i < num_push_pop; ++i) { + benchmark::DoNotOptimize(queue.Pop()); + } + } +} + +static void BM_FifoReadyQueuePushPopHalf(benchmark::State& state) { + ThunkExecutor::FifoReadyQueue queue({}); + const size_t num_push_pop = state.range(0); + + for (auto _ : state) { + for (int i = 0; i < num_push_pop; ++i) { + queue.Push(i); + } + benchmark::DoNotOptimize(queue.PopHalf()); + } +} + +static void BM_PriorityReadyQueuePushPop(benchmark::State& state) { + std::vector nodes_defs(16); + for (size_t i = 0; i < nodes_defs.size(); ++i) { + nodes_defs[i].priority = i; + } + + std::default_random_engine rng; + absl::c_shuffle(nodes_defs, rng); + + ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {}); + const size_t num_push_pop = state.range(0); + + for (auto _ : state) { + for (int i = 0; i < num_push_pop; ++i) { + queue.Push(i); + } + for (int i = 0; i < num_push_pop; ++i) { + benchmark::DoNotOptimize(queue.Pop()); + } + } +} + +static void BM_PriorityReadyQueuePushPopHalf(benchmark::State& state) { + std::vector nodes_defs(16); + for (size_t i = 0; i < nodes_defs.size(); ++i) { + nodes_defs[i].priority = i; + } + + std::default_random_engine rng; + absl::c_shuffle(nodes_defs, rng); + + ThunkExecutor::PriorityReadyQueue queue(nodes_defs, {}); + const size_t num_push_pop = state.range(0); + + for (auto _ : state) { + for (int i = 0; i < num_push_pop; ++i) { + queue.Push(i); + } + benchmark::DoNotOptimize(queue.PopHalf()); + } +} + +#define BENCHMARK_READY_QUEUE(name) \ + BENCHMARK(name) \ + ->MeasureProcessCPUTime() \ + ->Arg(1) \ + ->Arg(2) \ + ->Arg(4) \ + ->Arg(8) \ + ->Arg(16) + +BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPop); +BENCHMARK_READY_QUEUE(BM_FifoReadyQueuePushPopHalf); +BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPop); +BENCHMARK_READY_QUEUE(BM_PriorityReadyQueuePushPopHalf); + static void BM_SequentialThunkExecutor(benchmark::State& state) { const size_t num_thunks = state.range(0); @@ -516,7 +751,8 @@ static void BM_SequentialThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kAll, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -535,7 +771,8 @@ static void BM_SyncThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kNo, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); Thunk::ExecuteParams params = {nullptr, &allocations}; @@ -558,7 +795,8 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) { /*shared_resource_use=*/SharedResourceUse::kNo, /*inject_errors=*/false) .value(); - auto e = ThunkExecutor::Create(std::move(g->sequence)).value(); + auto e = + ThunkExecutor::Create(std::move(g->sequence), OptionsForTest()).value(); BufferAllocations allocations(g->buffers); diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/thunk_test.cc index 510d2c2f44025a..1b20de023d91f8 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" #include #include @@ -34,8 +34,8 @@ class ThunkExecuteStateTestHelper : public Thunk { } }; -TEST(ThunkTest, OkExecuteEvent) { - auto event = Thunk::OkExecuteEvent(); +TEST(ThunkTest, OkExecuteEventSingleton) { + auto event = Thunk::OkExecuteEventSingleton(); ASSERT_TRUE(event.IsConcrete()); } diff --git a/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h similarity index 88% rename from third_party/xla/xla/service/cpu/runtime/thunk_testlib.h rename to third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h index 154c2b28972701..4da0650efee7c4 100644 --- a/third_party/xla/xla/service/cpu/runtime/thunk_testlib.h +++ b/third_party/xla/xla/backends/cpu/runtime/thunk_testlib.h @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ -#define XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ +#define XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ #include "absl/status/status.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -59,4 +59,4 @@ class ResourceUseThunk : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_THUNK_TESTLIB_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_THUNK_TESTLIB_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/topk_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc similarity index 96% rename from third_party/xla/xla/service/cpu/runtime/topk_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc index 6c238224166a52..0c72933dc1a3aa 100644 --- a/third_party/xla/xla/service/cpu/runtime/topk_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/topk_thunk.h" +#include "xla/backends/cpu/runtime/topk_thunk.h" #include #include @@ -21,8 +21,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/runtime_topk.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" diff --git a/third_party/xla/xla/service/cpu/runtime/topk_thunk.h b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.h similarity index 90% rename from third_party/xla/xla/service/cpu/runtime/topk_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/topk_thunk.h index 7b2bfb63502bfe..7e7fadb03852e7 100644 --- a/third_party/xla/xla/service/cpu/runtime/topk_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/topk_thunk.h @@ -13,16 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ #include #include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -56,4 +56,4 @@ class TopKThunk final : public Thunk { } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_TOPK_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_TOPK_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk.cc b/third_party/xla/xla/backends/cpu/runtime/while_thunk.cc similarity index 54% rename from third_party/xla/xla/service/cpu/runtime/while_thunk.cc rename to third_party/xla/xla/backends/cpu/runtime/while_thunk.cc index 4e326b63a91706..6c1e81f5dee0d6 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk.cc +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk.cc @@ -13,19 +13,24 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" +#include #include #include +#include #include #include "absl/base/optimization.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "tsl/platform/logging.h" @@ -36,70 +41,165 @@ namespace xla::cpu { absl::StatusOr> WhileThunk::Create( Info info, BufferAllocation::Slice cond_buffer, ThunkSequence cond_sequence, - ThunkSequence body_sequence) { + ThunkSequence body_sequence, std::optional trip_count) { TF_ASSIGN_OR_RETURN(ThunkExecutor cond_executor, ThunkExecutor::Create(std::move(cond_sequence))); TF_ASSIGN_OR_RETURN(ThunkExecutor body_executor, ThunkExecutor::Create(std::move(body_sequence))); return absl::WrapUnique(new WhileThunk(std::move(info), cond_buffer, std::move(cond_executor), - std::move(body_executor))); + std::move(body_executor), trip_count)); } WhileThunk::WhileThunk(Info info, BufferAllocation::Slice cond_buffer, - ThunkExecutor cond_executor, ThunkExecutor body_executor) + ThunkExecutor cond_executor, ThunkExecutor body_executor, + std::optional trip_count) : Thunk(Kind::kWhile, std::move(info)), cond_buffer_(cond_buffer), cond_executor_(std::move(cond_executor)), - body_executor_(std::move(body_executor)) {} + body_executor_(std::move(body_executor)), + trip_count_(trip_count) {} -tsl::AsyncValueRef WhileThunk::ExecuteAsync( +tsl::AsyncValueRef WhileThunk::Execute( + const ExecuteParams& params) { + tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); + + VLOG(3) << absl::StreamFormat( + "While: #trip_count=%s", + trip_count_.has_value() ? absl::StrCat(*trip_count_) : "unknown"); + + // Most of the while loops in XLA have statically known trip count. + if (ABSL_PREDICT_TRUE(trip_count_.has_value())) { + return ExecuteForLoop(params, *trip_count_); + } + + const BufferAllocations* allocations = params.buffer_allocations; + + se::DeviceMemoryBase cond_data; + if (ShouldCheckBufferSlices()) { + TF_ASSIGN_OR_RETURN(cond_data, allocations->GetDeviceAddress(cond_buffer_)); + } else { + cond_data = allocations->GetDeviceAddressUnchecked(cond_buffer_); + } + + bool* condition = reinterpret_cast(cond_data.opaque()); + return ExecuteWhileLoop(params, condition); +} + +tsl::AsyncValueRef WhileThunk::ExecuteForLoop( + const ExecuteParams& params, int64_t trip_count) { + for (int64_t loop_counter = 0; loop_counter < trip_count; ++loop_counter) { + auto body_event = body_executor_.Execute(params); + + // If loop iteration has not completed yet, switch to async execution mode + // using `body_event` as a dependency and continue the loop iteration + // starting from `loop_counter + 1`. + if (ABSL_PREDICT_FALSE(!body_event.IsAvailable())) { + return ExecuteAsyncForLoop(params, std::move(body_event), + loop_counter + 1, trip_count); + } + + if (ABSL_PREDICT_FALSE(body_event.IsError())) { + return body_event.GetError(); + } + + DCHECK(body_event.IsConcrete()); + } + + // Successfully completed `trip_count` while loop iterations. + return OkExecuteEvent(); +} + +tsl::AsyncValueRef WhileThunk::ExecuteWhileLoop( + const ExecuteParams& params, bool* condition) { + // Execute `cond` thunk sequence to initialize the loop condition. + auto init_event = cond_executor_.Execute(params); + + // If we don't know if we should continue or not, switch to async execution + // mode using `init_event` as a dependency. + if (ABSL_PREDICT_FALSE(!init_event.IsAvailable())) { + return ExecuteAsyncWhileLoop(params, std::move(init_event), condition); + } + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(init_event.IsError())) { + return init_event.GetError(); + } + + DCHECK(init_event.IsConcrete()); + + while (*condition) { + auto body_event = body_executor_.Execute(params); + auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { + return cond_executor_.Execute(params); + }); + + // If loop iteration has not completed yet, switch to async execution mode + // using `cond_event` as a dependency and maybe continue the loop + // iteration (if `condition` is still true). + if (ABSL_PREDICT_FALSE(!cond_event.IsAvailable())) { + return ExecuteAsyncWhileLoop(params, std::move(cond_event), condition); + } + + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + return cond_event.GetError(); + } + + // At this point `*condition` should have been updated and we may continue + // executing the while loop in the current thread. + DCHECK(cond_event.IsConcrete()); + } + + // Successfully completed while loop iterations. + return OkExecuteEvent(); +} + +tsl::AsyncValueRef WhileThunk::ExecuteAsyncForLoop( const ExecuteParams& params, tsl::AsyncValueRef dependency, - bool* condition) { + int64_t loop_counter, int64_t trip_count) { auto event = tsl::MakeConstructedAsyncValueRef(); // Allocate while loop iteration function on heap so we can detach its life // time from the caller stack. - auto loop_fn = std::make_shared>(); - *loop_fn = [this, condition, ¶ms, event, - loop = loop_fn.get()](absl::Status status) { + auto loop_fn = std::make_shared>(); + *loop_fn = [this, trip_count, ¶ms, event, loop = loop_fn.get()]( + int64_t loop_counter, absl::Status status) { // Dependency completed with an error. Forward it to the result event. if (ABSL_PREDICT_FALSE(!status.ok())) { event.SetError(std::move(status)); return; } - while (*condition) { + for (; loop_counter < trip_count; ++loop_counter) { auto body_event = body_executor_.Execute(params); - auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { - return cond_executor_.Execute(params); - }); - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(cond_event.IsError())) { - event.SetError(cond_event.GetError()); + // If loop iteration has not completed yet, continue execution + // asynchronously starting from `loop_counter + 1`. + if (!body_event.IsAvailable()) { + body_event.AndThen([loop, loop_counter](absl::Status status) { + (*loop)(loop_counter + 1, std::move(status)); + }); return; } - // If we don't know yet wether we should execute the next iteration or - // not, attach `AndThen` continuation to the `cond_event`. - if (!cond_event.IsAvailable()) { - cond_event.AndThen( - [loop](absl::Status status) { (*loop)(std::move(status)); }); + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(body_event.IsError())) { + event.SetError(body_event.GetError()); return; } - // At this point `*condition` should have been updated and we may continue - // executing the while loop in the current thread. - DCHECK(cond_event.IsAvailable()); + DCHECK(body_event.IsConcrete()); } - // Successfully completed while loop iterations. + // Successfully completed `trip_count` while loop iterations. event.SetStateConcrete(); }; // Kick-off loop execution once dependency event is available. - dependency.AndThen(*loop_fn); + dependency.AndThen([loop_counter, loop = loop_fn.get()](absl::Status status) { + (*loop)(loop_counter, std::move(status)); + }); // Keep `loop_fn` alive until the end of the while loop execution. event.AndThen([loop_fn = std::move(loop_fn)]() {}); @@ -107,54 +207,60 @@ tsl::AsyncValueRef WhileThunk::ExecuteAsync( return event; } -tsl::AsyncValueRef WhileThunk::Execute( - const ExecuteParams& params) { - tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase cond_data, - params.buffer_allocations->GetDeviceAddress(cond_buffer_)); - - bool* condition = reinterpret_cast(cond_data.opaque()); +tsl::AsyncValueRef WhileThunk::ExecuteAsyncWhileLoop( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + bool* condition) { + auto event = tsl::MakeConstructedAsyncValueRef(); - // Execute `cond` thunk sequence to initialize the loop condition. - auto init_event = cond_executor_.Execute(params); + // Allocate while loop iteration function on heap so we can detach its life + // time from the caller stack. + auto loop_fn = std::make_shared>(); + *loop_fn = [this, condition, ¶ms, event, + loop = loop_fn.get()](absl::Status status) { + // Dependency completed with an error. Forward it to the result event. + if (ABSL_PREDICT_FALSE(!status.ok())) { + event.SetError(std::move(status)); + return; + } - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(init_event.IsError())) { - return init_event.GetError(); - } + while (*condition) { + auto body_event = body_executor_.Execute(params); + auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { + return cond_executor_.Execute(params); + }); - // If we don't know if we should continue or not, switch to async execution - // mode using `init_event` as a dependency. - if (ABSL_PREDICT_FALSE(!init_event.IsAvailable())) { - return ExecuteAsync(params, std::move(init_event), condition); - } + // If loop iteration has not completed yet, continue execution + // asynchronously (if `condition` is still true when it becomes ready). + if (!cond_event.IsAvailable()) { + cond_event.AndThen( + [loop](absl::Status status) { (*loop)(std::move(status)); }); + return; + } - while (*condition) { - auto body_event = body_executor_.Execute(params); - auto cond_event = body_event.FlatMap([this, ¶ms](ExecuteEvent) { - return cond_executor_.Execute(params); - }); + // Immediately forward error to the caller. + if (ABSL_PREDICT_FALSE(cond_event.IsError())) { + event.SetError(cond_event.GetError()); + return; + } - // Immediately forward error to the caller. - if (ABSL_PREDICT_FALSE(cond_event.IsError())) { - return cond_event.GetError(); + // At this point `*condition` should have been updated and we may continue + // executing the while loop in the current thread. + DCHECK(cond_event.IsConcrete()); } - // If we don't know if we should continue or not, switch to async execution - // mode using `cond_event` as a dependency. - if (ABSL_PREDICT_FALSE(!cond_event.IsAvailable())) { - return ExecuteAsync(params, std::move(cond_event), condition); - } + // Successfully completed while loop iterations. + event.SetStateConcrete(); + }; - // At this point `*condition` should have been updated and we may continue - // executing the while loop in the current thread. - DCHECK(cond_event.IsAvailable()); - } + // Kick-off loop execution once dependency event is available. + dependency.AndThen([loop = loop_fn.get()](absl::Status status) { + (*loop)(std::move(status)); + }); - // Successfully completed while loop iterations. - return OkExecuteEvent(); + // Keep `loop_fn` alive until the end of the while loop execution. + event.AndThen([loop_fn = std::move(loop_fn)]() {}); + + return event; } WhileThunk::BufferUses WhileThunk::buffer_uses() const { diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk.h b/third_party/xla/xla/backends/cpu/runtime/while_thunk.h similarity index 61% rename from third_party/xla/xla/service/cpu/runtime/while_thunk.h rename to third_party/xla/xla/backends/cpu/runtime/while_thunk.h index 9c5a7af272468c..c1de07de86ad52 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk.h +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk.h @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ +#ifndef XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ +#define XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ +#include #include +#include #include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" namespace xla::cpu { @@ -37,7 +39,8 @@ class WhileThunk final : public Thunk { public: static absl::StatusOr> Create( Info info, BufferAllocation::Slice cond_buffer, - ThunkSequence cond_sequence, ThunkSequence body_sequence); + ThunkSequence cond_sequence, ThunkSequence body_sequence, + std::optional trip_count = std::nullopt); tsl::AsyncValueRef Execute(const ExecuteParams& params) final; @@ -46,21 +49,38 @@ class WhileThunk final : public Thunk { private: WhileThunk(Info info, BufferAllocation::Slice cond_buffer, - ThunkExecutor cond_executor, ThunkExecutor body_executor); + ThunkExecutor cond_executor, ThunkExecutor body_executor, + std::optional trip_count); + + tsl::AsyncValueRef ExecuteForLoop(const ExecuteParams& params, + int64_t trip_count); + + tsl::AsyncValueRef ExecuteWhileLoop(const ExecuteParams& params, + bool* condition); // If `cond` or `body` thunk sequence return unavailable async values, then // we execute the while loop asynchronously by chaining `Execute` calls via // `AndThen` callbacks. This execution mode adds significant overheads, so we // try to avoid it when possible and run everything in the caller thread. - tsl::AsyncValueRef ExecuteAsync( + + tsl::AsyncValueRef ExecuteAsyncForLoop( + const ExecuteParams& params, tsl::AsyncValueRef dependency, + int64_t loop_counter, int64_t trip_count); + + tsl::AsyncValueRef ExecuteAsyncWhileLoop( const ExecuteParams& params, tsl::AsyncValueRef dependency, bool* condition); BufferAllocation::Slice cond_buffer_; ThunkExecutor cond_executor_; ThunkExecutor body_executor_; + + // Statically known trip count. If available, WhileThunk::Execute will not + // execute `cond_executor_` and simply call `body_executor_` `trip_count` + // times (effectively converting while loop into a for loop). + std::optional trip_count_; }; } // namespace xla::cpu -#endif // XLA_SERVICE_CPU_RUNTIME_WHILE_THUNK_H_ +#endif // XLA_BACKENDS_CPU_RUNTIME_WHILE_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc similarity index 77% rename from third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc rename to third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc index 5da7202f7d9b7f..d4b874a72b380f 100644 --- a/third_party/xla/xla/service/cpu/runtime/while_thunk_test.cc +++ b/third_party/xla/xla/backends/cpu/runtime/while_thunk_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/cpu/runtime/while_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" #include #include @@ -22,12 +22,12 @@ limitations under the License. #include #include +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_testlib.h" #include "xla/runtime/buffer_use.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_testlib.h" #include "xla/service/maybe_owning_device_memory.h" #include "xla/stream_executor/device_memory.h" #include "xla/tsl/concurrency/async_value_ref.h" @@ -203,5 +203,51 @@ TEST(WhileThunkTest, NonBlockingExecute) { EXPECT_EQ(counter[0], kNumIterations); } +TEST(WhileThunkTest, NonBlockingExecuteWithTripCount) { + static constexpr size_t kNumIterations = 100; + + BufferAllocation pred_alloc(0, sizeof(char), 0); + BufferAllocation cnt_alloc(1, sizeof(int32_t), 0); + + BufferAllocation::Slice pred_slice(&pred_alloc, 0, sizeof(char)); + BufferAllocation::Slice cnt_slice(&cnt_alloc, 0, sizeof(int32_t)); + + std::vector buffers; + std::vector predicate = {false}; + std::vector counter = {0}; + + buffers.emplace_back(se::DeviceMemoryBase(predicate.data(), sizeof(char))); + buffers.emplace_back(se::DeviceMemoryBase(counter.data(), sizeof(int32_t))); + + BufferAllocations allocations(buffers); + + // We pass empty cond sequence, because we know the trip count, and check that + // predicate value is ignored (it is initialized to false) and body executed + // `kNumIterations` times. + ThunkSequence cond_sequence; + + ThunkSequence body_sequence; + body_sequence.push_back(std::make_unique(cnt_slice)); + + TF_ASSERT_OK_AND_ASSIGN( + auto thunk, WhileThunk::Create( + {"while"}, pred_slice, std::move(cond_sequence), + std::move(body_sequence), /*trip_count=*/kNumIterations)); + + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "while-test", 8); + Eigen::ThreadPoolDevice device(thread_pool.AsEigenThreadPool(), + thread_pool.NumThreads()); + + Thunk::ExecuteParams params; + params.buffer_allocations = &allocations; + params.intra_op_threadpool = &device; + + auto execute_event = thunk->Execute(params); + tsl::BlockUntilReady(execute_event); + ASSERT_FALSE(execute_event.IsError()); + + EXPECT_EQ(counter[0], kNumIterations); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/backends/interpreter/executor.h b/third_party/xla/xla/backends/interpreter/executor.h index 8ca0cd9c357ef0..1228b3ba890055 100644 --- a/third_party/xla/xla/backends/interpreter/executor.h +++ b/third_party/xla/xla/backends/interpreter/executor.h @@ -86,15 +86,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { absl::Status Init() override { return absl::OkStatus(); } int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec &spec, - Kernel *kernel) override { - return absl::UnimplementedError("Not Implemented"); - } - absl::Status Launch(Stream *stream, const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &kernel, - const KernelArgs &args) override { - return absl::UnimplementedError("Not Implemented"); - } DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase *mem) override; @@ -107,11 +98,6 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memset(Stream *stream, DeviceMemoryBase *location, - uint8_t pattern, uint64_t size) override { - return absl::InternalError("Interpreter can not memset"); - } - // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase *location, @@ -151,8 +137,7 @@ class XlaInterpreterExecutor : public StreamExecutorCommon { } absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override { + std::optional> priority) override { return std::make_unique(this); } diff --git a/third_party/xla/xla/backends/interpreter/platform.cc b/third_party/xla/xla/backends/interpreter/platform.cc index 8b77eb1c801101..0b5756d4e3e175 100644 --- a/third_party/xla/xla/backends/interpreter/platform.cc +++ b/third_party/xla/xla/backends/interpreter/platform.cc @@ -47,31 +47,27 @@ XlaInterpreterPlatform::DescriptionForDevice(int ordinal) const { return XlaInterpreterExecutor::CreateDeviceDescription(ordinal); } -absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( +absl::StatusOr XlaInterpreterPlatform::FindExisting( int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.Get(ordinal); } -absl::StatusOr XlaInterpreterPlatform::GetExecutor( - const StreamExecutorConfig& config) { +absl::StatusOr XlaInterpreterPlatform::ExecutorForDevice( + int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -XlaInterpreterPlatform::GetUncachedExecutor( - const StreamExecutorConfig& config) { - auto executor = - std::make_unique(config.ordinal, this); +XlaInterpreterPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(ordinal, this); auto init_status = executor->Init(); if (!init_status.ok()) { return absl::Status{ absl::StatusCode::kInternal, absl::StrFormat( "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString())}; + ordinal, init_status.ToString())}; } return std::move(executor); diff --git a/third_party/xla/xla/backends/interpreter/platform.h b/third_party/xla/xla/backends/interpreter/platform.h index da3d18e7f4b95f..50a69504ae0139 100644 --- a/third_party/xla/xla/backends/interpreter/platform.h +++ b/third_party/xla/xla/backends/interpreter/platform.h @@ -47,11 +47,13 @@ class XlaInterpreterPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; + absl::StatusOr FindExisting(int ordinal) override; + // Returns a device constructed with ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; + int ordinal); private: // This platform's name. diff --git a/third_party/xla/xla/backends/profiler/cpu/BUILD b/third_party/xla/xla/backends/profiler/cpu/BUILD index 28b448ba142bf2..d5ef505cfe7b3b 100644 --- a/third_party/xla/xla/backends/profiler/cpu/BUILD +++ b/third_party/xla/xla/backends/profiler/cpu/BUILD @@ -128,9 +128,9 @@ xla_cc_test( srcs = ["host_tracer_test.cc"], deps = [ ":host_tracer_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:optional", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc index 0db9f800b958ee..2fca882f9910d8 100644 --- a/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc +++ b/third_party/xla/xla/backends/profiler/cpu/host_tracer_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include "absl/types/optional.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/backends/profiler/gpu/BUILD b/third_party/xla/xla/backends/profiler/gpu/BUILD index 98ac9b38010be7..7a822f518e2921 100644 --- a/third_party/xla/xla/backends/profiler/gpu/BUILD +++ b/third_party/xla/xla/backends/profiler/gpu/BUILD @@ -311,7 +311,10 @@ tsl_gpu_library( "@local_tsl//tsl/profiler/utils:xplane_builder", "@local_tsl//tsl/profiler/utils:xplane_schema", "@local_tsl//tsl/profiler/utils:xplane_utils", - ] + if_cuda(["//xla/tsl/cuda:cupti"]), + ] + if_cuda([ + "//xla/tsl/cuda:cupti", + "//xla/tsl/cuda", + ]), ) tsl_gpu_library( diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc index ccda1b07902355..376b1809ad4b1a 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.cc @@ -186,18 +186,18 @@ void AddGraphTraceActivityEvent(CuptiEventCollectorDelegate &collector, AnnotationMap::AnnotationInfo info = collector.annotation_map.LookUp( graph_trace->deviceId, graph_trace->correlationId); collector.receive(CuptiTracerEvent{ - .type = CuptiTracerEventType::CudaGraph, - .source = CuptiTracerEventSource::Activity, - .name = absl::StrCat("CudaGraphExec:", graph_trace->graphId), - .annotation = info.annotation, - .nvtx_range = info.nvtx_range, - .start_time_ns = graph_trace->start, - .end_time_ns = graph_trace->end, - .device_id = graph_trace->deviceId, - .correlation_id = graph_trace->correlationId, - .context_id = graph_trace->contextId, - .stream_id = graph_trace->streamId, - .graph_id = graph_trace->graphId, + /* .type = */ CuptiTracerEventType::CudaGraph, + /* .source = */ CuptiTracerEventSource::Activity, + /* .name = */ absl::StrCat("CudaGraphExec:", graph_trace->graphId), + /* .annotation = */ info.annotation, + /* .nvtx_range = */ info.nvtx_range, + /* .start_time_ns = */ graph_trace->start, + /* .end_time_ns = */ graph_trace->end, + /* .device_id = */ graph_trace->deviceId, + /* .correlation_id = */ graph_trace->correlationId, + /* .context_id = */ graph_trace->contextId, + /* .stream_id = */ graph_trace->streamId, + /* .graph_id = */ graph_trace->graphId, }); } diff --git a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h index ac708ed94faeda..f58dda54e623c1 100644 --- a/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h +++ b/third_party/xla/xla/backends/profiler/gpu/cupti_buffer_events.h @@ -56,7 +56,7 @@ struct MemcpyDetails { int8_t dst_mem_kind; // ID of the hardware channel on which this operation ran. - uint32_t channel_id = -1; + uint32_t channel_id = static_cast(-1); // CUpti_ChannelType of the channel above. int8_t channel_type = 0; // CUPTI_CHANNEL_TYPE_INVALID }; diff --git a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc index e53b017195d717..65f8d5793e4064 100644 --- a/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc +++ b/third_party/xla/xla/backends/profiler/gpu/rocm_collector.cc @@ -618,7 +618,7 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { const std::vector ApiActivityInfoExchange() TF_EXCLUSIVE_LOCKS_REQUIRED(event_maps_mutex_); - absl::flat_hash_map per_device_collector_; + absl::node_hash_map per_device_collector_; }; //========== diff --git a/third_party/xla/xla/client/BUILD b/third_party/xla/xla/client/BUILD index 1e7f1bf6615628..d322cafd25f403 100644 --- a/third_party/xla/xla/client/BUILD +++ b/third_party/xla/xla/client/BUILD @@ -53,6 +53,7 @@ cc_library( "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/math:math_util", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) @@ -71,24 +72,23 @@ cc_library( srcs = ["client.cc"], hdrs = ["client.h"], deps = [ - ":global_data", ":xla_computation", - "//xla:debug_options_flags", "//xla:execution_options_util", "//xla:literal", + "//xla:shape_util", "//xla:status_macros", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/service", "//xla/service:hlo_proto_cc", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", ], ) @@ -111,9 +111,9 @@ cc_library( "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -126,14 +126,18 @@ cc_library( ":client", ":executable_build_options", ":xla_computation", + "//xla:debug_options_flags", "//xla:executable_run_options", + "//xla:literal", "//xla:shape_tree", + "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:backend", "//xla/service:compiler", + "//xla/service:computation_layout", "//xla/service:dump", "//xla/service:executable", - "//xla/service:hlo_proto_cc", "//xla/service:local_service", "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", @@ -141,9 +145,14 @@ cc_library( "//xla/service:stream_pool", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -154,12 +163,18 @@ cc_library( deps = [ ":client", ":xla_computation", + "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/service:compile_only_service", "//xla/service:compiler", + "//xla/service:hlo_module_config", "//xla/stream_executor", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", ], @@ -174,18 +189,20 @@ cc_library( deps = [ ":compile_only_client", ":local_client", - "//xla:status_macros", "//xla:types", - "//xla:util", - "//xla/service:backend", + "//xla/service", "//xla/service:compile_only_service", "//xla/service:local_service", "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", ], ) @@ -200,6 +217,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "@com_google_absl//absl/log:check", ], ) @@ -214,6 +232,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/status:statusor", ], ) @@ -233,8 +252,13 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/evaluator:hlo_evaluator", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", @@ -305,6 +329,7 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/service:hlo_proto_cc", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", "//xla/tests:xla_internal_test_main", diff --git a/third_party/xla/xla/client/client.cc b/third_party/xla/xla/client/client.cc index 6e89947f237ff7..f5e174df13d98a 100644 --- a/third_party/xla/xla/client/client.cc +++ b/third_party/xla/xla/client/client.cc @@ -22,16 +22,22 @@ limitations under the License. #include #include "absl/status/status.h" -#include "absl/strings/str_cat.h" +#include "absl/types/span.h" #include "xla/client/xla_computation.h" -#include "xla/debug_options_flags.h" #include "xla/execution_options_util.h" +#include "xla/layout.h" #include "xla/literal.h" +#include "xla/service/hlo.pb.h" +#include "xla/service/service.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/types.h" -#include "tsl/platform/errors.h" +#include "xla/util.h" +#include "xla/xla.pb.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" #include "tsl/platform/protobuf.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/client.h b/third_party/xla/xla/client/client.h index f3eacbcad3ec06..120156874869a3 100644 --- a/third_party/xla/xla/client/client.h +++ b/third_party/xla/xla/client/client.h @@ -21,12 +21,15 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/xla_computation.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/service/hlo.pb.h" #include "xla/service/service.h" +#include "xla/shape.h" #include "xla/types.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/client_library.cc b/third_party/xla/xla/client/client_library.cc index b55691be31f7f8..476208d78b0bfb 100644 --- a/third_party/xla/xla/client/client_library.cc +++ b/third_party/xla/xla/client/client_library.cc @@ -20,11 +20,17 @@ limitations under the License. #include #include -#include "xla/service/backend.h" +#include "absl/synchronization/mutex.h" +#include "xla/client/compile_only_client.h" +#include "xla/client/local_client.h" +#include "xla/service/compile_only_service.h" +#include "xla/service/local_service.h" #include "xla/service/platform_util.h" -#include "xla/status_macros.h" -#include "xla/util.h" +#include "xla/service/service.h" +#include "xla/stream_executor/platform.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/client_library.h b/third_party/xla/xla/client/client_library.h index db867329be71ad..0e4f3a9a24dd22 100644 --- a/third_party/xla/xla/client/client_library.h +++ b/third_party/xla/xla/client/client_library.h @@ -28,13 +28,16 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/synchronization/mutex.h" #include "xla/client/compile_only_client.h" #include "xla/client/local_client.h" #include "xla/service/compile_only_service.h" #include "xla/service/local_service.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/compile_only_client.cc b/third_party/xla/xla/client/compile_only_client.cc index 23c07b3742ba0b..1aa6a4f1a8c54c 100644 --- a/third_party/xla/xla/client/compile_only_client.cc +++ b/third_party/xla/xla/client/compile_only_client.cc @@ -18,9 +18,19 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/TargetParser/Triple.h" +#include "xla/service/compile_only_service.h" +#include "xla/service/compiler.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/status_macros.h" +#include "xla/xla.pb.h" namespace xla { diff --git a/third_party/xla/xla/client/compile_only_client.h b/third_party/xla/xla/client/compile_only_client.h index 8dde8c884dd78c..2dcb9775725027 100644 --- a/third_party/xla/xla/client/compile_only_client.h +++ b/third_party/xla/xla/client/compile_only_client.h @@ -20,11 +20,16 @@ limitations under the License. #include #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/client/client.h" #include "xla/client/xla_computation.h" #include "xla/service/compile_only_service.h" #include "xla/service/compiler.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" namespace xla { diff --git a/third_party/xla/xla/client/executable_build_options.cc b/third_party/xla/xla/client/executable_build_options.cc index 46b810d5537a3a..e4843194012cee 100644 --- a/third_party/xla/xla/client/executable_build_options.cc +++ b/third_party/xla/xla/client/executable_build_options.cc @@ -27,14 +27,13 @@ limitations under the License. #include "xla/debug_options_flags.h" #include "xla/execution_options_util.h" #include "xla/layout_util.h" +#include "xla/pjrt/compile_options.pb.h" #include "xla/service/compilation_environments.h" #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/executable_build_options.h b/third_party/xla/xla/client/executable_build_options.h index c849230ecad082..f1129d6ac5c1fe 100644 --- a/third_party/xla/xla/client/executable_build_options.h +++ b/third_party/xla/xla/client/executable_build_options.h @@ -24,6 +24,10 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/pjrt/compile_options.pb.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/compilation_environments.h" diff --git a/third_party/xla/xla/client/lib/BUILD b/third_party/xla/xla/client/lib/BUILD index 97791eaff387d0..1648ff37237c8f 100644 --- a/third_party/xla/xla/client/lib/BUILD +++ b/third_party/xla/xla/client/lib/BUILD @@ -232,6 +232,7 @@ cc_library( xla_test( name = "math_test", + timeout = "long", srcs = ["math_test.cc"], backend_tags = { # Times out. @@ -253,8 +254,8 @@ xla_test( "//xla/tests:client_library_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/xla/client/lib/math_test.cc b/third_party/xla/xla/client/lib/math_test.cc index 559302f6bb5977..0c5776f4bea333 100644 --- a/third_party/xla/xla/client/lib/math_test.cc +++ b/third_party/xla/xla/client/lib/math_test.cc @@ -38,9 +38,9 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/client/local_client.cc b/third_party/xla/xla/client/local_client.cc index c388dc478d7fde..05056ba76664f9 100644 --- a/third_party/xla/xla/client/local_client.cc +++ b/third_party/xla/xla/client/local_client.cc @@ -20,13 +20,37 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" +#include "xla/debug_options_flags.h" +#include "xla/executable_run_options.h" +#include "xla/literal.h" #include "xla/service/backend.h" +#include "xla/service/compiler.h" +#include "xla/service/computation_layout.h" #include "xla/service/dump.h" +#include "xla/service/executable.h" +#include "xla/service/maybe_owning_device_memory.h" #include "xla/service/service_executable_run_options.h" +#include "xla/service/shaped_buffer.h" #include "xla/service/source_map_util.h" #include "xla/service/stream_pool.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/shape_util.h" +#include "xla/stream_executor/device_memory_allocator.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" using xla::source_map_util::InvalidParameterArgument; diff --git a/third_party/xla/xla/client/local_client.h b/third_party/xla/xla/client/local_client.h index 236ebe0bfb2f3c..07c6e6e8b11978 100644 --- a/third_party/xla/xla/client/local_client.h +++ b/third_party/xla/xla/client/local_client.h @@ -21,21 +21,28 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/client/client.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/literal.h" +#include "xla/service/backend.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" #include "xla/service/local_service.h" #include "xla/service/maybe_owning_device_memory.h" +#include "xla/service/service_executable_run_options.h" #include "xla/service/shaped_buffer.h" +#include "xla/service/stream_pool.h" #include "xla/shape_tree.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/stream_executor.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/client/padding.cc b/third_party/xla/xla/client/padding.cc index 37abc598216ded..daf26d5467ac7b 100644 --- a/third_party/xla/xla/client/padding.cc +++ b/third_party/xla/xla/client/padding.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/types/span.h" #include "xla/util.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/client/padding.h b/third_party/xla/xla/client/padding.h index e71522616bf1ab..e717183ce2d6c8 100644 --- a/third_party/xla/xla/client/padding.h +++ b/third_party/xla/xla/client/padding.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/sharding_builder.cc b/third_party/xla/xla/client/sharding_builder.cc index e2324d68f92db7..7b179b8c91ee4a 100644 --- a/third_party/xla/xla/client/sharding_builder.cc +++ b/third_party/xla/xla/client/sharding_builder.cc @@ -17,6 +17,12 @@ limitations under the License. #include +#include "absl/log/check.h" +#include "xla/shape.h" +#include "xla/shape_tree.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + namespace xla { namespace sharding_builder { diff --git a/third_party/xla/xla/client/sharding_builder.h b/third_party/xla/xla/client/sharding_builder.h index 98d6512d59c28d..eef395e0b46368 100644 --- a/third_party/xla/xla/client/sharding_builder.h +++ b/third_party/xla/xla/client/sharding_builder.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "xla/array.h" +#include "xla/shape.h" #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/types.h" diff --git a/third_party/xla/xla/client/value_inference.cc b/third_party/xla/xla/client/value_inference.cc index 1ba694ad6154c9..2f0b6e20756bff 100644 --- a/third_party/xla/xla/client/value_inference.cc +++ b/third_party/xla/xla/client/value_inference.cc @@ -21,11 +21,17 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/client/xla_builder.h" #include "xla/comparison_util.h" +#include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_opcode.h" @@ -33,6 +39,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/hlo.pb.h" +#include "xla/service/hlo_module_config.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/client/value_inference.h b/third_party/xla/xla/client/value_inference.h index 6f1685f1a42e0a..84c1c99f53fd4d 100644 --- a/third_party/xla/xla/client/value_inference.h +++ b/third_party/xla/xla/client/value_inference.h @@ -19,6 +19,9 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xla/client/xla_builder.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/dfs_hlo_visitor.h" @@ -26,6 +29,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/client/xla_builder.cc b/third_party/xla/xla/client/xla_builder.cc index c869c43e160518..98e7dada978400 100644 --- a/third_party/xla/xla/client/xla_builder.cc +++ b/third_party/xla/xla/client/xla_builder.cc @@ -1450,6 +1450,42 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation, }); } +XlaOp XlaBuilder::CompositeCall(const XlaComputation& computation, + absl::Span operands, + const std::string& name, + std::optional attributes, + std::optional version) { + return ReportErrorOrReturn([&]() -> absl::StatusOr { + HloInstructionProto instr; + std::vector operand_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); + absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape, + computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape( + operand_shape_ptrs, + /*to_apply=*/called_program_shape)); + *instr.mutable_shape() = shape.ToProto(); + + AddCalledComputation(computation, &instr); + instr.set_is_composite(true); + + TF_ASSIGN_OR_RETURN( + XlaOp instruction, + AddInstruction(std::move(instr), HloOpcode::kCall, operands)); + TF_RETURN_IF_ERROR( + SetInstructionFrontendAttribute(instruction, "composite.name", name)); + TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute( + instruction, "composite.attributes", + attributes.has_value() ? std::string(*attributes) : "{}")); + TF_RETURN_IF_ERROR(SetInstructionFrontendAttribute( + instruction, "composite.version", + version.has_value() ? std::to_string(*version) : "0")); + return instruction; + }); +} + XlaOp XlaBuilder::Parameter( int64_t parameter_number, const Shape& shape, const std::string& name, const std::vector& replicated_at_leaf_buffers) { @@ -3854,6 +3890,7 @@ XlaOp XlaBuilder::AllToAllArray( if (is_unbounded) { std::vector new_dimensions; + new_dimensions.reserve(operand_shape->rank()); for (int64_t i = 0; i < operand_shape->rank(); ++i) { new_dimensions.push_back(GetR1DimensionSizeOrConstant(operand, i)); } @@ -5195,6 +5232,14 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, return builder->Call(computation, operands); } +XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands, const std::string& name, + std::optional attributes, + std::optional version) { + return builder->CompositeCall(computation, operands, name, attributes, + version); +} + XlaOp CustomCall( XlaBuilder* builder, const std::string& call_target_name, absl::Span operands, const Shape& shape, diff --git a/third_party/xla/xla/client/xla_builder.h b/third_party/xla/xla/client/xla_builder.h index c1192bf716a0d3..53683cfe3957cd 100644 --- a/third_party/xla/xla/client/xla_builder.h +++ b/third_party/xla/xla/client/xla_builder.h @@ -731,6 +731,12 @@ class XlaBuilder { XlaOp Call(const XlaComputation& computation, absl::Span operands); + XlaOp CompositeCall( + const XlaComputation& computation, absl::Span operands, + const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + XlaOp CustomCall( const std::string& call_target_name, absl::Span operands, const Shape& shape_with_layout, const std::string& opaque, @@ -1378,6 +1384,14 @@ class XlaBuilder { const std::string& outfeed_config); friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); + + friend XlaOp CompositeCall(XlaBuilder* builder, + const XlaComputation& computation, + absl::Span operands, + const std::string& name, + std::optional attributes, + std::optional version); + friend XlaOp CustomCall( XlaBuilder* builder, const std::string& call_target_name, absl::Span operands, const Shape& shape, @@ -2305,6 +2319,12 @@ XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span operands); +// Enqueues a composite call instruction onto the computation. +XlaOp CompositeCall(XlaBuilder* builder, const XlaComputation& computation, + absl::Span operands, const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt); + // Enqueues a custom call instruction onto the computation. A custom call // invokes code external to XLA. The |operands| are passed to the external code, // and the external code is expected to produce a result of the given diff --git a/third_party/xla/xla/client/xla_builder_test.cc b/third_party/xla/xla/client/xla_builder_test.cc index 9828d500a7060d..8ecf2434fc1d3f 100644 --- a/third_party/xla/xla/client/xla_builder_test.cc +++ b/third_party/xla/xla/client/xla_builder_test.cc @@ -48,6 +48,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/layout_util.h" +#include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -330,6 +331,176 @@ TEST(XlaBuilderTest, Call) { m::Call(m::Constant(), m::Constant())))); } +TEST(XlaBuilderTest, CompositeCall) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), + /*name=*/"foo.bar", + /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor}", + /*version=*/1); + + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Call(m::Parameter(), m::Parameter()))); +} + +TEST(XlaBuilderTest, CompositeCallFrontendAttributesStayLocal) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), + /*name=*/"foo.bar", + /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor}", + /*version=*/1); + Add(operands[0], operands[1]); + + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_TRUE(GetRoot(*module)->frontend_attributes().map().empty()); +} + +TEST(XlaBuilderTest, CompositeCallMissingName) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"", + /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor}", + /*version=*/1); + + auto statusor = BuildHloModule(b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().message(), + HasSubstr("A composite call op must have frontend attributes " + "with key composite.name whose value is non-empty")); +} + +TEST(XlaBuilderTest, CompositeCallMissingAttribute) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar", + /*attributes=*/"", /*version=*/1); + + auto statusor = BuildHloModule(b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().message(), + HasSubstr( + "A composite call op must have frontend attributes with key " + "composite.attributes whose value is default: {} or non-empty")); +} + +TEST(XlaBuilderTest, CompositeCallNonNegativeVersion) { + XlaBuilder b(TestName()); + + FrontendAttributes frontend_attributes = b.frontend_attributes(); + frontend_attributes.mutable_map()->insert({"foo", "bar"}); + b.SetFrontendAttributes(frontend_attributes); + + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), + /*name=*/"foo.bar", + /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor}", + /*version=*/-1); + + auto statusor = BuildHloModule(b); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().message(), + HasSubstr("A composite call op must have frontend attributes " + "with a composite.version whose value is a " + "non-negative integer but got: -1")); +} + +TEST(XlaBuilderTest, CompositeCallOptionalVersionAndAttribute) { + XlaBuilder b(TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), /*name=*/"foo.bar"); + + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + ASSERT_THAT(GetRoot(*module), + GmockMatch(m::Call(m::Parameter(), m::Parameter()))); + ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains( + "composite.attributes")); + EXPECT_EQ( + GetRoot(*module)->frontend_attributes().map().at("composite.attributes"), + "{}"); + EXPECT_EQ( + GetRoot(*module)->frontend_attributes().map().at("composite.version"), + "0"); +} + +TEST(XlaBuilderTest, CompositeCallWithExtraFrontendAttributes) { + XlaBuilder b(TestName()); + + FrontendAttributes frontend_attributes = b.frontend_attributes(); + frontend_attributes.mutable_map()->insert({"foo", "bar"}); + b.SetFrontendAttributes(frontend_attributes); + + const Shape shape = ShapeUtil::MakeShape(F32, {}); + const Shape expected = ShapeUtil::MakeTupleShape({shape, shape, shape}); + + XlaBuilder bsum(TestName()); + Add(Parameter(&bsum, 0, shape, "arg0"), Parameter(&bsum, 1, shape, "arg1")); + TF_ASSERT_OK_AND_ASSIGN(const XlaComputation computation, bsum.Build()); + + std::vector operands = {Parameter(&b, 0, shape, "arg0"), + Parameter(&b, 1, shape, "arg1")}; + CompositeCall(&b, computation, absl::MakeSpan(operands), + /*name=*/"foo.bar", + /*attributes=*/"{n = 1 : i32, tensor = dense<1> : tensor}", + /*version=*/1); + + TF_ASSERT_OK_AND_ASSIGN(const auto module, BuildHloModule(b)); + EXPECT_THAT(GetRoot(*module), + GmockMatch(m::Call(m::Parameter(), m::Parameter()))); + ASSERT_TRUE(GetRoot(*module)->frontend_attributes().map().contains("foo")); + EXPECT_EQ(GetRoot(*module)->frontend_attributes().map().at("foo"), "bar"); +} + TEST(XlaBuilderTest, BinopHasDegenerateBroadcast) { XlaBuilder b(TestName()); auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x"); diff --git a/third_party/xla/xla/client/xla_computation.cc b/third_party/xla/xla/client/xla_computation.cc index c92de63495d190..fc558462d1a576 100644 --- a/third_party/xla/xla/client/xla_computation.cc +++ b/third_party/xla/xla/client/xla_computation.cc @@ -18,6 +18,9 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" +#include "xla/service/hlo.pb.h" +#include "xla/shape.h" #include "xla/status_macros.h" #include "xla/util.h" diff --git a/third_party/xla/xla/client/xla_computation.h b/third_party/xla/xla/client/xla_computation.h index e21a92d6300654..52a54aa113b178 100644 --- a/third_party/xla/xla/client/xla_computation.h +++ b/third_party/xla/xla/client/xla_computation.h @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/status/statusor.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/status_macros.h" diff --git a/third_party/xla/xla/core/host_offloading/README.md b/third_party/xla/xla/core/host_offloading/README.md new file mode 100644 index 00000000000000..22f6449bce3b09 --- /dev/null +++ b/third_party/xla/xla/core/host_offloading/README.md @@ -0,0 +1,8 @@ +# XLA Host Offloading + +XLA host offloading allows us to run part of the HLO module on the host attached +to the accelerator device (TPU or GPU) using the XLA:CPU compiler. On JAX side +it is available as `jax.experimental.compute_on` API. + +With `compute_on` annotation, JAX + XLA can be used to implement +[ZeRO-Offload](https://arxiv.org/abs/2101.06840) host offloading. \ No newline at end of file diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index c35ea757728c8c..8c7603ab7a7bfa 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -38,6 +38,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" +#include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" @@ -82,7 +83,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { #ifdef XLA_CPU_USE_ACL opts.set_xla_cpu_use_acl(true); #endif - opts.set_xla_cpu_use_thunk_runtime(false); + opts.set_xla_cpu_use_thunk_runtime(true); opts.set_xla_cpu_enable_concurrency_optimized_scheduler(false); opts.set_xla_cpu_prefer_vector_width(256); @@ -144,7 +145,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_dumping(true); opts.set_xla_gpu_enable_custom_fusions(false); - opts.set_xla_gpu_enable_address_computation_fusion(true); + opts.set_xla_gpu_enable_dynamic_slice_fusion(true); opts.set_xla_gpu_nccl_termination_timeout_seconds(-1); opts.set_xla_gpu_enable_shared_constants(true); opts.set_xla_gpu_enable_nccl_user_buffers(false); @@ -166,9 +167,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_highest_priority_async_stream(true); opts.set_xla_gpu_enable_pipelined_collectives(false); - opts.set_xla_gpu_enable_pipelined_all_reduce(false); - opts.set_xla_gpu_enable_pipelined_all_gather(false); - opts.set_xla_gpu_enable_pipelined_reduce_scatter(false); + opts.set_xla_gpu_enable_pipelined_all_reduce(true); + opts.set_xla_gpu_enable_pipelined_all_gather(true); + opts.set_xla_gpu_enable_pipelined_reduce_scatter(true); opts.set_xla_gpu_enable_pipelined_p2p(false); opts.set_xla_gpu_run_post_layout_collective_pipeliner(false); @@ -231,12 +232,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_enable_triton_hopper(false); - // We disable this until b/319271534 is fixed due to errors during linking. - // - // TODO(b/319271534): Re-enable once we use libnvjitlink. opts.set_xla_gpu_enable_llvm_module_compilation_parallelism(false); - - opts.set_xla_gpu_enable_libnvptxcompiler(false); + opts.set_xla_gpu_enable_libnvptxcompiler( + stream_executor::IsLibNvPtxCompilerSupported()); + opts.set_xla_gpu_enable_libnvjitlink( + stream_executor::IsLibNvJitLinkSupported()); opts.set_xla_gpu_enable_dot_strength_reduction(true); @@ -246,7 +246,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_p2p_max_nchannels(0); #if GOOGLE_CUDA - opts.set_xla_gpu_mlir_emitter_level(3); + opts.set_xla_gpu_mlir_emitter_level(4); #else opts.set_xla_gpu_mlir_emitter_level(0); #endif @@ -271,9 +271,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_terminate_on_error(false); - opts.set_xla_use_shardy(false); - - opts.set_xla_gpu_shard_autotuning(false); + opts.set_xla_gpu_shard_autotuning(true); opts.set_xla_syntax_sugar_async_ops(false); @@ -283,6 +281,9 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_enable_command_buffers_during_profiling(false); + opts.set_xla_gpu_cudnn_gemm_max_plans(5); + + opts.set_xla_gpu_enable_triton_gemm_int4(false); return opts; } @@ -1109,7 +1110,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, collective_op_types_to_string( debug_options->xla_gpu_disable_async_collectives()), "This disables a certain set of async collectives and turn them into" - " synchornous ones. By default, this is empty which indicates enabling" + " synchronous ones. By default, this is empty which indicates enabling" " async execution for all collectives. A sample usage is: " " --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER")); flag_list->push_back(tsl::Flag( @@ -1206,7 +1207,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_enable_cudnn_fmha), debug_options->xla_gpu_enable_cudnn_fmha(), "Use the cuDNN Fused Attention runtime fusion when possible. Note " - "that dropout support and the developement of this feature as a whole is " + "that dropout support and the development of this feature as a whole is " "in progress. Attention with dropout may cause results to diverge with " "and without this flag turned on.")); flag_list->push_back(tsl::Flag( @@ -1242,7 +1243,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, setter_for_legacy_command_buffer_custom_call_targets, "", "Comma-separated list of custom call targets with legacy " "registry API (non FFI API), whose targets supports lowering " - "to command buffer custom command, i.e, custom call target " + "to command buffer custom command, i.e., custom call target " "supports cuda-graph capturing for CUDA devices.")); flag_list->push_back(tsl::Flag( @@ -1298,10 +1299,9 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "expression. Default is all custom fusions registerered in a current " "process.")); flag_list->push_back(tsl::Flag( - "xla_gpu_enable_address_computation_fusion", - bool_setter_for( - &DebugOptions::set_xla_gpu_enable_address_computation_fusion), - debug_options->xla_gpu_enable_address_computation_fusion(), + "xla_gpu_enable_dynamic_slice_fusion", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_dynamic_slice_fusion), + debug_options->xla_gpu_enable_dynamic_slice_fusion(), "Whether to enable XLA address computation fusion")); flag_list->push_back(tsl::Flag( "xla_gpu_nccl_termination_timeout_seconds", @@ -1447,8 +1447,11 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "xla_gpu_enable_pipelined_collectives", bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_collectives), debug_options->xla_gpu_enable_pipelined_collectives(), - "Enable pipelinling of collective instructions (all-reduce, all-gather, " - "and reduce-scatter).")); + "Enable pipelinling of collective instructions. It has the same effect " + "as setting xla_gpu_enable_pipelined_all_reduce, " + "xla_gpu_enable_pipelined_all_gather, " + "xla_gpu_enable_pipelined_reduce_scatter and " + "xla_gpu_enable_pipelined_p2p flags to true.")); flag_list->push_back(tsl::Flag( "xla_gpu_enable_pipelined_all_reduce", bool_setter_for(&DebugOptions::set_xla_gpu_enable_pipelined_all_reduce), @@ -1768,7 +1771,7 @@ void MakeDebugOptionsFlags(std::vector* flag_list, int64_setter_for(&DebugOptions::set_xla_gpu_gemm_rewrite_size_threshold), debug_options->xla_gpu_gemm_rewrite_size_threshold(), "Threshold until which elemental dot emitter is preferred for GEMMs " - "(minumum combined number of elements of both matrices " + "(minimum combined number of elements of both matrices " "in non-batch dimensions to be considered for a rewrite).")); flag_list->push_back(tsl::Flag( "xla_gpu_use_memcpy_local_p2p", @@ -1797,9 +1800,6 @@ void MakeDebugOptionsFlags(std::vector* flag_list, bool_setter_for(&DebugOptions::set_xla_gpu_nccl_terminate_on_error), debug_options->xla_gpu_nccl_terminate_on_error(), "If set, then NCCL errors will terminate the process.")); - flag_list->push_back(tsl::Flag( - "xla_use_shardy", bool_setter_for(&DebugOptions::set_xla_use_shardy), - debug_options->xla_use_shardy(), "Whether to use Shardy.")); flag_list->push_back(tsl::Flag( "xla_gpu_shard_autotuning", bool_setter_for(&DebugOptions::set_xla_gpu_shard_autotuning), @@ -1840,6 +1840,23 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Experimental: Enable command buffers while a profiling active. " "By default, enabling profiling switches from command buffers to " "op-by-op mode.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_cudnn_gemm_max_plans", + int32_setter_for(&DebugOptions::set_xla_gpu_cudnn_gemm_max_plans), + debug_options->xla_gpu_cudnn_gemm_max_plans(), + "Limit for the number of kernel configurations (plans) to use during " + "autotuning of cuDNN GEMM fusions.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_triton_gemm_int4", + bool_setter_for(&DebugOptions::set_xla_gpu_enable_triton_gemm_int4), + debug_options->xla_gpu_enable_triton_gemm_int4(), + "Experimental: Enable Triton gemm for int4 inputs.")); + flag_list->push_back( + tsl::Flag("xla_gpu_async_dot", + bool_setter_for(&DebugOptions::set_xla_gpu_async_dot), + debug_options->xla_gpu_async_dot(), + "Wrap `dot` operations into async computations in an effort to " + "parallelize matrix operations.")); } // NOLINT(readability/fn_size) // Allocates flag_values and flag_objects; this function must not be called more diff --git a/third_party/xla/xla/examples/axpy/BUILD b/third_party/xla/xla/examples/axpy/BUILD index db66693c359d8d..9d3424498f3413 100644 --- a/third_party/xla/xla/examples/axpy/BUILD +++ b/third_party/xla/xla/examples/axpy/BUILD @@ -22,11 +22,11 @@ xla_cc_test( "//xla/service/cpu:cpu_compiler", "//xla/stream_executor:platform", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/examples/axpy/README.md b/third_party/xla/xla/examples/axpy/README.md index 397dd21c8fb6d8..39bacfb18c5659 100644 --- a/third_party/xla/xla/examples/axpy/README.md +++ b/third_party/xla/xla/examples/axpy/README.md @@ -72,10 +72,8 @@ LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); // PlatformUtil::GetPlatform("CUDA")); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, PlatformUtil::GetPlatform("cpu")); -se::StreamExecutorConfig config; -config.ordinal = 0; TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(0)); // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a // device which can do computation or transfer buffers. Could represent a GPU diff --git a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc index 0bf61a17caf280..49a99ee88a679c 100644 --- a/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc +++ b/third_party/xla/xla/examples/axpy/stablehlo_compile_test.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/service/stream_pool.h" #include "xla/stream_executor/platform.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" @@ -62,10 +62,8 @@ TEST(StableHloAxpyTest, LoadAndRunCpuExecutable) { // PlatformUtil::GetPlatform("CUDA")); TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, PlatformUtil::GetPlatform("cpu")); - se::StreamExecutorConfig config; - config.ordinal = 0; TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(/*ordinal=*/0)); // LocalDeviceState and PjRtStreamExecutorDevice describes the state of a // device which can do computation or transfer buffers. This could represent a diff --git a/third_party/xla/xla/ffi/BUILD b/third_party/xla/xla/ffi/BUILD index 55912a26009613..2e9deaaf4066e7 100644 --- a/third_party/xla/xla/ffi/BUILD +++ b/third_party/xla/xla/ffi/BUILD @@ -45,9 +45,9 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -76,8 +76,8 @@ xla_cc_test( deps = [ ":execution_context", ":type_id_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -103,8 +103,8 @@ xla_cc_test( srcs = ["execution_state_test.cc"], deps = [ ":execution_state", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -118,8 +118,10 @@ cc_library( ":api", ":execution_context", ":execution_state", + "//xla:executable_run_options", "//xla:shape_util", "//xla:types", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", @@ -148,6 +150,7 @@ cc_library( ":execution_context", ":execution_state", ":type_id_registry", + "//xla:executable_run_options", "//xla:util", "//xla/ffi/api:c_api", "//xla/ffi/api:c_api_internal", @@ -197,12 +200,12 @@ xla_cc_test( "//xla/ffi/api:c_api", "//xla/stream_executor", "//xla/stream_executor:device_memory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -230,8 +233,8 @@ xla_cc_test( srcs = ["type_id_registry_test.cc"], deps = [ ":type_id_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/ffi/api/BUILD b/third_party/xla/xla/ffi/api/BUILD index 9f1fcdc8d4e117..0af899a77c4d9a 100644 --- a/third_party/xla/xla/ffi/api/BUILD +++ b/third_party/xla/xla/ffi/api/BUILD @@ -85,10 +85,10 @@ xla_cc_test( "//xla/ffi:type_id_registry", "//xla/stream_executor:device_memory", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/ffi/api/api.h b/third_party/xla/xla/ffi/api/api.h index 7675c3ab58f8a8..dccbcc60d25936 100644 --- a/third_party/xla/xla/ffi/api/api.h +++ b/third_party/xla/xla/ffi/api/api.h @@ -33,7 +33,6 @@ limitations under the License. #include #include #include -#include #include // This is a header-only base C++ library that defines templates for decoding @@ -232,16 +231,17 @@ class Ffi { template static std::string StrCat(Args... args); - static inline XLA_FFI_Error* MakeError(const XLA_FFI_Api* api, - XLA_FFI_Error_Code errc, - std::string message); + static XLA_FFI_Error* Sucess(); - static inline XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api, - std::string message); + static XLA_FFI_Error* MakeError(const XLA_FFI_Api* api, + XLA_FFI_Error_Code errc, std::string message); - static inline XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api, - std::string_view struct_name, - size_t expected, size_t actual); + static XLA_FFI_Error* InvalidArgument(const XLA_FFI_Api* api, + std::string message); + + static XLA_FFI_Error* CheckStructSize(const XLA_FFI_Api* api, + std::string_view struct_name, + size_t expected, size_t actual); }; XLA_FFI_Error* Ffi::RegisterStaticHandler(const XLA_FFI_Api* api, @@ -266,8 +266,11 @@ std::string Ffi::StrCat(Args... args) { return ss.str(); } -XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc, - std::string message) { +inline XLA_FFI_Error* Ffi::Sucess() { return nullptr; } + +inline XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, + XLA_FFI_Error_Code errc, + std::string message) { XLA_FFI_Error_Create_Args args; args.struct_size = XLA_FFI_Error_Create_Args_STRUCT_SIZE; args.priv = nullptr; @@ -276,15 +279,15 @@ XLA_FFI_Error* Ffi::MakeError(const XLA_FFI_Api* api, XLA_FFI_Error_Code errc, return api->XLA_FFI_Error_Create(&args); } -XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api, - std::string message) { +inline XLA_FFI_Error* Ffi::InvalidArgument(const XLA_FFI_Api* api, + std::string message) { return MakeError(api, XLA_FFI_Error_Code_INVALID_ARGUMENT, std::move(message)); } -XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, - std::string_view struct_name, - size_t expected, size_t actual) { +inline XLA_FFI_Error* Ffi::CheckStructSize(const XLA_FFI_Api* api, + std::string_view struct_name, + size_t expected, size_t actual) { if (expected != actual) { return InvalidArgument( api, StrCat("Unexpected ", struct_name, " size: expected ", expected, @@ -306,12 +309,13 @@ namespace internal { // parameter packs. We need this to be able to pattern match FFI handler // signature at compile time. +// A type tag for decoding optional argument. +template +struct OptionalArgTag {}; + // A type tag to forward all remaining args as `RemainingArgs`. struct RemainingArgsTag {}; -// A type tag to forward all remaining results as `RemainingRets`. -struct RemainingRetsTag {}; - // A type tag to distinguish parameters tied to results in the `Binding` // variadic template. In XLA FFI we use destination passing style APIs and don't // return anything from the handler, but instead pass a destination where the @@ -319,6 +323,13 @@ struct RemainingRetsTag {}; template struct RetTag {}; +// A type tag for decoding optional result. +template +struct OptionalRetTag {}; + +// A type tag to forward all remaining results as `RemainingRets`. +struct RemainingRetsTag {}; + // A type tag to distinguish parameters tied to the attributes in the // `Binding` variadic template. template @@ -357,12 +368,30 @@ struct NumTagged { //----------------------------------------------------------------------------// -// Checks if remaining arguments are in the parameter pack. +template +struct IsOptionalArgTag : std::false_type {}; +template +struct IsOptionalArgTag> : std::true_type {}; + +template +struct IsOptionalRetTag : std::false_type {}; +template +struct IsOptionalRetTag> : std::true_type {}; + +// Checks if parameter pack has an optional argument. +template +using HasOptionalArgTag = std::disjunction...>; + +// Checks if parameter pack has remaining arguments. template using HasRemainingArgsTag = std::disjunction...>; -// Checks if remaining results are in the parameter pack. +// Checks if parameter pack has an optional result. +template +using HasOptionalRetTag = std::disjunction...>; + +// Checks if parameter pack has remaining results. template using HasRemainingRetsTag = std::disjunction...>; @@ -413,11 +442,34 @@ class Binding { public: template Binding Arg() && { + static_assert(!internal::HasOptionalArgTag::value, + "argument can't be passed after optional argument"); + static_assert(!internal::HasRemainingArgsTag::value, + "argument can't be passed after remaining arguments"); return {std::move(*this)}; } template Binding> Ret() && { + static_assert(!internal::HasOptionalRetTag::value, + "result can't be passed after optional result"); + static_assert(!internal::HasRemainingRetsTag::value, + "result can't be passed after remaining results"); + return {std::move(*this)}; + } + + template + Binding> OptionalArg() && { + static_assert( + !internal::HasRemainingArgsTag::value, + "optional argument can't be passed after remaining arguments"); + return {std::move(*this)}; + } + + template + Binding> OptionalRet() && { + static_assert(!internal::HasRemainingRetsTag::value, + "optional result can't be passed after remaining results"); return {std::move(*this)}; } @@ -427,7 +479,7 @@ class Binding { return {std::move(*this)}; } - Binding RemainingResults() && { + Binding RemainingRets() && { static_assert(!internal::HasRemainingRetsTag::value, "remaining results can be passed just once"); return {std::move(*this)}; @@ -900,10 +952,20 @@ struct Decode { } }; -} // namespace internal +template +struct Decode> { + static std::optional> call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + if (offsets.args >= ctx.call_frame->args.size) { + return std::optional(std::nullopt); + } + return Decode::call(offsets, ctx, diagnostic); + } +}; template -struct internal::Decode> { +struct Decode> { static std::optional> call(DecodingOffsets& offsets, DecodingContext& ctx, DiagnosticEngine& diagnostic) { @@ -914,7 +976,19 @@ struct internal::Decode> { }; template -struct internal::Decode> { +struct Decode> { + static std::optional>> call( + DecodingOffsets& offsets, DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + if (offsets.rets >= ctx.call_frame->rets.size) { + return std::optional>(std::nullopt); + } + return Decode>::call(offsets, ctx, diagnostic); + } +}; + +template +struct Decode> { using R = typename AttrDecoding::Type; static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, @@ -946,7 +1020,7 @@ struct internal::Decode> { }; template -struct internal::Decode> { +struct Decode> { using R = typename CtxDecoding::Type; static std::optional call(DecodingOffsets& offsets, DecodingContext& ctx, @@ -956,75 +1030,17 @@ struct internal::Decode> { } }; -//===----------------------------------------------------------------------===// -// Expected -//===----------------------------------------------------------------------===// - -// Forward declare. -template -class Unexpected; - -// TODO(slebedev): Replace with `std::expected` when C++23 is available. -template -class Expected { - public: - constexpr Expected(T value) : data_(std::move(value)) {} // NOLINT - constexpr Expected(Unexpected u); // NOLINT - - constexpr operator bool() const { // NOLINT - return has_value(); - } - - constexpr T& operator*() & { return value(); } - constexpr const T& operator*() const& { return value(); } - constexpr T&& operator*() && { return std::move(value()); } - constexpr const T& operator*() const&& { return std::move(value()); } - - constexpr T* operator->() { return &value(); } - constexpr const T* operator->() const { return &value(); } - - constexpr bool has_value() const { return std::holds_alternative(data_); } - constexpr bool has_error() const { return std::holds_alternative(data_); } - - constexpr T& value() & { return std::get(data_); } - constexpr const T& value() const& { return std::get(data_); } - constexpr T&& value() && { return std::get(std::move(data_)); } - constexpr const T& value() const&& { return std::get(std::move(data_)); } - - constexpr E& error() & { return std::get(data_); } - constexpr const E& error() const& { return std::get(data_); } - constexpr E&& error() && { return std::get(std::move(data_)); } - constexpr const E&& error() const&& { return std::get(std::move(data_)); } - - private: - std::variant data_; -}; - -template -class Unexpected { - public: - constexpr Unexpected(E error) : error_(std::move(error)) {} // NOLINT - - private: - template - friend class Expected; - - E error_; -}; - -Unexpected(const char*) -> Unexpected; - -template -constexpr Expected::Expected(Unexpected u) - : data_(std::move(u.error_)) {} +} // namespace internal //===----------------------------------------------------------------------===// // Type-safe wrapper for accessing a variable number of arguments. //===----------------------------------------------------------------------===// -class RemainingArgs { +namespace internal { + +class RemainingArgsBase { public: - RemainingArgs(const XLA_FFI_Args* args, size_t offset) + RemainingArgsBase(const XLA_FFI_Args* args, size_t offset) : args_(args), offset_(offset) { assert(offset <= args_->size && "illegal remaining args offset"); } @@ -1032,43 +1048,26 @@ class RemainingArgs { size_t size() const { return args_->size - offset_; } bool empty() const { return size() == 0; } - template - Expected get(size_t index) const { - size_t idx = offset_ + index; - if (idx >= args_->size) { - return Unexpected("Index out of range."); - } - - DiagnosticEngine diagnostic; - auto value_opt = - ArgDecoding::Decode(args_->types[idx], args_->args[idx], diagnostic); - if (!value_opt.has_value()) { - return Unexpected(diagnostic.Result()); - } - return *value_opt; - } + protected: + const XLA_FFI_Args* args() const { return args_; } + size_t offset() const { return offset_; } private: - const XLA_FFI_Args* args_; // not owned + const XLA_FFI_Args* args_; size_t offset_; }; -template <> -struct internal::Decode { - static std::optional call(DecodingOffsets& offsets, - DecodingContext& ctx, - DiagnosticEngine& diagnostic) { - return RemainingArgs(&ctx.call_frame->args, offsets.args); - } -}; +} // namespace internal //===----------------------------------------------------------------------===// // Type-safe wrapper for accessing a variable number of results. //===----------------------------------------------------------------------===// -class RemainingResults { +namespace internal { + +class RemainingRetsBase { public: - RemainingResults(const XLA_FFI_Rets* rets, size_t offset) + RemainingRetsBase(const XLA_FFI_Rets* rets, size_t offset) : rets_(rets), offset_(offset) { assert(offset <= rets_->size && "illegal remaining rets offset"); } @@ -1076,43 +1075,30 @@ class RemainingResults { size_t size() const { return rets_->size - offset_; } bool empty() const { return size() == 0; } - template - Expected get(size_t index) const { - size_t idx = offset_ + index; - if (idx >= rets_->size) { - return Unexpected("Index out of range."); - } - - DiagnosticEngine diagnostic; - auto value_opt = - RetDecoding::Decode(rets_->types[idx], rets_->rets[idx], diagnostic); - if (!value_opt.has_value()) { - return Unexpected(diagnostic.Result()); - } - return **value_opt; - } + protected: + const XLA_FFI_Rets* rets() const { return rets_; } + size_t offset() const { return offset_; } private: const XLA_FFI_Rets* rets_; // not owned size_t offset_; }; -template <> -struct internal::Decode { - static std::optional call(DecodingOffsets& offsets, - DecodingContext& ctx, - DiagnosticEngine& diagnostic) { - return RemainingResults(&ctx.call_frame->rets, offsets.rets); - } -}; +} // namespace internal //===----------------------------------------------------------------------===// // Type-safe wrapper for accessing dictionary attributes. //===----------------------------------------------------------------------===// -class Dictionary { +namespace internal { + +// Forward declare dictionary attribute decoding defined below. +template +struct DecodeDictionaryAttr; + +class DictionaryBase { public: - explicit Dictionary(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {} + explicit DictionaryBase(const XLA_FFI_Attrs* attrs) : attrs_(attrs) {} size_t size() const { return attrs_->size; } @@ -1120,21 +1106,15 @@ class Dictionary { return Find(name) < attrs_->size; } - template - Expected get(std::string_view name) const { - DiagnosticEngine diagnostic; - auto value_opt = get(name, diagnostic); - if (!value_opt.has_value()) { - return Unexpected(diagnostic.Result()); - } - return *value_opt; - } + protected: + template + friend struct DecodeDictionaryAttr; template std::optional get(std::string_view name, DiagnosticEngine& diagnostic) const { size_t idx = Find(name); - if (idx >= attrs_->size) { + if (XLA_FFI_PREDICT_FALSE(idx >= attrs_->size)) { return diagnostic.Emit("Unexpected attribute: ") << name; } @@ -1161,15 +1141,11 @@ class Dictionary { const XLA_FFI_Attrs* attrs_; }; -// Decode `AttrsTag` into a generic `Dictionary` attribute. -template <> -struct internal::Decode> { - static std::optional call(DecodingOffsets& offsets, - DecodingContext& ctx, - DiagnosticEngine& diagnostic) { - return Dictionary(&ctx.call_frame->attrs); - } -}; +} // namespace internal + +//===----------------------------------------------------------------------===// +// Decoding for aggregate attributes (decoding dictionaries into structs). +//===----------------------------------------------------------------------===// // Decode `AttrsTag` into a type `T` relying on struct decoding defined below. template @@ -1186,6 +1162,13 @@ struct internal::Decode> { // Template metaprogramming for decoding handler signature //===----------------------------------------------------------------------===// +// Forward declare classes for decoding variadic number of arguments and +// results. They are defined in `ffi.h` headers (internal and external), to be +// able to use slightly different implementations for internal and external +// FFI (`absl::StatusOr` vs `ffi::ErrorOr`). +class RemainingArgs; +class RemainingRets; + namespace internal { // A helper struct to extract the type of the handler argument. template @@ -1193,23 +1176,31 @@ struct FnArgType { using Type = T; }; -template <> -struct FnArgType { - using Type = RemainingArgs; +template +struct FnArgType> { + using Type = std::optional; }; template <> -struct FnArgType { - using Type = RemainingResults; +struct FnArgType { + using Type = RemainingArgs; }; -// Extracts the underlying type from the returned result type tag. template struct FnArgType> { using Type = Result; }; -// Extracts the underlying type from the attribute type tag. +template +struct FnArgType> { + using Type = std::optional>; +}; + +template <> +struct FnArgType { + using Type = RemainingRets; +}; + template struct FnArgType> { using Type = typename AttrDecoding::Type; @@ -1220,7 +1211,6 @@ struct FnArgType> { using Type = T; }; -// Extracts the underlying type from the context type tag. template struct FnArgType> { using Type = typename CtxDecoding::Type; @@ -1230,20 +1220,27 @@ struct FnArgType> { // a special decoding rule defined by template specialization. template struct IsTagged : std::false_type {}; + +template +struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; template +struct IsTagged> : std::true_type {}; +template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; template struct IsTagged> : std::true_type {}; + template <> struct IsTagged : std::true_type {}; template <> struct IsTagged : std::true_type {}; -// A template for counting regular arguments in the Ts pack. +// A template for counting regular arguments in the Ts pack (arguments that are +// not wrapped into a special tag). template struct NumArgs; @@ -1269,9 +1266,15 @@ class Handler : public Ffi { static constexpr int64_t kNumArgs = internal::NumArgs::value; + static constexpr int64_t kNumOptionalArgs = + internal::NumTagged::value; + static constexpr int64_t kNumRets = internal::NumTagged::value; + static constexpr int64_t kNumOptionalRets = + internal::NumTagged::value; + static constexpr int64_t kNumAttrs = internal::NumTagged::value; @@ -1292,22 +1295,22 @@ class Handler : public Ffi { public: XLA_FFI_Error* Call(const XLA_FFI_CallFrame* call_frame) const override { // Sanity checking call frame struct size. - if (auto* err = CheckStructSize(call_frame->api, "XLA_FFI_CallFrame", - XLA_FFI_CallFrame_STRUCT_SIZE, - call_frame->struct_size)) + if (XLA_FFI_Error* err = CheckStructSize( + call_frame->api, "XLA_FFI_CallFrame", XLA_FFI_CallFrame_STRUCT_SIZE, + call_frame->struct_size)) { return err; + } // Check the API versions. - auto api_version = call_frame->api->api_version; + const XLA_FFI_Api_Version& api_version = call_frame->api->api_version; if (api_version.major_version != XLA_FFI_API_MAJOR || api_version.minor_version != XLA_FFI_API_MINOR) { return InvalidArgument( call_frame->api, StrCat("FFI handler's API version (", XLA_FFI_API_MAJOR, ".", - XLA_FFI_API_MINOR, - ") does not match the framework's API version (", - api_version.major_version, ".", api_version.minor_version, - ")")); + XLA_FFI_API_MINOR, ") does not match the framework's API ", + "version (", api_version.major_version, ".", + api_version.minor_version, ")")); } // Check that handler is called during correct execution stage. @@ -1321,12 +1324,21 @@ class Handler : public Ffi { // Check that the number of passed arguments matches the signature. Each // individual argument decoding will check the actual type. - if (internal::HasRemainingArgsTag::value) { + if constexpr (internal::HasRemainingArgsTag::value) { + if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of arguments: expected at least ", + kNumArgs - kNumOptionalArgs - 1, " but got ", + call_frame->args.size)); + } + } else if constexpr (internal::HasOptionalArgTag::value) { if (XLA_FFI_PREDICT_FALSE(call_frame->args.size < kNumArgs)) { return InvalidArgument( call_frame->api, StrCat("Wrong number of arguments: expected at least ", - kNumArgs - 1, " but got ", call_frame->args.size)); + kNumArgs - kNumOptionalArgs, " but got ", + call_frame->args.size)); } } else { if (XLA_FFI_PREDICT_FALSE(call_frame->args.size != kNumArgs)) { @@ -1339,12 +1351,21 @@ class Handler : public Ffi { // Check that the number of results matches the signature. Each individual // result decoding will check the actual type. - if (internal::HasRemainingRetsTag::value) { + if constexpr (internal::HasRemainingRetsTag::value) { if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) { return InvalidArgument( call_frame->api, - StrCat("Wrong number of results: expected at least ", kNumRets - 1, - " but got ", call_frame->rets.size)); + StrCat("Wrong number of results: expected at least ", + kNumRets - kNumOptionalRets - 1, " but got ", + call_frame->rets.size)); + } + } else if constexpr (internal::HasOptionalRetTag::value) { + if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size < kNumRets)) { + return InvalidArgument( + call_frame->api, + StrCat("Wrong number of results: expected at least ", + kNumRets - kNumOptionalRets, " but got ", + call_frame->rets.size)); } } else { if (XLA_FFI_PREDICT_FALSE(call_frame->rets.size != kNumRets)) { @@ -1515,21 +1536,6 @@ struct AttrDecoding { } }; -template <> -struct AttrDecoding { - using Type = Dictionary; - static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine& diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { - return diagnostic.Emit("Wrong attribute type: expected ") - << XLA_FFI_AttrType_DICTIONARY << " but got " << type; - } - - auto* attrs = reinterpret_cast(attr); - return Dictionary(attrs); - } -}; - //===----------------------------------------------------------------------===// // Automatic dictionary attributes to structs decoding. //===----------------------------------------------------------------------===// @@ -1574,7 +1580,7 @@ struct DecodeDictionaryAttr { // // Consider using `static auto decoder = ...` below, and compute mapping in // constructor. Add benchmarks first to know what to improve! - Dictionary dict(attrs); + internal::DictionaryBase dict(attrs); std::tuple...> members = { dict.get(names[Is], diagnostic)...}; @@ -1637,7 +1643,7 @@ auto DictionaryDecoder(Members... m) { // type to decode the attribute as a scalar value and cast it to the enum type. #define XLA_FFI_REGISTER_ENUM_ATTR_DECODING(T) \ template <> \ - struct ::xla::ffi::AttrDecoding { \ + struct xla::ffi::AttrDecoding { \ using Type = T; \ using U = std::underlying_type_t; \ static_assert(std::is_enum::value, "Expected enum class"); \ diff --git a/third_party/xla/xla/ffi/api/c_api_internal.h b/third_party/xla/xla/ffi/api/c_api_internal.h index 3c5c2baf1cdbd5..da5ea3295e6764 100644 --- a/third_party/xla/xla/ffi/api/c_api_internal.h +++ b/third_party/xla/xla/ffi/api/c_api_internal.h @@ -74,6 +74,11 @@ typedef void* XLA_FFI_INTERNAL_ExecutionContext_Get( typedef void* XLA_FFI_INTERNAL_ExecutionState_Get( XLA_FFI_ExecutionContext* ctx); +// Returns a pointer to the `Eigen::ThreadPoolDevice` passed via run options, +// which allows FFI handlers to execute tasks in the same thread pool as XLA. +typedef void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get( + XLA_FFI_ExecutionContext* ctx); + //===----------------------------------------------------------------------===// // API access //===----------------------------------------------------------------------===// @@ -89,6 +94,7 @@ struct XLA_FFI_InternalApi { _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_CalledComputation_Get); _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionContext_Get); _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_ExecutionState_Get); + _XLA_FFI_INTERNAL_API_STRUCT_FIELD(XLA_FFI_INTERNAL_IntraOpThreadPool_Get); }; #undef _XLA_FFI_INTERNAL_API_STRUCT_FIELD diff --git a/third_party/xla/xla/ffi/api/ffi.h b/third_party/xla/xla/ffi/api/ffi.h index c48996b2270871..c35fdce6478a81 100644 --- a/third_party/xla/xla/ffi/api/ffi.h +++ b/third_party/xla/xla/ffi/api/ffi.h @@ -36,6 +36,7 @@ limitations under the License. #include #include #include +#include #include #include "xla/ffi/api/c_api.h" @@ -75,6 +76,30 @@ enum class DataType : uint8_t { F8E4M3FNUZ = XLA_FFI_DataType_F8E4M3FNUZ, }; +// Create aliases in ::xla::ffi namespace for all DataTypes, for consistency +// with xla that defines PrimitiveType enums in ::xla namespace. +inline constexpr DataType PRED = DataType::PRED; +inline constexpr DataType S8 = DataType::S8; +inline constexpr DataType S16 = DataType::S16; +inline constexpr DataType S32 = DataType::S32; +inline constexpr DataType S64 = DataType::S64; +inline constexpr DataType U8 = DataType::U8; +inline constexpr DataType U16 = DataType::U16; +inline constexpr DataType U32 = DataType::U32; +inline constexpr DataType U64 = DataType::U64; +inline constexpr DataType F16 = DataType::F16; +inline constexpr DataType F32 = DataType::F32; +inline constexpr DataType F64 = DataType::F64; +inline constexpr DataType BF16 = DataType::BF16; +inline constexpr DataType C64 = DataType::C64; +inline constexpr DataType C128 = DataType::C128; +inline constexpr DataType TOKEN = DataType::TOKEN; +inline constexpr DataType F8E5M2 = DataType::F8E5M2; +inline constexpr DataType F8E4M3FN = DataType::F8E4M3FN; +inline constexpr DataType F8E4M3B11FNUZ = DataType::F8E4M3B11FNUZ; +inline constexpr DataType F8E5M2FNUZ = DataType::F8E5M2FNUZ; +inline constexpr DataType F8E4M3FNUZ = DataType::F8E4M3FNUZ; + inline std::ostream& operator<<(std::ostream& os, const DataType dtype) { return os << static_cast(dtype); } @@ -149,7 +174,7 @@ class Span { }; //===----------------------------------------------------------------------===// -// Error and ErrorOr +// Error //===----------------------------------------------------------------------===// enum class ErrorCode : uint8_t { @@ -182,19 +207,89 @@ class Error { Error(XLA_FFI_Error_Code errc, std::string message) : Error(static_cast(errc), std::move(message)) {} - static Error Success() { return Error(); } - bool success() const { return errc_ == ErrorCode::kOk; } bool failure() const { return !success(); } std::optional errc() const { return errc_; } const std::string& message() const { return message_; } + static Error Success() { return Error(); } + + static Error Internal(std::string message) { + return Error(ErrorCode::kInternal, std::move(message)); + } + + static Error InvalidArgument(std::string message) { + return Error(ErrorCode::kInvalidArgument, std::move(message)); + } + private: ErrorCode errc_ = ErrorCode::kOk; std::string message_; }; +//===----------------------------------------------------------------------===// +// Expected and ErrorOr +//===----------------------------------------------------------------------===// + +// Forward declare. +template +class Unexpected; + +// TODO(slebedev): Replace with `std::expected` when C++23 is available. +template +class Expected { + public: + constexpr Expected(T value) : data_(std::move(value)) {} // NOLINT + constexpr Expected(Unexpected u); // NOLINT + + constexpr operator bool() const { // NOLINT + return has_value(); + } + + constexpr T& operator*() & { return value(); } + constexpr const T& operator*() const& { return value(); } + constexpr T&& operator*() && { return std::move(value()); } + constexpr const T& operator*() const&& { return std::move(value()); } + + constexpr T* operator->() { return &value(); } + constexpr const T* operator->() const { return &value(); } + + constexpr bool has_value() const { return std::holds_alternative(data_); } + constexpr bool has_error() const { return std::holds_alternative(data_); } + + constexpr T& value() & { return std::get(data_); } + constexpr const T& value() const& { return std::get(data_); } + constexpr T&& value() && { return std::get(std::move(data_)); } + constexpr const T& value() const&& { return std::get(std::move(data_)); } + + constexpr E& error() & { return std::get(data_); } + constexpr const E& error() const& { return std::get(data_); } + constexpr E&& error() && { return std::get(std::move(data_)); } + constexpr const E&& error() const&& { return std::get(std::move(data_)); } + + private: + std::variant data_; +}; + +template +class Unexpected { + public: + constexpr Unexpected(E error) : error_(std::move(error)) {} // NOLINT + + private: + template + friend class Expected; + + E error_; +}; + +Unexpected(const char*) -> Unexpected; + +template +constexpr Expected::Expected(Unexpected u) + : data_(std::move(u.error_)) {} + template class ErrorOr : public Expected { public: @@ -483,6 +578,42 @@ struct ArgDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing a variable number of arguments. +//===----------------------------------------------------------------------===// + +class RemainingArgs : public internal::RemainingArgsBase { + public: + using internal::RemainingArgsBase::RemainingArgsBase; + + template + ErrorOr get(size_t index) const { + size_t idx = offset() + index; + if (XLA_FFI_PREDICT_FALSE(idx >= args()->size)) { + return Unexpected( + Error(ErrorCode::kInvalidArgument, "Index out of range")); + } + + DiagnosticEngine diagnostic; + std::optional value = ArgDecoding::Decode( + args()->types[idx], args()->args[idx], diagnostic); + if (XLA_FFI_PREDICT_FALSE(!value.has_value())) { + return Unexpected(Error::Internal(diagnostic.Result())); + } + + return *value; + } +}; + +template <> +struct internal::Decode { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return RemainingArgs(&ctx.call_frame->args, offsets.args); + } +}; + //===----------------------------------------------------------------------===// // Results decoding //===----------------------------------------------------------------------===// @@ -523,6 +654,42 @@ struct RetDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing a variable number of results. +//===----------------------------------------------------------------------===// + +class RemainingRets : public internal::RemainingRetsBase { + public: + using internal::RemainingRetsBase::RemainingRetsBase; + + template + ErrorOr> get(size_t index) const { + size_t idx = offset() + index; + if (XLA_FFI_PREDICT_FALSE(idx >= rets()->size)) { + return Unexpected( + Error(ErrorCode::kInvalidArgument, "Index out of range")); + } + + DiagnosticEngine diagnostic; + std::optional> value = RetDecoding::Decode( + rets()->types[idx], rets()->rets[idx], diagnostic); + if (XLA_FFI_PREDICT_FALSE(!value.has_value())) { + return Unexpected(Error::Internal(diagnostic.Result())); + } + + return *value; + } +}; + +template <> +struct internal::Decode { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return RemainingRets(&ctx.call_frame->rets, offsets.rets); + } +}; + //===----------------------------------------------------------------------===// // Attributes decoding //===----------------------------------------------------------------------===// @@ -580,6 +747,49 @@ struct AttrDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing dictionary attributes. +//===----------------------------------------------------------------------===// + +class Dictionary : public internal::DictionaryBase { + public: + using internal::DictionaryBase::DictionaryBase; + + template + ErrorOr get(std::string_view name) const { + DiagnosticEngine diagnostic; + std::optional value = internal::DictionaryBase::get(name, diagnostic); + if (!value.has_value()) { + return Unexpected(Error::Internal(diagnostic.Result())); + } + return *value; + } +}; + +// Decode `AttrsTag` (all attributes) into a `Dictionary`. +template <> +struct internal::Decode> { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return Dictionary(&ctx.call_frame->attrs); + } +}; + +// Decode individual attribute into `Dictionary` type. +template <> +struct AttrDecoding { + using Type = Dictionary; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_DICTIONARY << " but got " << type; + } + return Dictionary(reinterpret_cast(attr)); + } +}; + //===----------------------------------------------------------------------===// // Error helpers //===----------------------------------------------------------------------===// @@ -758,6 +968,7 @@ inline std::optional ScratchAllocator::Allocate(size_t size, internal::DestroyError(api_, error); return std::nullopt; } + allocations_.push_back({size, args.data}); return args.data; } diff --git a/third_party/xla/xla/ffi/api/ffi_test.cc b/third_party/xla/xla/ffi/api/ffi_test.cc index a677c4e355ee0c..2bbfd048688bc8 100644 --- a/third_party/xla/xla/ffi/api/ffi_test.cc +++ b/third_party/xla/xla/ffi/api/ffi_test.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include #include +#include #include +#include #include #include @@ -34,8 +36,8 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" @@ -237,12 +239,11 @@ TEST(FfiTest, BufferArgument) { builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = - Ffi::Bind().Arg>().To([&](auto buffer) { - EXPECT_EQ(buffer.typed_data(), storage.data()); - EXPECT_EQ(buffer.dimensions().size(), 2); - return Error::Success(); - }); + auto handler = Ffi::Bind().Arg>().To([&](auto buffer) { + EXPECT_EQ(buffer.typed_data(), storage.data()); + EXPECT_EQ(buffer.dimensions().size(), 2); + return Error::Success(); + }); auto status = Call(*handler, call_frame); TF_ASSERT_OK(status); @@ -270,7 +271,7 @@ TEST(FfiTest, MissingBufferArgument) { CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -286,7 +287,7 @@ TEST(FfiTest, WrongRankBufferArgument) { builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -303,7 +304,7 @@ TEST(FfiTest, WrongTypeBufferArgument) { builder.AddBufferArg(memory, PrimitiveType::S32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto handler = Ffi::Bind().Arg>().To( + auto handler = Ffi::Bind().Arg>().To( [](auto) { return Error::Success(); }); auto status = Call(*handler, call_frame); @@ -322,7 +323,7 @@ TEST(FfiTest, TokenArgument) { auto fn = [&](Token tok) { EXPECT_EQ(tok.typed_data(), nullptr); EXPECT_EQ(tok.dimensions().size(), 0); - return ffi::Error::Success(); + return Error::Success(); }; auto handler = Ffi::Bind().Arg().To(fn); @@ -330,6 +331,182 @@ TEST(FfiTest, TokenArgument) { TF_ASSERT_OK(status); } +TEST(FfiTest, RemainingArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto fn = [&](RemainingArgs args) { + EXPECT_EQ(args.size(), 1); + + ErrorOr arg0 = args.get(0); + ErrorOr arg1 = args.get(1); + + EXPECT_TRUE(arg0.has_value()); + EXPECT_FALSE(arg1.has_value()); + + return Error::Success(); + }; + + auto handler = Ffi::Bind().RemainingArgs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, RemainingRets) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/2); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + auto fn = [&](Result ret, RemainingRets rets) { + EXPECT_EQ(rets.size(), 1); + + ErrorOr> ret0 = rets.get(0); + ErrorOr> ret1 = rets.get(1); + + EXPECT_TRUE(ret0.has_value()); + EXPECT_FALSE(ret1.has_value()); + + return Error::Success(); + }; + + auto handler = Ffi::Bind().Ret().RemainingRets().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, OptionalArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional argument. + auto fn = [&](std::optional arg0) { + EXPECT_TRUE(arg0.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional arguments. + auto fn = [&](std::optional arg0, + std::optional arg1) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_FALSE(arg1.has_value()); + return Error::Success(); + }; + + auto handler = + Ffi::Bind().OptionalArg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional argument after a regular one. + auto fn = [&](AnyBuffer arg0, std::optional arg1) { + EXPECT_FALSE(arg1.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Arg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining arguments after optional one. + auto fn = [&](std::optional arg0, RemainingArgs args) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_EQ(args.size(), 0); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalArg().RemainingArgs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, OptionalRets) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional result. + auto fn = [&](std::optional> ret0) { + EXPECT_TRUE(ret0.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional results. + auto fn = [&](std::optional> ret0, + std::optional> ret1) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_FALSE(ret1.has_value()); + return Error::Success(); + }; + + auto handler = + Ffi::Bind().OptionalRet().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional result after a regular one. + auto fn = [&](Result ret0, + std::optional> ret1) { + EXPECT_FALSE(ret1.has_value()); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Ret().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining results after optional one. + auto fn = [&](std::optional> ret0, RemainingRets rets) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_EQ(rets.size(), 0); + return Error::Success(); + }; + + auto handler = Ffi::Bind().OptionalRet().RemainingRets().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + TEST(FfiTest, AutoBinding) { static constexpr char kI32[] = "i32"; @@ -463,6 +640,150 @@ TEST(FfiTest, ArrayAttr) { TF_ASSERT_OK(status); } +TEST(FfiTest, AttrsAsDictionary) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("i32", 42); + attrs.Insert("f32", 42.0f); + attrs.Insert("str", "foo"); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](Dictionary dict) { + EXPECT_EQ(dict.size(), 3); + + EXPECT_TRUE(dict.contains("i32")); + EXPECT_TRUE(dict.contains("f32")); + EXPECT_TRUE(dict.contains("str")); + + ErrorOr i32 = dict.get("i32"); + ErrorOr f32 = dict.get("f32"); + ErrorOr str = dict.get("str"); + + EXPECT_TRUE(i32.has_value()); + EXPECT_TRUE(f32.has_value()); + EXPECT_TRUE(str.has_value()); + + if (i32.has_value()) EXPECT_EQ(*i32, 42); + if (f32.has_value()) EXPECT_EQ(*f32, 42.0f); + if (str.has_value()) EXPECT_EQ(*str, "foo"); + + EXPECT_FALSE(dict.contains("i64")); + EXPECT_FALSE(dict.get("i32").has_value()); + EXPECT_FALSE(dict.get("i64").has_value()); + + return Error::Success(); + }; + + auto handler = Ffi::Bind().Attrs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, DictionaryAttr) { + CallFrameBuilder::FlatAttributesMap dict0; + dict0.try_emplace("i32", 42); + + CallFrameBuilder::FlatAttributesMap dict1; + dict1.try_emplace("f32", 42.0f); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("dict0", dict0); + attrs.Insert("dict1", dict1); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](Dictionary dict0, Dictionary dict1) { + EXPECT_EQ(dict0.size(), 1); + EXPECT_EQ(dict1.size(), 1); + + EXPECT_TRUE(dict0.contains("i32")); + EXPECT_TRUE(dict1.contains("f32")); + + ErrorOr i32 = dict0.get("i32"); + ErrorOr f32 = dict1.get("f32"); + + EXPECT_TRUE(i32.has_value()); + EXPECT_TRUE(f32.has_value()); + + if (i32.has_value()) EXPECT_EQ(*i32, 42); + if (f32.has_value()) EXPECT_EQ(*f32, 42.0f); + + return Error::Success(); + }; + + auto handler = + Ffi::Bind().Attr("dict0").Attr("dict1").To(fn); + + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +struct PairOfI32AndF32 { + int32_t i32; + float f32; +}; + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(PairOfI32AndF32, + StructMember("i32"), + StructMember("f32")); + +TEST(FfiTest, StructAttr) { + CallFrameBuilder::FlatAttributesMap dict; + dict.try_emplace("i32", 42); + dict.try_emplace("f32", 42.0f); + + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("str", "foo"); + attrs.Insert("i32_and_f32", dict); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](std::string_view str, PairOfI32AndF32 i32_and_f32) { + EXPECT_EQ(str, "foo"); + EXPECT_EQ(i32_and_f32.i32, 42); + EXPECT_EQ(i32_and_f32.f32, 42.0f); + return Error::Success(); + }; + + auto handler = Ffi::Bind() + .Attr("str") + .Attr("i32_and_f32") + .To(fn); + + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + +TEST(FfiTest, AttrsAsStruct) { + CallFrameBuilder::AttributesBuilder attrs; + attrs.Insert("i32", 42); + attrs.Insert("f32", 42.0f); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0); + builder.AddAttributes(attrs.Build()); + auto call_frame = builder.Build(); + + auto fn = [&](PairOfI32AndF32 i32_and_f32) { + EXPECT_EQ(i32_and_f32.i32, 42); + EXPECT_EQ(i32_and_f32.f32, 42.0f); + return Error::Success(); + }; + + auto handler = Ffi::Bind().Attrs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); +} + TEST(FfiTest, PointerAttr) { std::string foo = "foo"; @@ -641,14 +962,19 @@ TEST(FfiTest, ScratchAllocator) { // A test only memory allocator that returns a fixed memory address. struct TestDeviceMemoryAllocator final : public se::DeviceMemoryAllocator { - TestDeviceMemoryAllocator() : se::DeviceMemoryAllocator(nullptr) {} + size_t count; + + TestDeviceMemoryAllocator() + : se::DeviceMemoryAllocator(nullptr), count(0) {} absl::StatusOr Allocate(int, uint64_t size, bool, int64_t) final { + count++; return se::OwningDeviceMemory(se::DeviceMemoryBase(kAddr, size), 0, this); } absl::Status Deallocate(int, se::DeviceMemoryBase mem) final { + count--; EXPECT_EQ(mem.opaque(), kAddr); return absl::OkStatus(); } @@ -672,11 +998,25 @@ TEST(FfiTest, ScratchAllocator) { CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build(); CallOptions options; - options.allocator = &allocator; + options.backend_options = CallOptions::GpuOptions{nullptr, &allocator}; auto status = Call(*handler, call_frame, options); TF_ASSERT_OK(status); + EXPECT_EQ(allocator.count, 0); +} + +TEST(FfiTest, ScratchAllocatorUnimplemented) { + auto fn = [&](ScratchAllocator scratch_allocator) { + auto mem = scratch_allocator.Allocate(1024); + EXPECT_FALSE(mem.has_value()); + return Error::Success(); + }; + auto handler = Ffi::Bind().Ctx().To(fn); + CallFrame call_frame = + CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build(); + auto status = Call(*handler, call_frame); + TF_ASSERT_OK(status); } //===----------------------------------------------------------------------===// @@ -747,7 +1087,7 @@ BENCHMARK(BM_AnyBufferArgX4); void BM_BufferArgX1(benchmark::State& state) { auto call_frame = WithBufferArgs(1).Build(); - auto handler = Ffi::Bind().Arg>().To([](auto buffer) { + auto handler = Ffi::Bind().Arg>().To([](auto buffer) { benchmark::DoNotOptimize(buffer); return Error::Success(); }); @@ -767,10 +1107,10 @@ void BM_BufferArgX4(benchmark::State& state) { auto call_frame = WithBufferArgs(4).Build(); auto handler = Ffi::Bind() - .Arg>() - .Arg>() - .Arg>() - .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() .To([](auto b0, auto b1, auto b2, auto b3) { benchmark::DoNotOptimize(b0); benchmark::DoNotOptimize(b1); @@ -794,14 +1134,14 @@ void BM_BufferArgX8(benchmark::State& state) { auto call_frame = WithBufferArgs(8).Build(); auto handler = Ffi::Bind() - .Arg>() - .Arg>() - .Arg>() - .Arg>() - .Arg>() - .Arg>() - .Arg>() - .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() + .Arg>() .To([](auto b0, auto b1, auto b2, auto b3, auto b4, auto b5, auto b6, auto b7) { benchmark::DoNotOptimize(b0); diff --git a/third_party/xla/xla/ffi/call_frame_test.cc b/third_party/xla/xla/ffi/call_frame_test.cc index 2937b53bb5d997..7b767bfb841af8 100644 --- a/third_party/xla/xla/ffi/call_frame_test.cc +++ b/third_party/xla/xla/ffi/call_frame_test.cc @@ -24,8 +24,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/ffi/api/c_api.h" #include "xla/stream_executor/device_memory.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/ffi/execution_context_test.cc b/third_party/xla/xla/ffi/execution_context_test.cc index 7a2a1b33992ede..6a5cdfa40b07b6 100644 --- a/third_party/xla/xla/ffi/execution_context_test.cc +++ b/third_party/xla/xla/ffi/execution_context_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/ffi/type_id_registry.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/ffi/execution_state_test.cc b/third_party/xla/xla/ffi/execution_state_test.cc index d8929246ca0161..dd8244f00183ff 100644 --- a/third_party/xla/xla/ffi/execution_state_test.cc +++ b/third_party/xla/xla/ffi/execution_state_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/ffi/ffi.h b/third_party/xla/xla/ffi/ffi.h index 2ace88ab66ea48..82076c5d128d09 100644 --- a/third_party/xla/xla/ffi/ffi.h +++ b/third_party/xla/xla/ffi/ffi.h @@ -27,6 +27,8 @@ limitations under the License. #include #include #include +#include +#include // IWYU pragma: begin_exports #include "xla/ffi/api/api.h" @@ -38,6 +40,7 @@ limitations under the License. #include "absl/base/optimization.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "xla/executable_run_options.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep #include "xla/ffi/execution_context.h" @@ -49,6 +52,7 @@ limitations under the License. #include "xla/stream_executor/scratch_allocator.h" #include "xla/stream_executor/stream.h" #include "xla/types.h" // IWYU pragma: keep +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/logging.h" @@ -60,6 +64,7 @@ struct DeviceOrdinal {}; // binds `int32_t` with device ordinal struct Allocator {}; // binds `se::DeviceMemoryAllocator*` struct ScratchAllocator {}; // binds `se::OwningScratchAllocator` struct CalledComputation {}; // binds `HloComputation*` +struct IntraOpThreadPool {}; // binds `const Eigen::ThreadPoolDevice*` //===----------------------------------------------------------------------===// // Arguments @@ -238,6 +243,41 @@ struct ArgDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing a variable number of arguments. +//===----------------------------------------------------------------------===// + +class RemainingArgs : public internal::RemainingArgsBase { + public: + using internal::RemainingArgsBase::RemainingArgsBase; + + template + absl::StatusOr get(size_t index) const { + size_t idx = offset() + index; + if (ABSL_PREDICT_FALSE(idx >= args()->size)) { + return InvalidArgument("Index out of range."); + } + + DiagnosticEngine diagnostic; + std::optional value = ArgDecoding::Decode( + args()->types[idx], args()->args[idx], diagnostic); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + return Internal("%s", diagnostic.Result()); + } + + return *value; + } +}; + +template <> +struct internal::Decode { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return RemainingArgs(&ctx.call_frame->args, offsets.args); + } +}; + //===----------------------------------------------------------------------===// // Results decoding //===----------------------------------------------------------------------===// @@ -271,6 +311,41 @@ struct RetDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing a variable number of results. +//===----------------------------------------------------------------------===// + +class RemainingRets : public internal::RemainingRetsBase { + public: + using internal::RemainingRetsBase::RemainingRetsBase; + + template + absl::StatusOr> get(size_t index) const { + size_t idx = offset() + index; + if (ABSL_PREDICT_FALSE(idx >= rets()->size)) { + return InvalidArgument("Index out of range."); + } + + DiagnosticEngine diagnostic; + std::optional> value = RetDecoding::Decode( + rets()->types[idx], rets()->rets[idx], diagnostic); + if (ABSL_PREDICT_FALSE(!value.has_value())) { + return Internal("%s", diagnostic.Result()); + } + + return *value; + } +}; + +template <> +struct internal::Decode { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return RemainingRets(&ctx.call_frame->rets, offsets.rets); + } +}; + //===----------------------------------------------------------------------===// // Attributes decoding //===----------------------------------------------------------------------===// @@ -329,6 +404,49 @@ struct AttrDecoding> { } }; +//===----------------------------------------------------------------------===// +// Type-safe wrapper for accessing dictionary attributes. +//===----------------------------------------------------------------------===// + +class Dictionary : public internal::DictionaryBase { + public: + using internal::DictionaryBase::DictionaryBase; + + template + absl::StatusOr get(std::string_view name) const { + DiagnosticEngine diagnostic; + std::optional value = internal::DictionaryBase::get(name, diagnostic); + if (!value.has_value()) { + return Internal("%s", diagnostic.Result()); + } + return *value; + } +}; + +// Decode `AttrsTag` (all attributes) into a `Dictionary`. +template <> +struct internal::Decode> { + static std::optional call(DecodingOffsets& offsets, + DecodingContext& ctx, + DiagnosticEngine& diagnostic) { + return Dictionary(&ctx.call_frame->attrs); + } +}; + +// Decode individual attribute into `Dictionary` type. +template <> +struct AttrDecoding { + using Type = Dictionary; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Wrong attribute type: expected ") + << XLA_FFI_AttrType_DICTIONARY << " but got " << type; + } + return Dictionary(reinterpret_cast(attr)); + } +}; + //===----------------------------------------------------------------------===// // Context decoding //===----------------------------------------------------------------------===// @@ -365,7 +483,7 @@ struct CtxDecoding { DiagnosticEngine&) { void* device_allocator = api->internal_api->XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get(ctx); - return reinterpret_cast(device_allocator); + return reinterpret_cast(device_allocator); } }; @@ -399,6 +517,19 @@ struct CtxDecoding { } }; +template <> +struct CtxDecoding { + using Type = const Eigen::ThreadPoolDevice*; + + static std::optional Decode(const XLA_FFI_Api* api, + XLA_FFI_ExecutionContext* ctx, + DiagnosticEngine&) { + void* intra_op_thread_pool = + api->internal_api->XLA_FFI_INTERNAL_IntraOpThreadPool_Get(ctx); + return reinterpret_cast(intra_op_thread_pool); + } +}; + //===----------------------------------------------------------------------===// // UserData //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/ffi/ffi_api.cc b/third_party/xla/xla/ffi/ffi_api.cc index f402ed24b32dcd..c5f07ebc020298 100644 --- a/third_party/xla/xla/ffi/ffi_api.cc +++ b/third_party/xla/xla/ffi/ffi_api.cc @@ -15,12 +15,14 @@ limitations under the License. #include "xla/ffi/ffi_api.h" +#include #include #include #include #include #include #include +#include #include #include "absl/base/optimization.h" @@ -31,6 +33,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "xla/executable_run_options.h" #include "xla/ffi/api/api.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep @@ -56,10 +59,19 @@ struct XLA_FFI_Error { }; struct XLA_FFI_ExecutionContext { - int32_t device_ordinal = -1; + struct CpuContext { + const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr; + }; + + struct GpuContext { + stream_executor::Stream* stream = nullptr; + stream_executor::DeviceMemoryAllocator* allocator = nullptr; + }; + + using BackendContext = std::variant; - stream_executor::Stream* stream = nullptr; - stream_executor::DeviceMemoryAllocator* allocator = nullptr; + int32_t device_ordinal = -1; + BackendContext backend_context = {}; const xla::HloComputation* called_computation = nullptr; const xla::ffi::ExecutionContext* execution_context = nullptr; @@ -76,10 +88,27 @@ bool IsCommandBufferCompatible(XLA_FFI_Handler_Traits traits) { static XLA_FFI_ExecutionContext CreateExecutionContext( const CallOptions& options) { + using BackendContext = XLA_FFI_ExecutionContext::BackendContext; + + // Converts CallOptions to corresponding backend context. + struct BackendVisitor { + BackendContext operator()(const std::monostate&) const { + return std::monostate{}; + } + + BackendContext operator()(const CallOptions::CpuOptions& options) const { + return XLA_FFI_ExecutionContext::CpuContext{options.intra_op_thread_pool}; + } + + BackendContext operator()(const CallOptions::GpuOptions& options) const { + return XLA_FFI_ExecutionContext::GpuContext{options.stream, + options.allocator}; + } + }; + return XLA_FFI_ExecutionContext{ options.device_ordinal, - options.stream, - options.allocator, + std::visit(BackendVisitor{}, options.backend_options), options.called_computation, internal::ScopedExecutionContext::GetCallExecutionContext(options), options.execution_state, @@ -376,12 +405,20 @@ static XLA_FFI_Error* XLA_FFI_Stream_Get(XLA_FFI_Stream_Get_Args* args) { "XLA_FFI_Stream_Get", XLA_FFI_Stream_Get_Args_STRUCT_SIZE, args->struct_size)); - if (args->ctx->stream == nullptr) { + auto* gpu = std::get_if( + &args->ctx->backend_context); + + if (ABSL_PREDICT_FALSE(gpu == nullptr)) { return new XLA_FFI_Error{ - InvalidArgument("XLA FFI stream is not available")}; + Unimplemented("XLA FFI GPU context is not available")}; } - auto handle = args->ctx->stream->platform_specific_handle(); + if (ABSL_PREDICT_FALSE(gpu->stream == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("XLA FFI GPU stream is not available")}; + } + + auto handle = gpu->stream->platform_specific_handle(); args->stream = handle.stream; return nullptr; @@ -459,6 +496,22 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Allocate( "XLA_FFI_DeviceMemory_Allocate_Args", XLA_FFI_DeviceMemory_Allocate_Args_STRUCT_SIZE, args->struct_size)); + auto* gpu = std::get_if( + &args->ctx->backend_context); + + // TODO(ezhulenev): Device memory allocation should be supported for all + // backends, not just GPU, although for CPU it doesn't make much sense, as + // plain `new` is sufficient. + if (ABSL_PREDICT_FALSE(gpu == nullptr)) { + return new XLA_FFI_Error{ + InvalidArgument("XLA FFI GPU context is not available")}; + } + + if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("No device memory allocator available on this platform")}; + } + // TODO(ezhulenev): We happen to have the same alignment requirement for // device memory on CPU and GPU backends, but instead of hardcoding it here // we should query it for the platform XLA FFI handler is registered with. @@ -471,7 +524,7 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Allocate( } absl::StatusOr memory = - args->ctx->allocator->Allocate(args->ctx->device_ordinal, args->size); + gpu->allocator->Allocate(args->ctx->device_ordinal, args->size); if (!memory.ok()) { return new XLA_FFI_Error{std::move(memory).status()}; } @@ -486,7 +539,23 @@ static XLA_FFI_Error* XLA_FFI_DeviceMemory_Free( "XLA_FFI_DeviceMemory_Free_Args", XLA_FFI_DeviceMemory_Free_Args_STRUCT_SIZE, args->struct_size)); - absl::Status status = args->ctx->allocator->Deallocate( + auto* gpu = std::get_if( + &args->ctx->backend_context); + + // TODO(ezhulenev): Device memory allocation should be supported for all + // backends, not just GPU, although for CPU it doesn't make much sense, as + // plain `new` is sufficient. + if (ABSL_PREDICT_FALSE(gpu == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("XLA FFI GPU context is not available")}; + } + + if (ABSL_PREDICT_FALSE(gpu->allocator == nullptr)) { + return new XLA_FFI_Error{ + Unimplemented("No device memory allocator available on this platform")}; + } + + absl::Status status = gpu->allocator->Deallocate( args->ctx->device_ordinal, stream_executor::DeviceMemoryBase(args->data, args->size)); if (!status.ok()) { @@ -509,7 +578,13 @@ static XLA_FFI_Error* XLA_FFI_INTERNAL_Error_Forward(void* status) { } static void* XLA_FFI_INTERNAL_Stream_Get(XLA_FFI_ExecutionContext* ctx) { - return ctx->stream; + if (auto* gpu = std::get_if( + &ctx->backend_context)) { + return gpu->stream; + } + + return new XLA_FFI_Error{ + InvalidArgument("XLA FFI GPU context is not available")}; } static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get( @@ -519,7 +594,13 @@ static int32_t XLA_FFI_INTERNAL_DeviceOrdinal_Get( static void* XLA_FFI_INTERNAL_DeviceMemoryAllocator_Get( XLA_FFI_ExecutionContext* ctx) { - return ctx->allocator; + if (auto* gpu = std::get_if( + &ctx->backend_context)) { + return gpu->allocator; + } + + return new XLA_FFI_Error{ + InvalidArgument("XLA FFI GPU context is not available")}; } static void* XLA_FFI_INTERNAL_CalledComputation_Get( @@ -537,6 +618,16 @@ static void* XLA_FFI_INTERNAL_ExecutionState_Get( return const_cast(ctx->execution_state); } +void* XLA_FFI_INTERNAL_IntraOpThreadPool_Get(XLA_FFI_ExecutionContext* ctx) { + if (auto* cpu = std::get_if( + &ctx->backend_context)) { + return const_cast(cpu->intra_op_thread_pool); + } + + return new XLA_FFI_Error{ + InvalidArgument("XLA FFI CPU context is not available")}; +} + //===----------------------------------------------------------------------===// // XLA FFI Api access //===----------------------------------------------------------------------===// @@ -551,6 +642,7 @@ static XLA_FFI_InternalApi internal_api = { XLA_FFI_INTERNAL_CalledComputation_Get, XLA_FFI_INTERNAL_ExecutionContext_Get, XLA_FFI_INTERNAL_ExecutionState_Get, + XLA_FFI_INTERNAL_IntraOpThreadPool_Get, }; static XLA_FFI_Api api = { diff --git a/third_party/xla/xla/ffi/ffi_api.h b/third_party/xla/xla/ffi/ffi_api.h index 7a6e5aa3df9506..f583fa396c4a42 100644 --- a/third_party/xla/xla/ffi/ffi_api.h +++ b/third_party/xla/xla/ffi/ffi_api.h @@ -19,10 +19,12 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "xla/executable_run_options.h" #include "xla/ffi/api/api.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep @@ -47,11 +49,23 @@ namespace xla::ffi { // Calling XLA FFI handlers //===----------------------------------------------------------------------===// +// Options for calling XLA FFI handlers. Backend specific options must be +// constructed from `xla::ExecuteRunOptions`, to give FFI handlers access to +// XLA runtime internals. struct CallOptions { - int32_t device_ordinal = -1; + struct CpuOptions { + const Eigen::ThreadPoolDevice* intra_op_thread_pool = nullptr; + }; + + struct GpuOptions { + se::Stream* stream = nullptr; + se::DeviceMemoryAllocator* allocator = nullptr; + }; - se::Stream* stream = nullptr; - se::DeviceMemoryAllocator* allocator = nullptr; + using BackendOptions = std::variant; + + int32_t device_ordinal = -1; + BackendOptions backend_options = {}; const HloComputation* called_computation = nullptr; const ExecutionContext* execution_context = nullptr; diff --git a/third_party/xla/xla/ffi/ffi_test.cc b/third_party/xla/xla/ffi/ffi_test.cc index 63f5dbf30e20d2..9fb4ff8e249600 100644 --- a/third_party/xla/xla/ffi/ffi_test.cc +++ b/third_party/xla/xla/ffi/ffi_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "xla/ffi/ffi_api.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/stream.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -317,21 +317,21 @@ TEST(FfiTest, AttrsAsDictionary) { EXPECT_TRUE(dict.contains("f32")); EXPECT_TRUE(dict.contains("str")); - auto i32 = dict.get("i32"); - auto f32 = dict.get("f32"); - auto str = dict.get("str"); + absl::StatusOr i32 = dict.get("i32"); + absl::StatusOr f32 = dict.get("f32"); + absl::StatusOr str = dict.get("str"); - EXPECT_TRUE(i32.has_value()); - EXPECT_TRUE(f32.has_value()); - EXPECT_TRUE(str.has_value()); + EXPECT_TRUE(i32.ok()); + EXPECT_TRUE(f32.ok()); + EXPECT_TRUE(str.ok()); - if (i32) EXPECT_EQ(*i32, 42); - if (f32) EXPECT_EQ(*f32, 42.0f); - if (str) EXPECT_EQ(*str, "foo"); + if (i32.ok()) EXPECT_EQ(*i32, 42); + if (f32.ok()) EXPECT_EQ(*f32, 42.0f); + if (str.ok()) EXPECT_EQ(*str, "foo"); EXPECT_FALSE(dict.contains("i64")); - EXPECT_FALSE(dict.get("i32").has_value()); - EXPECT_FALSE(dict.get("i64").has_value()); + EXPECT_FALSE(dict.get("i32").ok()); + EXPECT_FALSE(dict.get("i64").ok()); return absl::OkStatus(); }; @@ -364,14 +364,14 @@ TEST(FfiTest, DictionaryAttr) { EXPECT_TRUE(dict0.contains("i32")); EXPECT_TRUE(dict1.contains("f32")); - auto i32 = dict0.get("i32"); - auto f32 = dict1.get("f32"); + absl::StatusOr i32 = dict0.get("i32"); + absl::StatusOr f32 = dict1.get("f32"); - EXPECT_TRUE(i32.has_value()); - EXPECT_TRUE(f32.has_value()); + EXPECT_TRUE(i32.ok()); + EXPECT_TRUE(f32.ok()); - if (i32) EXPECT_EQ(*i32, 42); - if (f32) EXPECT_EQ(*f32, 42.0f); + if (i32.ok()) EXPECT_EQ(*i32, 42); + if (f32.ok()) EXPECT_EQ(*f32, 42.0f); return absl::OkStatus(); }; @@ -631,8 +631,14 @@ TEST(FfiTest, RemainingArgs) { auto fn = [&](RemainingArgs args) { EXPECT_EQ(args.size(), 1); - EXPECT_TRUE(args.get(0).has_value()); - EXPECT_FALSE(args.get(1).has_value()); + + absl::StatusOr arg0 = args.get(0); + absl::StatusOr arg1 = args.get(1); + + EXPECT_TRUE(arg0.ok()); + EXPECT_THAT(arg1.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Index out of range"))); + return absl::OkStatus(); }; @@ -651,19 +657,148 @@ TEST(FfiTest, RemainingRets) { builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); auto call_frame = builder.Build(); - auto fn = [&](Result ret, RemainingResults rets) { + auto fn = [&](Result ret, RemainingRets rets) { EXPECT_EQ(rets.size(), 1); - EXPECT_TRUE(rets.get(0).has_value()); - EXPECT_FALSE(rets.get(1).has_value()); + + absl::StatusOr> ret0 = rets.get(0); + absl::StatusOr> ret1 = rets.get(1); + + EXPECT_TRUE(ret0.ok()); + EXPECT_THAT(ret1.status(), StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Index out of range"))); + return absl::OkStatus(); }; - auto handler = Ffi::Bind().Ret().RemainingResults().To(fn); + auto handler = Ffi::Bind().Ret().RemainingRets().To(fn); auto status = Call(*handler, call_frame); TF_ASSERT_OK(status); } +TEST(FfiTest, OptionalArgs) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/1, /*num_rets=*/0); + builder.AddBufferArg(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional argument. + auto fn = [&](std::optional arg0) { + EXPECT_TRUE(arg0.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional arguments. + auto fn = [&](std::optional arg0, + std::optional arg1) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_FALSE(arg1.has_value()); + return absl::OkStatus(); + }; + + auto handler = + Ffi::Bind().OptionalArg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional argument after a regular one. + auto fn = [&](AnyBuffer arg0, std::optional arg1) { + EXPECT_FALSE(arg1.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Arg().OptionalArg().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining arguments after optional one. + auto fn = [&](std::optional arg0, RemainingArgs args) { + EXPECT_TRUE(arg0.has_value()); + EXPECT_EQ(args.size(), 0); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalArg().RemainingArgs().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + +TEST(FfiTest, OptionalRets) { + std::vector storage(4, 0.0f); + se::DeviceMemoryBase memory(storage.data(), 4 * sizeof(float)); + + CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/1); + builder.AddBufferRet(memory, PrimitiveType::F32, /*dims=*/{2, 2}); + auto call_frame = builder.Build(); + + { // Single optional result. + auto fn = [&](std::optional> ret0) { + EXPECT_TRUE(ret0.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Two optional results. + auto fn = [&](std::optional> ret0, + std::optional> ret1) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_FALSE(ret1.has_value()); + return absl::OkStatus(); + }; + + auto handler = + Ffi::Bind().OptionalRet().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Optional result after a regular one. + auto fn = [&](Result ret0, + std::optional> ret1) { + EXPECT_FALSE(ret1.has_value()); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().Ret().OptionalRet().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } + + { // Remaining results after optional one. + auto fn = [&](std::optional> ret0, RemainingRets rets) { + EXPECT_TRUE(ret0.has_value()); + EXPECT_EQ(rets.size(), 0); + return absl::OkStatus(); + }; + + auto handler = Ffi::Bind().OptionalRet().RemainingRets().To(fn); + auto status = Call(*handler, call_frame); + + TF_ASSERT_OK(status); + } +} + TEST(FfiTest, RunOptionsCtx) { auto call_frame = CallFrameBuilder(/*num_args=*/0, /*num_rets=*/0).Build(); auto* expected = reinterpret_cast(0x01234567); @@ -674,7 +809,7 @@ TEST(FfiTest, RunOptionsCtx) { }; CallOptions options; - options.stream = expected; + options.backend_options = CallOptions::GpuOptions{expected}; auto handler = Ffi::Bind().Ctx().To(fn); auto status = Call(*handler, call_frame, options); diff --git a/third_party/xla/xla/hlo/evaluator/BUILD b/third_party/xla/xla/hlo/evaluator/BUILD index b857a8a15ad532..47574f0ff59f38 100644 --- a/third_party/xla/xla/hlo/evaluator/BUILD +++ b/third_party/xla/xla/hlo/evaluator/BUILD @@ -135,10 +135,12 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/hlo/ir:hlo", + "//xla/service:call_graph", "//xla/service:dynamic_dimension_inference", "//xla/service:hlo_element_type_converter", "//xla/service:hlo_module_config", "//xla/service:shape_inference", + "//xla/service:tuple_points_to_analysis", "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:test_utils", diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc index 9b51dca7721011..9fe65193b84d97 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.cc @@ -146,8 +146,9 @@ absl::StatusOr Compare(const Shape& shape, Comparison comparison, std::optional GetInstructionStaticValueAsBool( const HloInstruction* instruction) { HloEvaluator evaluator; - absl::StatusOr static_value = evaluator.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr static_value = + evaluator.Evaluate(instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { return static_value->GetFirstElement(); } @@ -232,10 +233,12 @@ struct DynamicOrStaticInteger { }; std::optional GetInstructionValueAsInteger( - const HloInstruction* instruction) { + const HloInstruction* instruction, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { HloEvaluator evaluator; - absl::StatusOr static_value = evaluator.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr static_value = + evaluator.Evaluate(instruction, precomputed_analyses, + /*recursively_evaluate_nonconstant_operands=*/true); if (static_value.ok()) { if (instruction->shape().element_type() == PrimitiveType::PRED) { return DynamicOrStaticInteger{ @@ -274,14 +277,16 @@ struct ParamIndexAndValue { }; std::optional TryParsingInstructionAsParameterAndInteger( - const HloInstruction* instruction) { + const HloInstruction* instruction, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { // Skip copies. if (instruction->opcode() == HloOpcode::kCopy) { - return TryParsingInstructionAsParameterAndInteger(instruction->operand(0)); + return TryParsingInstructionAsParameterAndInteger(instruction->operand(0), + precomputed_analyses); } if (instruction->opcode() == HloOpcode::kCopyDone) { return TryParsingInstructionAsParameterAndInteger( - instruction->operand(0)->operand(1)); + instruction->operand(0)->operand(1), precomputed_analyses); } ParamIndexAndValue result; if (Match(instruction, match::GetTupleElement().WithOperand( @@ -289,7 +294,7 @@ std::optional TryParsingInstructionAsParameterAndInteger( result.param_index = instruction->tuple_index(); } std::optional integer_value = - GetInstructionValueAsInteger(instruction); + GetInstructionValueAsInteger(instruction, precomputed_analyses); result.value = std::move(integer_value); if (!result.IsValid()) { return std::nullopt; @@ -318,11 +323,12 @@ using WhileCondComparisonOrNoOp = std::variant; std::optional ParseComparisonOperand( - const HloInstruction* operand) { + const HloInstruction* operand, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (operand->opcode() == HloOpcode::kCopy || operand->opcode() == HloOpcode::kCopyStart || operand->opcode() == HloOpcode::kCopyDone) { - return ParseComparisonOperand(operand->operand(0)); + return ParseComparisonOperand(operand->operand(0), precomputed_analyses); } std::optional param_index; if (Match(operand, match::GetTupleElement().WithOperand( @@ -330,7 +336,7 @@ std::optional ParseComparisonOperand( param_index = operand->tuple_index(); } std::optional operand_value = - GetInstructionValueAsInteger(operand); + GetInstructionValueAsInteger(operand, precomputed_analyses); if (!param_index.has_value() && !operand_value.has_value()) { return std::nullopt; } @@ -338,12 +344,13 @@ std::optional ParseComparisonOperand( } std::optional PatternMatchLoopCondComparison( - const HloInstruction* comparison) { + const HloInstruction* comparison, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { CHECK_EQ(comparison->opcode(), HloOpcode::kCompare); std::optional lhs = - ParseComparisonOperand(comparison->operand(0)); + ParseComparisonOperand(comparison->operand(0), precomputed_analyses); std::optional rhs = - ParseComparisonOperand(comparison->operand(1)); + ParseComparisonOperand(comparison->operand(1), precomputed_analyses); if (!lhs.has_value() || !rhs.has_value()) { return std::nullopt; } @@ -353,18 +360,21 @@ std::optional PatternMatchLoopCondComparison( // Finds the while loop condition comparison by matching the loop condition root // with known patterns. std::optional PatternMatchLoopCondRoot( - const HloInstruction* loop_cond_root) { + const HloInstruction* loop_cond_root, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (loop_cond_root->opcode() == HloOpcode::kCopy) { - return PatternMatchLoopCondRoot(loop_cond_root->operand(0)); + return PatternMatchLoopCondRoot(loop_cond_root->operand(0), + precomputed_analyses); } if (loop_cond_root->opcode() == HloOpcode::kCopyDone) { - return PatternMatchLoopCondRoot(loop_cond_root->operand(0)->operand(1)); + return PatternMatchLoopCondRoot(loop_cond_root->operand(0)->operand(1), + precomputed_analyses); } if (loop_cond_root->opcode() == HloOpcode::kCompare) { // Base pattern #1: gte-0 comp gte-1 // Base pattern #2: constant comp gte // Base pattern #3: gte comp constant - return PatternMatchLoopCondComparison(loop_cond_root); + return PatternMatchLoopCondComparison(loop_cond_root, precomputed_analyses); } // Base pattern #4: gte is a boolean scalar and it was return immediately. if (Match(loop_cond_root, match::GetTupleElement().WithOperand( @@ -390,7 +400,8 @@ std::optional PatternMatchLoopCondRoot( const HloInstruction* to_apply_root = to_apply->root_instruction(); if (Match(to_apply_root, match::Tuple())) { return PatternMatchLoopCondRoot( - to_apply_root->operand(loop_cond_root->tuple_index())); + to_apply_root->operand(loop_cond_root->tuple_index()), + precomputed_analyses); } } // Recursive pattern #2: @@ -400,23 +411,26 @@ std::optional PatternMatchLoopCondRoot( match::GetTupleElement().WithOperand(0, match::Tuple()))) { const HloInstruction* new_cond_root = loop_cond_root->operand(0)->operand(loop_cond_root->tuple_index()); - return PatternMatchLoopCondRoot(new_cond_root); + return PatternMatchLoopCondRoot(new_cond_root, precomputed_analyses); } return std::nullopt; } std::optional PatternMatchInductionVarUpdate( - const HloInstruction* induction_var_update, int64_t tuple_index) { + const HloInstruction* induction_var_update, int64_t tuple_index, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (induction_var_update->opcode() == HloOpcode::kCopy) { return PatternMatchInductionVarUpdate(induction_var_update->operand(0), - tuple_index); + tuple_index, precomputed_analyses); } if (induction_var_update->opcode() == HloOpcode::kCopyDone) { return PatternMatchInductionVarUpdate( - induction_var_update->operand(0)->operand(1), tuple_index); + induction_var_update->operand(0)->operand(1), tuple_index, + precomputed_analyses); } std::optional update_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(induction_var_update); + TryParsingInstructionAsParameterAndInteger(induction_var_update, + precomputed_analyses); if (update_param_index_and_value.has_value()) { if (update_param_index_and_value->param_index.has_value()) { @@ -450,12 +464,14 @@ std::optional PatternMatchInductionVarUpdate( const HloInstruction* update_lhs = induction_var_update->operand(0); VLOG(3) << "PatternMatchInductionVarUpdate, LHS: " << update_lhs->ToString(); std::optional update_lhs_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(update_lhs); + TryParsingInstructionAsParameterAndInteger(update_lhs, + precomputed_analyses); const HloInstruction* update_rhs = induction_var_update->operand(1); VLOG(3) << "PatternMatchInductionVarUpdate, RHS: " << update_rhs->ToString(); std::optional update_rhs_param_index_and_value = - TryParsingInstructionAsParameterAndInteger(update_rhs); + TryParsingInstructionAsParameterAndInteger(update_rhs, + precomputed_analyses); if (!update_lhs_param_index_and_value.has_value() || !update_lhs_param_index_and_value->value.has_value() || @@ -496,14 +512,16 @@ std::optional PatternMatchInductionVarUpdate( // using pattern matching. std::optional PatternMatchInductionVarUpdateFromLoopBodyRoot( - const HloInstruction* loop_body_root, int64_t tuple_index) { + const HloInstruction* loop_body_root, int64_t tuple_index, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { if (loop_body_root->opcode() != HloOpcode::kTuple || loop_body_root->operand_count() <= tuple_index) { return std::nullopt; } const HloInstruction* induction_var_update = loop_body_root->operand(tuple_index); - return PatternMatchInductionVarUpdate(induction_var_update, tuple_index); + return PatternMatchInductionVarUpdate(induction_var_update, tuple_index, + precomputed_analyses); } std::optional PatternMatchLoopCondVarOverride( @@ -528,7 +546,8 @@ std::optional EvaluateWhileLoopParamInitValue( } const HloInstruction* element_instruction = param_instruction->operand(tuple_index); - return GetInstructionValueAsInteger(element_instruction); + return GetInstructionValueAsInteger(element_instruction, + /*precomputed_analyses=*/{}); } } // namespace @@ -634,14 +653,16 @@ std::optional HandleStaticLoopComparison( } std::optional PatternMatchParseWhileLoop( - const HloInstruction* while_op) { + const HloInstruction* while_op, + HloEvaluator::PrecomputedAnalyses precomputed_analyses) { VLOG(3) << "PatternMatchParseWhileLoop, while_op: " << while_op->name(); const HloComputation* while_cond = while_op->while_condition(); const HloComputation* while_body = while_op->while_body(); const HloInstruction* while_operand = while_op->operand(0); // Try to parse the loop condition comparison. std::optional loop_comparison_or_noop = - PatternMatchLoopCondRoot(while_cond->root_instruction()); + PatternMatchLoopCondRoot(while_cond->root_instruction(), + precomputed_analyses); if (!loop_comparison_or_noop.has_value()) { return std::nullopt; } @@ -704,7 +725,8 @@ std::optional PatternMatchParseWhileLoop( induction_var_init = EvaluateWhileLoopParamInitValue( while_operand, *loop_comparison.lhs.param_index); induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( - while_body->root_instruction(), *loop_comparison.lhs.param_index); + while_body->root_instruction(), *loop_comparison.lhs.param_index, + precomputed_analyses); lhs_is_induction_var = true; } } else { @@ -714,7 +736,8 @@ std::optional PatternMatchParseWhileLoop( induction_var_init = EvaluateWhileLoopParamInitValue( while_operand, *loop_comparison.rhs.param_index); induction_var_update = PatternMatchInductionVarUpdateFromLoopBodyRoot( - while_body->root_instruction(), *loop_comparison.rhs.param_index); + while_body->root_instruction(), *loop_comparison.rhs.param_index, + precomputed_analyses); lhs_is_induction_var = false; } } @@ -920,7 +943,7 @@ absl::StatusOr HloEvaluator::Evaluate( } absl::StatusOr HloEvaluator::Evaluate( - const HloInstruction* instruction, + const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses, bool recursively_evaluate_nonconstant_operands) { arg_literals_.clear(); evaluated_.clear(); @@ -930,7 +953,7 @@ absl::StatusOr HloEvaluator::Evaluate( absl::MakeCleanup([this] { enable_partial_evaluation_ = false; }); enable_partial_evaluation_ = recursively_evaluate_nonconstant_operands; TF_RETURN_IF_ERROR( - EvaluateInternal(instruction, /*shape_index=*/{}, + EvaluateInternal(instruction, precomputed_analyses, /*shape_index=*/{}, recursively_evaluate_nonconstant_operands)); const Literal& result = GetEvaluatedLiteralFor(instruction); if (!result.IsKnown()) { @@ -943,8 +966,8 @@ bool HloEvaluator::TryEvaluate(const HloInstruction* instruction, Literal* result, bool recursively_evaluate_nonconstant_operands) { CHECK(result != nullptr); - auto result_or = - Evaluate(instruction, recursively_evaluate_nonconstant_operands); + auto result_or = Evaluate(instruction, /*precomputed_analyses=*/{}, + recursively_evaluate_nonconstant_operands); if (!result_or.ok()) { VLOG(1) << "TryEvaluate failed:" << result_or.status(); return false; @@ -1066,11 +1089,12 @@ absl::StatusOr HloEvaluator::EvaluateDotOp( } absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( - const HloInstruction* parameter, const ShapeIndex& shape_index) { + const HloInstruction* parameter, const ShapeIndex& shape_index, + PrecomputedAnalyses analyses) { CHECK(!evaluated_.contains(parameter)); const HloComputation* parent_computation = parameter->parent(); std::vector computation_callers = - call_graph_cache_->GetComputationCallers(parent_computation); + analyses.call_graph->GetComputationCallers(parent_computation); // If the parent computation has multiple callers, we cannot determine from // which caller the arguments are passed. if (computation_callers.size() != 1) { @@ -1093,11 +1117,11 @@ absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( HloComputation* while_body = computation_caller->while_body(); TF_ASSIGN_OR_RETURN( const LogicalBuffer* logical_buffer, - tuple_points_to_analysis_cache_->GetBufferDefinedAt( + analyses.tuple_points_to->GetBufferDefinedAt( while_body->parameter_instruction(parameter->parameter_number()), shape_index)); const TuplePointsToAnalysis::BufferAliasVector& buffer_aliases = - tuple_points_to_analysis_cache_->GetBufferAliases(*logical_buffer); + analyses.tuple_points_to->GetBufferAliases(*logical_buffer); bool unchanged_in_return = false; for (const BufferAlias& buffer_alias : buffer_aliases) { if (buffer_alias.instruction() == while_body->root_instruction() && @@ -1109,7 +1133,8 @@ absl::Status HloEvaluator::EvaluateParameterFromCallerArgument( return MakeEvalErrorDueToParamOrInfeed(*parameter); } } - TF_RETURN_IF_ERROR(EvaluateInternal(caller_operand, shape_index, true)); + TF_RETURN_IF_ERROR( + EvaluateInternal(caller_operand, analyses, shape_index, true)); const Literal& caller_operand_literal = GetEvaluatedLiteralFor(caller_operand); evaluated_[parameter] = @@ -1154,7 +1179,8 @@ DimensionVector HloEvaluator::MakeDimMultipliers(const Shape& shape) { } absl::Status HloEvaluator::EvaluateInternal( - const HloInstruction* instruction, const ShapeIndex& shape_index, + const HloInstruction* instruction, PrecomputedAnalyses precomputed_analyses, + const ShapeIndex& shape_index, bool recursively_evaluate_nonconstant_operands) { // Don't need to evaluate this instruction again if it has already been // evaluated. @@ -1170,34 +1196,44 @@ absl::Status HloEvaluator::EvaluateInternal( if (instruction->opcode() == HloOpcode::kGetTupleElement) { ShapeIndex new_shape_index = shape_index; new_shape_index.push_front(instruction->tuple_index()); - TF_RETURN_IF_ERROR( - EvaluateInternal(instruction->operand(0), new_shape_index, - /*recursively_evaluate_nonconstant_operands=*/true)); + TF_RETURN_IF_ERROR(EvaluateInternal( + instruction->operand(0), precomputed_analyses, new_shape_index, + /*recursively_evaluate_nonconstant_operands=*/true)); } else if (instruction->opcode() == HloOpcode::kTuple && !shape_index.empty()) { ShapeIndex new_shape_index = shape_index; int64_t tuple_index = new_shape_index.front(); new_shape_index.pop_front(); TF_RETURN_IF_ERROR( - EvaluateInternal(instruction->operand(tuple_index), new_shape_index, + EvaluateInternal(instruction->operand(tuple_index), + precomputed_analyses, new_shape_index, /*recursively_evaluate_nonconstant_operands=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { - if (!call_graph_cache_) { - HloModule* module = instruction->GetModule(); - call_graph_cache_ = CallGraph::Build(module); - } - if (!tuple_points_to_analysis_cache_) { - HloModule* module = instruction->GetModule(); - absl::StatusOr> - tuple_points_to_analysis = TuplePointsToAnalysis::Run(module); - if (tuple_points_to_analysis.ok()) { - tuple_points_to_analysis_cache_ = - *std::move(tuple_points_to_analysis); - } - } - if (call_graph_cache_ && tuple_points_to_analysis_cache_) { - absl::Status argument_eval_status = - EvaluateParameterFromCallerArgument(instruction, shape_index); + CallGraph* call_graph = + (precomputed_analyses.call_graph != nullptr) + ? precomputed_analyses.call_graph + : std::invoke([this, instruction]() -> CallGraph* { + call_graph_cache_ = + CallGraph::Build(instruction->GetModule()); + return call_graph_cache_.get(); + }); + TuplePointsToAnalysis* tuple_points_to_analysis = + (precomputed_analyses.tuple_points_to != nullptr) + ? precomputed_analyses.tuple_points_to + : std::invoke([this, instruction]() -> TuplePointsToAnalysis* { + absl::StatusOr> + tuple_points_to_analysis = + TuplePointsToAnalysis::Run(instruction->GetModule()); + if (!tuple_points_to_analysis.ok()) { + return nullptr; + } + tuple_points_to_analysis_cache_ = + *std::move(tuple_points_to_analysis); + return tuple_points_to_analysis_cache_.get(); + }); + if (call_graph && tuple_points_to_analysis) { + absl::Status argument_eval_status = EvaluateParameterFromCallerArgument( + instruction, shape_index, {tuple_points_to_analysis, call_graph}); if (!argument_eval_status.ok()) { VLOG(4) << "Failed to evaluate parameter " << instruction->name() << " from caller. Reason: " << argument_eval_status.message(); @@ -1209,7 +1245,7 @@ absl::Status HloEvaluator::EvaluateInternal( } else { for (HloInstruction* operand : instruction->operands()) { TF_RETURN_IF_ERROR(EvaluateInternal( - operand, /*shape_index=*/{}, + operand, precomputed_analyses, /*shape_index=*/{}, /*recursively_evaluate_nonconstant_operands=*/true)); // Except for the above and following cases, we do not support handling // unknown operands for other HLOs. So mark the result as unknown. @@ -1821,6 +1857,7 @@ class FftTransform { auto generate_twiddles = [](int64_t length, bool inverse) { std::vector twiddles; // Need only half the twiddles. + twiddles.reserve(length / 2); for (int64_t k = 0; k < length / 2; k++) { twiddles.push_back(Twiddle(k, length, inverse)); } @@ -3445,7 +3482,7 @@ absl::StatusOr CreateScalarLiteral(int64_t value, absl::StatusOr TryParseAndEvaluateWhileInductionVar( const HloInstruction* while_hlo) { std::optional parsed_while_loop = - PatternMatchParseWhileLoop(while_hlo); + PatternMatchParseWhileLoop(while_hlo, /*precomputed_analyses=*/{}); if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic()) { return FailedPrecondition( "Cannot evaluate a while loop's induction variable since the loop " @@ -3486,7 +3523,8 @@ absl::Status HloEvaluator::HandleWhile(const HloInstruction* while_hlo) { auto lcv = GetEvaluatedLiteralFor(while_hlo->operand(0)).Clone(); if (!lcv.IsKnown()) { std::optional parsed_while_loop = - PatternMatchParseWhileLoop(while_hlo); + PatternMatchParseWhileLoop(while_hlo, + /*precomputed_analyses=*/{}); evaluated_[while_hlo] = Literal::CreateFromShapeWithUnknownLeafArrays(while_hlo->shape()); if (!parsed_while_loop.has_value() || parsed_while_loop->is_dynamic() || diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h index 2f91c39c857c9c..5f004073b7a3a4 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator.h @@ -55,43 +55,18 @@ limitations under the License. namespace xla { -// Represents a parsed static while loop. We normalize the loop representation -// so that it starts from the induction_var_init_value and increments by -// step_size until it exceeds or goes below loop_bound. -struct ParsedStaticWhileLoop { - // The number of iterations to be executed. - int64_t trip_count = -1; - // The tuple index of the induction variable in the while argument tuple. - int64_t induction_var_index = -1; - // The induction variable's initial value. - int64_t induction_var_init_value = -1; - // The induction variable is incremented by this number (could be negative) - // in each iteration. - int64_t step_size = -1; - int64_t loop_bound = -1; -}; - -// Indicates whether a parsed while loop is static or dynamic. If the loop is -// static, it contains a value for StaticLoopInfo; otherwise the loop is -// dynamic. We consider a loop dynamic if its induction variable's initial -// value or the loop bound's value depends on the while's parent computation's -// parameter. -struct ParsedWhileLoop { - std::optional static_while_loop; - bool is_dynamic() const { return !static_while_loop.has_value(); } -}; -constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop(); - -// Tries to parse a while loop using a set of predefined patterns. -// Returns the parsing result. -std::optional PatternMatchParseWhileLoop( - const HloInstruction* while_op); - // Responsible for evaluating HLO and obtain literal as the evaluation results. // // This class is not thread-safe. class HloEvaluator : public ConstDfsHloVisitorWithDefault { public: + // Precomputed analyses that can be passed to Evaluate functions to avoid + // recomputation during evaluation. + struct PrecomputedAnalyses { + TuplePointsToAnalysis* tuple_points_to; + CallGraph* call_graph; + }; + // Only evaluate up to max_loop_iterations per while-loop execution if // specified. explicit HloEvaluator(int64_t max_loop_iterations = -1); @@ -167,8 +142,12 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // within its parent computation until it encounters something that cannot be // evaluated, such as an Infeed or a Parameter instruction. // It makes best effort to partially evaluate a dependency if possible. + // The caller may pass in non-null `precomputed_analyses` to avoid + // recomputation during evaluation; the caller must ensure that any + // precomputed analyses were performed on the module containing `instruction`. absl::StatusOr Evaluate( const HloInstruction* instruction, + PrecomputedAnalyses precomputed_analyses = {}, bool recursively_evaluate_nonconstant_operands = false); // Same as Evaluate, except returning false on error and accepts an output @@ -270,13 +249,20 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { // marked as undetermined unless it has been previously evaluated using // EvaluateInternal. Such partial evaluation reduces the computation and // memory overhead in cases where we need only one tuple element by avoiding - // the evaluation of a full tuple. + // the evaluation of a full tuple. Any non-null `precomputed_analyses` will be + // used instead of recomputing. absl::Status EvaluateInternal( - const HloInstruction* instruction, const ShapeIndex& shape_index = {}, + const HloInstruction* instruction, + PrecomputedAnalyses precomputed_analyses, + const ShapeIndex& shape_index = {}, bool recursively_evaluate_nonconstant_operands = false); + // Evaluates the result of a `parameter` instruction by traversing the call + // graph as given in `analyses`. `shape_index` has the same effect as in + // EvaluateInternal above. absl::Status EvaluateParameterFromCallerArgument( - const HloInstruction* parameter, const ShapeIndex& shape_index); + const HloInstruction* parameter, const ShapeIndex& shape_index, + PrecomputedAnalyses analyses); // Helper method to extract a list of int64_t from evaluated instruction for // start_indices for DynamicSlice and DynamicUpdateSlice. @@ -518,6 +504,41 @@ class HloEvaluator : public ConstDfsHloVisitorWithDefault { std::unique_ptr> MatmulArray2D(const Array2D& lhs, const Array2D& rhs); +// Represents a parsed static while loop. We normalize the loop representation +// so that it starts from the induction_var_init_value and increments by +// step_size until it exceeds or goes below loop_bound. +struct ParsedStaticWhileLoop { + // The number of iterations to be executed. + int64_t trip_count = -1; + // The tuple index of the induction variable in the while argument tuple. + int64_t induction_var_index = -1; + // The induction variable's initial value. + int64_t induction_var_init_value = -1; + // The induction variable is incremented by this number (could be negative) + // in each iteration. + int64_t step_size = -1; + int64_t loop_bound = -1; +}; + +// Indicates whether a parsed while loop is static or dynamic. If the loop is +// static, it contains a value for StaticLoopInfo; otherwise the loop is +// dynamic. We consider a loop dynamic if its induction variable's initial +// value or the loop bound's value depends on the while's parent computation's +// parameter. +struct ParsedWhileLoop { + std::optional static_while_loop; + bool is_dynamic() const { return !static_while_loop.has_value(); } +}; +constexpr ParsedWhileLoop kParsedDynamicWhileLoop = ParsedWhileLoop(); + +// Tries to parse a while loop using a set of predefined patterns. +// Returns the parsing result. Any non-null `precompute_analyses` will be used +// instead of recomputing, and it is the caller's responsibility to ensure that +// the analyses are valid for the module that contains `while_op`. +std::optional PatternMatchParseWhileLoop( + const HloInstruction* while_op, + HloEvaluator::PrecomputedAnalyses precomputed_analyses = {}); + // Functionality exposed for testing. Do not rely on anything in this namespace // outside this file. namespace internal { diff --git a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc index 72dc6f84c4ade6..901c99fe1b66d3 100644 --- a/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc +++ b/third_party/xla/xla/hlo/evaluator/hlo_evaluator_test.cc @@ -50,10 +50,12 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" +#include "xla/service/call_graph.h" #include "xla/service/dynamic_dimension_inference.h" #include "xla/service/hlo_element_type_converter.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" +#include "xla/service/tuple_points_to_analysis.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -167,14 +169,15 @@ class HloEvaluatorTest : public HloTestBase { TF_ASSERT_OK_AND_ASSIGN( Literal result, evaluator_.Evaluate( - instruction, + instruction, /*precomputed_analyses=*/{}, /*recursively_evaluate_nonconstant_operands=*/true)); EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); } void TestRecursiveEvaluationFailure(HloInstruction* instruction) { - absl::StatusOr result = evaluator_.Evaluate( - instruction, /*recursively_evaluate_nonconstant_operands=*/true); + absl::StatusOr result = + evaluator_.Evaluate(instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true); EXPECT_TRUE(!result.ok()); } @@ -5035,6 +5038,79 @@ TEST_F(HloEvaluatorTest, GetTupleElementInterleavedWithTupleSucceeds) { TestRecursivelyEvaluateInstruction(gte2, expected); } +// Tests that we can evaluate a parameter instruction through the call graph. +TEST_F(HloEvaluatorTest, ParameterThroughCallSucceeds) { + constexpr absl::string_view kHloModule = R"( + HloModule parameter_through_call + + %identity { + ROOT %param = s32[] parameter(0) + } + + ENTRY parameter_through_call { + %constant = s32[] constant(42) + ROOT %call = s32[] call(s32[] %constant), to_apply=%identity + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + const HloInstruction* parameter_instruction = nullptr; + for (const auto* computation : hlo_module->computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_instruction = instruction; + } + } + } + ASSERT_NE(parameter_instruction, nullptr); + + Literal expected = LiteralUtil::CreateR0(42); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator_.Evaluate(parameter_instruction, /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +// As above, but with analyses precomputed. +TEST_F(HloEvaluatorTest, ParameterThroughCallSucceedsWithPrecomputation) { + constexpr absl::string_view kHloModule = R"( + HloModule parameter_through_call + + %identity { + ROOT %param = s32[] parameter(0) + } + + ENTRY parameter_through_call { + %constant = s32[] constant(42) + ROOT %call = s32[] call(s32[] %constant), to_apply=%identity + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + const HloInstruction* parameter_instruction = nullptr; + for (const auto* computation : hlo_module->computations()) { + for (const auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_instruction = instruction; + } + } + } + ASSERT_NE(parameter_instruction, nullptr); + + Literal expected = LiteralUtil::CreateR0(42); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tuple_points_to, + TuplePointsToAnalysis::Run(hlo_module.get())); + std::unique_ptr call_graph = CallGraph::Build(hlo_module.get()); + TF_ASSERT_OK_AND_ASSIGN( + Literal result, + evaluator_.Evaluate(parameter_instruction, + {tuple_points_to.get(), call_graph.get()}, + /*recursively_evaluate_nonconstant_operands=*/true)); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + class PatternMatchParseWhileLoopTest : public HloTestBase {}; TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) { @@ -5084,6 +5160,59 @@ TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedInsideOfCond) { EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5); } +TEST_F(PatternMatchParseWhileLoopTest, + LoopBoundDefinedInsideOfCondWithPrecomputation) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %while_condition { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %loop_bound = s32[] constant(5) + ROOT result = pred[] compare(%gte.0, %loop_bound), direction=LT + } + + %while_body { + %param = (s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = f32[1024, 1024] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %gte.1, f32[1024, 1024] %gte.2) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.1 = f32[1024, 1024] parameter(0) + %constant.0 = s32[] constant(0) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + ROOT %result = f32[1024, 1024] get-tuple-element((s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=2 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr tuple_points_to, + TuplePointsToAnalysis::Run(hlo_module.get())); + std::unique_ptr call_graph = CallGraph::Build(hlo_module.get()); + + HloInstruction* while_op = + hlo_module->entry_computation()->root_instruction()->mutable_operand(0); + std::optional parsed_while_loop = PatternMatchParseWhileLoop( + while_op, {tuple_points_to.get(), call_graph.get()}); + ASSERT_TRUE(parsed_while_loop.has_value()); + EXPECT_FALSE(parsed_while_loop->is_dynamic()); + EXPECT_EQ(parsed_while_loop->static_while_loop->trip_count, 5); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_index, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->induction_var_init_value, 0); + EXPECT_EQ(parsed_while_loop->static_while_loop->step_size, 1); + EXPECT_EQ(parsed_while_loop->static_while_loop->loop_bound, 5); +} + TEST_F(PatternMatchParseWhileLoopTest, LoopBoundDefinedOutsideOfCond) { constexpr absl::string_view kHloModule = R"( HloModule accumulated_all_reduce diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD index 6730e0317aad7c..09b40eaa7127cd 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/BUILD @@ -33,6 +33,7 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_solver", ":auto_sharding_strategy", @@ -110,6 +111,7 @@ cc_library( hdrs = ["auto_sharding_memory.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_proto_cc", ":auto_sharding_strategy", "//xla:status_macros", @@ -148,6 +150,7 @@ cc_library( ], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_proto_cc", "//xla:shape_util", "//xla/hlo/ir:hlo", @@ -170,6 +173,7 @@ cc_library( hdrs = ["auto_sharding_cost_graph.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_strategy", ":matrix", "//xla:shape_util", @@ -188,7 +192,9 @@ cc_library( hdrs = ["auto_sharding_option.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_util", + "//xla:array", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -219,6 +225,7 @@ cc_library( compatible_with = get_compatible_with_libtpu_portable(), deps = [ ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_wrapper", @@ -247,13 +254,16 @@ cc_library( hdrs = ["cluster_environment.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", ":profiling_result", + "//xla:array", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service/spmd:spmd_partitioner", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], ) @@ -271,6 +281,7 @@ cc_library( hdrs = ["auto_sharding_util.h"], compatible_with = get_compatible_with_libtpu_portable(), deps = [ + ":auto_sharding_device_mesh", ":auto_sharding_strategy", "//xla:array", "//xla:shape_tree", @@ -326,6 +337,20 @@ tf_proto_library( visibility = ["//visibility:public"], ) +cc_library( + name = "auto_sharding_device_mesh", + srcs = ["auto_sharding_device_mesh.cc"], + hdrs = [ + "auto_sharding_device_mesh.h", + ], + compatible_with = get_compatible_with_libtpu_portable(), + deps = [ + "//xla:array", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/types:span", + ], +) + build_test( name = "auto_sharding_runner_build_test", targets = [ @@ -343,6 +368,8 @@ xla_cc_test( ], deps = [ ":auto_sharding", + ":auto_sharding_cost_graph", + ":auto_sharding_device_mesh", ":auto_sharding_option", ":auto_sharding_strategy", ":auto_sharding_util", @@ -356,6 +383,8 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -363,7 +392,6 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index b0f17e9ffbddbf..042e4547a28e09 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -50,6 +50,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_memory.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" @@ -121,7 +122,7 @@ std::vector CommunicationReshardingCostVector( double ComputeMemoryReshardingCost(const Shape& shape, const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); int64_t dst_n_dim = NumTileDimensions(dst_sharding); @@ -889,7 +890,7 @@ double ComputeSortCommunicationCost(const int64_t sort_dim, // Enumerate all 1d partition strategies. void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -961,7 +962,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, } void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -969,7 +970,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, absl::Span tensor_dims); void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -1012,7 +1013,7 @@ void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, } void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -1075,7 +1076,7 @@ void BuildStrategyAndCostForOp(const HloInstruction* ins, const Shape& shape, } void EnumerateAll1DPartitionReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, bool only_allow_divisible, const std::string& suffix) { @@ -1129,14 +1130,14 @@ void EnumerateAll1DPartitionReshape( } void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, absl::Span tensor_dims); // Enumerate all partitions for reshape. Batch dim is always partitioned. void EnumeratePartitionReshape(const HloInstruction* ins, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, const InstructionBatchDimMap& batch_dim_map, @@ -1181,7 +1182,7 @@ void EnumeratePartitionReshape(const HloInstruction* ins, } void BuildStrategyAndCostForReshape( - const HloInstruction* ins, const Array& device_mesh, + const HloInstruction* ins, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, absl::Span tensor_dims) { @@ -1876,7 +1877,7 @@ std::unique_ptr CreateReshapeStrategies( const InstructionBatchDimMap& batch_dim_map, const AutoShardingOption& option, StrategyGroups& strategy_groups, const CallGraph& call_graph) { - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; int mesh_nn_dims = VectorGreaterThanOneElementCount(device_mesh.dimensions()); std::unique_ptr strategy_group = CreateLeafStrategyGroup( @@ -1989,6 +1990,7 @@ AutoShardingSolverResult CallSolver( request.mutable_max_cost()->set_coeff(*max_cost); } for (const auto& [edge, edge_cost] : cost_graph.edge_costs_) { + const auto normalized_edge_cost = Normalize(edge_cost); AutoShardingSolverRequest_Pair raw_edge; raw_edge.set_first(edge.first); raw_edge.set_second(edge.second); @@ -1997,8 +1999,8 @@ AutoShardingSolverResult CallSolver( AutoShardingSolverRequest_Costs mij; for (NodeStrategyIdx i = 0; i < edge_cost.n_; i++) { for (NodeStrategyIdx j = 0; j < edge_cost.m_; j++) { - rij.add_costs(edge_cost(i, j).communication_cost); - mij.add_costs(edge_cost(i, j).memory_cost); + rij.add_costs(normalized_edge_cost(i, j).communication_cost); + mij.add_costs(normalized_edge_cost(i, j).memory_cost); } } request.mutable_resharding_costs()->Add(std::move(rij)); @@ -2335,7 +2337,7 @@ absl::Status InsertReshardReshapes( absl::flat_hash_map>& preserve_shardings) { const std::vector& instructions = sequence.instructions(); - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; // Post process: fix some corner cases. ReshardingCache resharding_cache_entity; ReshardingCache* resharding_cache = &resharding_cache_entity; @@ -3016,7 +3018,6 @@ void FindReplicateSet( for (size_t i = 0; i < cur->operand_count(); ++i) { HloInstruction* operand = cur->mutable_operand(i); - operand = PassThroughCustomCallMarkerOperand(operand, cur); if (!visited.contains(operand) && !IsAlwaysReplicated(operand) && GetShardingStrategy(operand, strategy_map, cost_graph, s_val) @@ -3040,9 +3041,6 @@ absl::Status GenerateReduceScatter( // Propagation ends at output. const HloInstruction* output = instructions.back(); - if (IsCustomCallMarker(output)) { - output = output->operand(0); - } // A debug option: whether to do all-gather after backward pass. // This controls the location of all-gather. @@ -3118,8 +3116,7 @@ absl::Status GenerateReduceScatter( while (true) { path.push_back(root); if (root->opcode() == HloOpcode::kGetTupleElement) { - root = PassThroughCustomCallMarkerOperand(root->mutable_operand(0), - root); + root = root->mutable_operand(0); } else { break; } @@ -3215,14 +3212,6 @@ absl::Status GenerateReduceScatter( insert_all_gather.push_back(alias_map.at(to_split)); } else { insert_all_gather.push_back(to_split); - - if (to_split->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(to_split->operand(0)) && - to_split->users().size() == 1 && - to_split->users().front() == output) { - insert_all_gather.push_back(PassThroughCustomCallMarkerOperand( - to_split->mutable_operand(0), to_split)); - } } } } else { @@ -3304,8 +3293,8 @@ absl::Status GenerateReduceScatter( void AnnotateShardingWithSimpleHeuristic( HloModule* module, const std::string& heuristic, const AliasMap& alias_map, const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh_1d = cluster_env.device_mesh_1d_; int64_t num_devices = device_mesh.num_elements(); // Count the non-one mesh dimension. @@ -3325,6 +3314,7 @@ void AnnotateShardingWithSimpleHeuristic( if (heuristic == "shard-largest") { std::vector lengths; + lengths.reserve(inst->shape().rank()); for (int64_t i = 0; i < inst->shape().rank(); ++i) { lengths.push_back(inst->shape().dimensions(i)); } @@ -3425,7 +3415,7 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, const AutoShardingOption& option) { int mesh_dim = option.force_batch_dim_to_mesh_dim; int batch_dim = batch_map.at(GetBatchDimMapKey(ins)); - const Array& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; if (shape.dimensions(batch_dim) % device_mesh.dim(mesh_dim) != 0) { return absl::InvalidArgumentError( @@ -3463,8 +3453,8 @@ absl::Status FilterStrategy(const HloInstruction* ins, const Shape& shape, HloSharding GetReduceScatterOutput(const HloInstruction* ins, const ShardingStrategy& strategy, const ClusterEnvironment& cluster_env) { - const Array& device_mesh = cluster_env.device_mesh_; - const Array& device_mesh_1d = cluster_env.device_mesh_1d_; + const DeviceMesh& device_mesh = cluster_env.device_mesh_; + const DeviceMesh& device_mesh_1d = cluster_env.device_mesh_1d_; if (ins->opcode() == HloOpcode::kDot) { const DotDimensionNumbers& dot_dnums = ins->dot_dimension_numbers(); @@ -3975,14 +3965,16 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // batch_dim_map = spmd::BuildInstructionBatchDimMap(sequence); // ----- Read parameters of device mesh ----- - Array original_device_mesh(option_.device_mesh_shape); + spmd::DeviceMesh original_device_mesh(option_.device_mesh_shape); original_device_mesh.SetValues(option_.device_mesh_ids); const int64_t original_memory_budget = option_.memory_budget_per_device; std::vector> partial_mesh_shapes; if (option_.solve_nd_sharding_iteratively) { // Generate partial mesh shapes to optimize iteratively. - partial_mesh_shapes = spmd::DecomposeMeshShapes(option_.device_mesh_shape); + partial_mesh_shapes = spmd::DecomposeMeshShapes(option_.device_mesh_shape, + option_.device_mesh_alpha, + option_.device_mesh_beta); } else { partial_mesh_shapes = {option_.device_mesh_shape}; } @@ -4000,7 +3992,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( std::vector mesh_shape = partial_mesh_shapes[mesh_idx]; LOG(INFO) << "Processing partial mesh shape: " << spmd::ToString(mesh_shape); - Array device_mesh(mesh_shape); + spmd::DeviceMesh device_mesh(mesh_shape); int64_t total_devices = 1; for (int64_t i : mesh_shape) { @@ -4024,10 +4016,7 @@ absl::StatusOr AutoShardingImplementation::RunAutoSharding( // use the actual device order only for the final full mesh. device_mesh.SetValues(option_.device_mesh_ids); } else { - std::vector device_mesh_ids = - std::vector(total_devices); - std::iota(device_mesh_ids.begin(), device_mesh_ids.end(), 0); - device_mesh.SetValues(device_mesh_ids); + device_mesh.FillIota(0); } // TODO (zhuohan): Include the prof result as an option. @@ -4640,24 +4629,4 @@ absl::StatusOr AutoSharding::Run( return module_is_changed; } -absl::StatusOr DummyAutoSharding::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - // ----- Set Dummy Replicated Sharding ----- - HloComputation* entry = module->entry_computation(); - - for (HloInstruction* inst : entry->instructions()) { - const Shape& out_shape = inst->shape(); - if (out_shape.IsTuple()) { - ShapeTree tuple_sharding(out_shape, - HloSharding::Replicate()); - inst->set_sharding(HloSharding::Tuple(tuple_sharding)); - } else { - inst->set_sharding(HloSharding::Replicate()); - } - } - - return true; -} - } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h index 4695efc60d0dea..bdc137a0e462c6 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.h @@ -31,8 +31,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" -#include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_solver.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" @@ -49,18 +49,6 @@ limitations under the License. namespace xla { -class DummyAutoSharding : public HloModulePass { - public: - DummyAutoSharding() = default; - ~DummyAutoSharding() override = default; - absl::string_view name() const override { return "dummy_auto_sharding"; } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; -}; - enum class AutoShardingResult { kModuleUnchanged, kModuleChangedShardingPerformed, @@ -140,7 +128,7 @@ namespace spmd { // Their comments can be found in their definitions in *.cc files. HloSharding Tile(const Shape& shape, absl::Span tensor_dims, absl::Span mesh_dims, - const Array& device_mesh); + const DeviceMesh& device_mesh); std::vector CommunicationReshardingCostVector( const StrategyGroup* strategy_group, const Shape& shape, @@ -319,7 +307,7 @@ std::unique_ptr CreateTupleStrategyGroup(size_t instruction_id); // Enumerate all 1d partition strategies. void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, @@ -329,7 +317,7 @@ void EnumerateAll1DPartition(const HloInstruction* ins, const Shape& shape, // Enumerate all partitions recursively. void EnumerateAllPartition(const HloInstruction* ins, const Shape& shape, - const Array& device_mesh, + const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env, const StrategyMap& strategy_map, std::unique_ptr& strategy_group, diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc index 85127883e21937..9d28df32b04f5c 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -34,6 +35,24 @@ limitations under the License. namespace xla { namespace spmd { +EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost) { + double min_communication_cost = std::numeric_limits::max(); + for (int i = 0; i < edge_cost.n_; ++i) { + for (int j = 0; j < edge_cost.m_; ++j) { + min_communication_cost = + std::min(min_communication_cost, edge_cost(i, j).communication_cost); + } + } + if (min_communication_cost >= 0) return edge_cost; + EdgeReshardingCostMatrix normalized_edge_cost = edge_cost; + for (int i = 0; i < edge_cost.n_; ++i) { + for (int j = 0; j < edge_cost.m_; ++j) { + normalized_edge_cost(i, j).communication_cost -= min_communication_cost; + } + } + return normalized_edge_cost; +} + CostGraph::CostGraph(const StrategyGroups& strategy_groups, const AssociativeDotPairs& associative_dot_pairs) { node_lens_.reserve(strategy_groups.size()); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h index fda06ee8ec1e7b..3d6bac1b139196 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h @@ -55,6 +55,10 @@ struct EdgeReshardingCost { using EdgeReshardingCostMatrix = Matrix; +// Normalizes the edge cost matrix by a fixed constant to ensure there are no +// negative communication costs. +EdgeReshardingCostMatrix Normalize(const EdgeReshardingCostMatrix& edge_cost); + // A graph data structure to simplify the edge cost graph. It merges nodes and // performs path compression. class CostGraph { diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc new file mode 100644 index 00000000000000..07ab282f0bfa38 --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" + +#include + +#include "absl/types/span.h" +#include "xla/array.h" + +namespace xla { +namespace spmd { + +namespace { +bool AreValuesIota(const absl::Span values) { + for (int i = 1; i < values.size(); ++i) { + if (values[i] - values[i - 1] != 1) { + return false; + } + } + return true; +} +} // namespace + +void DeviceMesh::SetValues(absl::Span values) { + device_array.SetValues(values); + is_iota = AreValuesIota(values); +} +} // namespace spmd +} // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h new file mode 100644 index 00000000000000..919ea64027833b --- /dev/null +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h @@ -0,0 +1,86 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ +#define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ + +#include +#include + +#include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/array.h" + +namespace xla { +namespace spmd { +struct DeviceMesh { + Array device_array; + bool is_iota; + + explicit DeviceMesh(absl::Span sizes) + : device_array(sizes), is_iota(false) {} + + void FillIota(const int64_t value) { + device_array.FillIota(value); + is_iota = true; + } + + void SetValues(absl::Span values); + + int64_t num_dimensions() const { return device_array.num_dimensions(); } + + // Returns the size of the dimension at the given index. + int64_t dim(int64_t n) const { return device_array.dim(n); } + + // Returns a vector containing the dimensions of the array. + absl::Span dimensions() const { + return device_array.dimensions(); + } + + // Returns the total number of elements in the array. + int64_t num_elements() const { return device_array.num_elements(); } + + std::string ToString() const { return device_array.ToString(); } + + void Reshape(absl::Span new_dimensions) { + device_array.Reshape(new_dimensions); + } + + void TransposeDimensions(absl::Span permutation) { + device_array.TransposeDimensions(permutation); + is_iota = false; + } + + const int64_t& operator()(absl::Span indexes) const { + return device_array(indexes); + } + + int64_t& operator()(absl::Span indexes) { + return device_array(indexes); + } + + void Each(absl::FunctionRef, int64_t*)> f) { + device_array.Each(f); + } + + void Each( + absl::FunctionRef, int64_t)> f) const { + device_array.Each(f); + } +}; +} // namespace spmd +} // namespace xla + +#endif // XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_AUTO_SHARDING_DEVICE_MESH_H_ diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc index 3c62712ab41e27..9224da821db47b 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_dot_handler.cc @@ -36,10 +36,10 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" -#include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" #include "xla/hlo/experimental/auto_sharding/cluster_environment.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -57,14 +57,8 @@ namespace xla { namespace spmd { namespace { -using DimMap = StableMap; -using MeshDims = absl::Span; - -struct Enumeration { - MeshDims mesh_dims; - int64_t i; - int64_t j; -}; +using MeshDimSet = StableSet; +using DimMap = StableMap; // Contains base functionality common to both DotHandler and ConvHandler. class HandlerBase { @@ -88,7 +82,6 @@ class HandlerBase { option_(option), call_graph_(call_graph), device_mesh_(cluster_env.device_mesh_), - device_mesh_1d_(cluster_env.device_mesh_1d_), lhs_(ins->operand(0)), rhs_(ins->operand(1)) {} @@ -100,12 +93,14 @@ class HandlerBase { double compute_cost, double communication_cost); HloSharding CreateInputSpec(const HloInstruction* ins, const DimMap& dim_map, - const Array& device_mesh) const { + const DeviceMesh& device_mesh) const { if (dim_map.empty()) return HloSharding::Replicate(); - std::vector tensor_dims, mesh_dims; - for (const auto& [tensor_dim, mesh_dim] : dim_map) { + std::vector tensor_dims; + std::vector> mesh_dims; + for (const auto& [tensor_dim, mesh_dim_set] : dim_map) { tensor_dims.push_back(tensor_dim); - mesh_dims.push_back(mesh_dim); + mesh_dims.push_back( + std::vector(mesh_dim_set.begin(), mesh_dim_set.end())); } return Tile(ins->shape(), tensor_dims, mesh_dims, device_mesh); } @@ -116,7 +111,7 @@ class HandlerBase { const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const std::optional& expected_output_dim_map, - const Array& device_mesh, double compute_cost = 0, + double compute_cost = 0, const std::optional>& communication_cost_fn = std::nullopt); @@ -126,7 +121,7 @@ class HandlerBase { const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const std::optional& expected_output_dim_map, - const Array& device_mesh, double compute_cost = 0, + double compute_cost = 0, const std::optional>& communication_cost_fn = std::nullopt); @@ -137,7 +132,7 @@ class HandlerBase { virtual void AppendAllGatherWindowedEinsumStrategyForOperand( int operand_num, const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) {} + double compute_cost) {} // Given an existing (allreduce) sharding candidate, generate a corresponding // candidate by additionally sharding (if possible) the dot/conv output, such @@ -146,62 +141,72 @@ class HandlerBase { virtual void AppendReduceScatterWindowedEinsumStrategy( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) {} + double compute_cost) {} std::optional GetShardingFromUser(const HloSharding& lhs_spec, const HloSharding& rhs_spec); - // Enumerates combinations of the given mesh + tensor dimensions. - void Enumerate(std::function split_func, - size_t num_outer_dims = 2, size_t num_inner_dims = 2, - bool half = false) { - absl::Span mesh_shape = device_mesh_.dimensions(); - for (int64_t dim0 = 0; dim0 < mesh_shape.size(); ++dim0) { - for (int64_t dim1 = 0; dim1 < mesh_shape.size(); ++dim1) { - if (dim0 == dim1) continue; - for (int64_t i = 0; i < num_outer_dims; ++i) { - for (int64_t j = half ? i + 1 : 0; j < num_inner_dims; ++j) { - split_func({{dim0, dim1}, i, j}); - } - } - } - } - } - // Given a set of tensor dims, and a set of mesh dims, enumerates all mappings // where a subset of all tensor dims is mapped to a subset of mesh dims, such // that each tensor dim is mapped to at most mesh dim, and no two tensor dims // are mapped to the same mesh dim. - // TODO(b/226977360): We might need to generalize this to also allow cases - // where a tensor dim can be mapped to multiple mesh dims. - void EnumerateGeneral(std::function split_func, - int tensor_rank, int current_tensor_dim, - const absl::flat_hash_set& unassigned_mesh_dims, - const DimMap& current_dim_map) { - if (current_tensor_dim == tensor_rank) { + void Enumerate(std::function split_func, int tensor_rank, + int current_mesh_dim_idx, + const std::vector& unassigned_mesh_dims, + const DimMap& current_dim_map) { + if (current_mesh_dim_idx == unassigned_mesh_dims.size()) { split_func(current_dim_map); return; } - // current_tensor_dim is unsharded - EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1, - unassigned_mesh_dims, current_dim_map); - // current_tensor_dim is sharded across one of the remaining mesh dims - for (int mesh_dim : unassigned_mesh_dims) { + // Current mesh dim is not assigned to any tensor dim + Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1, + unassigned_mesh_dims, current_dim_map); + + for (int i = 0; i < tensor_rank; ++i) { DimMap updated_dim_map = current_dim_map; - updated_dim_map[current_tensor_dim] = mesh_dim; - absl::flat_hash_set updated_unassigned_mesh_dims = - unassigned_mesh_dims; - updated_unassigned_mesh_dims.erase( - updated_unassigned_mesh_dims.find(mesh_dim)); - EnumerateGeneral(split_func, tensor_rank, current_tensor_dim + 1, - updated_unassigned_mesh_dims, updated_dim_map); + if (!updated_dim_map[i].empty() && !option_.allow_mixed_mesh_shape) { + continue; + } + updated_dim_map[i].insert(unassigned_mesh_dims[current_mesh_dim_idx]); + Enumerate(split_func, tensor_rank, current_mesh_dim_idx + 1, + unassigned_mesh_dims, updated_dim_map); + } + } + + bool IsMeshDimSetNonTrivial(const MeshDimSet& mesh_dim_set) { + return absl::c_any_of(mesh_dim_set, [&](int mesh_dim) { + return device_mesh_.dim(mesh_dim) > 1; + }); + } + + bool IsFullyReplicatedSharding(const DimMap& dim_map, + const DeviceMesh& device_mesh) { + if (dim_map.empty()) { + return true; + } + for (const auto& [_, mesh_dim_set] : dim_map) { + if (IsMeshDimSetNonTrivial(mesh_dim_set)) { + return false; + } } + return true; } - // Enumerates *half* of the combinations (if inner & outer dims are the same). - void EnumerateHalf(std::function split_func, - size_t num_outer_dims = 2, size_t num_inner_dims = 2) { - Enumerate(split_func, num_outer_dims, num_inner_dims, true); + bool IsFullyReplicatedStrategy(const DimMap& output_dim_map, + const DimMap& lhs_dim_map, + const DimMap& rhs_dim_map, + const DeviceMesh& device_mesh) { + return IsFullyReplicatedSharding(output_dim_map, device_mesh) && + IsFullyReplicatedSharding(lhs_dim_map, device_mesh) && + IsFullyReplicatedSharding(rhs_dim_map, device_mesh); + } + + bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) { + int num_mesh_dims_used = 0; + for (const auto& [_, mesh_dims] : dim_map) { + num_mesh_dims_used += mesh_dims.size(); + } + return num_mesh_dims_used >= num_mesh_dims; } // Sorts strategies in the increasing order of their memory costs. Anecdotal @@ -219,8 +224,7 @@ class HandlerBase { const AutoShardingOption& option_; const CallGraph& call_graph_; - const Array& device_mesh_; - const Array& device_mesh_1d_; + const DeviceMesh& device_mesh_; const HloInstruction* lhs_; const HloInstruction* rhs_; }; @@ -257,12 +261,13 @@ class DotHandler : public HandlerBase { void AppendAllGatherWindowedEinsumStrategyForOperand( int operand_num, const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) override; + double compute_cost) override; - void AppendReduceScatterWindowedEinsumStrategy( - const std::string& name, const DimMap& lhs_dim_map, - const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) override; + void AppendReduceScatterWindowedEinsumStrategy(const std::string& name, + const DimMap& lhs_dim_map, + const DimMap& rhs_dim_map, + const DimMap& output_dim_map, + double compute_cost) override; absl::Status RegisterStrategies(); @@ -347,29 +352,28 @@ void HandlerBase::AppendNewStrategy(const std::string& name, })); } -// Given lhs and rhs dim maps, infers a sharding for the output by relying on -// the sharding_propagation pass. Given that this is a relatively new change -// (as of 11/2023), we also take an optional expected output dim map as an -// argument, to verify that sharding propagation in fact infers the sharding -// we expect (and to crash if it doesn't). +// Given lhs and rhs dim maps, infers a sharding for the output by relying +// on the sharding_propagation pass. Given that this is a relatively new +// change (as of 11/2023), we also take an optional expected output dim map +// as an argument, to verify that sharding propagation in fact infers the +// sharding we expect (and to crash if it doesn't). // TODO(b/309638633) As we build more confidence in this, we should remove // this expected_output_dim_map argument and fully rely on sharding // propagation. void HandlerBase::MaybeAppendInternal( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, - const std::optional& expected_output_dim_map, - const Array& device_mesh, double compute_cost, + const std::optional& expected_output_dim_map, double compute_cost, const std::optional>& communication_cost_fn) { - HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh); - HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh); + HloSharding lhs_spec = CreateInputSpec(lhs_, lhs_dim_map, device_mesh_); + HloSharding rhs_spec = CreateInputSpec(rhs_, rhs_dim_map, device_mesh_); std::optional output_spec = GetShardingFromUser(lhs_spec, rhs_spec); if (output_spec.has_value()) { if (expected_output_dim_map.has_value()) { HloSharding expected_output_spec = - CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); + CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_); // TODO(b/308687597) Once the bug is resolved, we ideally either want // have a CHECK statement verifying that the sharding inferred by // sharding propagation is in fact what we expect, or we trust sharding @@ -389,7 +393,7 @@ void HandlerBase::MaybeAppendInternal( } } else { CHECK(expected_output_dim_map.has_value()); - output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh); + output_spec = CreateInputSpec(ins_, *expected_output_dim_map, device_mesh_); LOG(WARNING) << "Sharding propagation could not infer output sharding for:\n " << ins_->ToString() << "\n LHS Spec: " << lhs_spec @@ -407,29 +411,27 @@ void HandlerBase::MaybeAppendInternal( void HandlerBase::MaybeAppend( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, - const std::optional& expected_output_dim_map, - const Array& device_mesh, double compute_cost, + const std::optional& expected_output_dim_map, double compute_cost, const std::optional>& communication_cost_fn) { MaybeAppendInternal(name, lhs_dim_map, rhs_dim_map, expected_output_dim_map, - device_mesh, compute_cost, communication_cost_fn); + compute_cost, communication_cost_fn); if (!option_.generate_windowed_einsum_strategies || !expected_output_dim_map.has_value()) { return; } if (absl::StrContains(name, "allreduce")) { CHECK(communication_cost_fn.has_value()); - AppendReduceScatterWindowedEinsumStrategy(name, lhs_dim_map, rhs_dim_map, - *expected_output_dim_map, - device_mesh, compute_cost); + AppendReduceScatterWindowedEinsumStrategy( + name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map, compute_cost); } else { CHECK(!communication_cost_fn.has_value()); AppendAllGatherWindowedEinsumStrategyForOperand( 0, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map, - device_mesh, compute_cost); + compute_cost); AppendAllGatherWindowedEinsumStrategyForOperand( 1, name, lhs_dim_map, rhs_dim_map, *expected_output_dim_map, - device_mesh, compute_cost); + compute_cost); } } @@ -460,14 +462,15 @@ std::optional HandlerBase::GetShardingFromUser( } void HandlerBase::SortStrategies() { - absl::c_sort(strategy_group_->strategies, - [](const ShardingStrategy& s1, const ShardingStrategy& s2) { - if (s1.memory_cost == s2.memory_cost) { - return s1.name < s2.name; - } else { - return s1.memory_cost < s2.memory_cost; - } - }); + absl::c_stable_sort( + strategy_group_->strategies, + [](const ShardingStrategy& s1, const ShardingStrategy& s2) { + if (s1.memory_cost == s2.memory_cost) { + return s1.name < s2.name; + } else { + return s1.memory_cost < s2.memory_cost; + } + }); } /************** DotHandler function definitions **************/ @@ -543,6 +546,15 @@ DotHandler::DotHandler( } } +std::string ToString(const MeshDimSet& set) { return absl::StrJoin(set, "-"); } +std::string ToString(const DimMap& map) { + std::vector strings; + for (const auto& [tdim, mdims] : map) { + strings.push_back(absl::StrCat("[", tdim, ": ", ToString(mdims), "]")); + } + return absl::StrJoin(strings, ", "); +} + std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map, const DimMap& lhs_dim_map) { std::string name; @@ -552,12 +564,12 @@ std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map, absl::string_view identifier) { for (size_t i = 0; i < out_dims.size(); ++i) { int output_batch_dim = out_dims[i]; - int mesh_dim = -1; + MeshDimSet mesh_dim_set; auto it = dim_map.find(output_batch_dim); - if (it != dim_map.end() && it->second >= 0) { - mesh_dim = it->second; + if (it != dim_map.end() && !it->second.empty()) { + mesh_dim_set = it->second; } - absl::StrAppend(&name, identifier, mesh_dim); + absl::StrAppend(&name, identifier, ToString(mesh_dim_set)); } }; @@ -577,9 +589,9 @@ std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map, bool contraction_dim_sharded = false; for (size_t i = 0; i < lhs_con_dims_.size(); ++i) { if (auto it = lhs_dim_map.find(lhs_con_dims_[i]); - it != lhs_dim_map.end() && it->second >= 0) { + it != lhs_dim_map.end() && !it->second.empty()) { contraction_dim_sharded = - contraction_dim_sharded || (device_mesh_.dim(it->second) > 1); + contraction_dim_sharded || IsMeshDimSetNonTrivial(it->second); } } @@ -589,34 +601,17 @@ std::string DotHandler::GenerateNameForDotSharding(const DimMap& output_dim_map, return name; } -bool IsFullyReplicatedSharding(const DimMap& dim_map, - const Array& device_mesh) { - if (dim_map.empty()) { - return true; - } - for (const auto& [_, mesh_dim] : dim_map) { - if (device_mesh.dim(mesh_dim) > 1) { - return false; +void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( + const DimMap& output_dim_map) { + // This early return is added to ensure parity with the older strategy + // generation code. Removing it will only increase the search space. + for (const auto& [_, mesh_dims] : output_dim_map) { + if (mesh_dims.size() > 1 && + mesh_dims.size() != device_mesh_.num_dimensions()) { + return; } } - return true; -} -bool IsFullyReplicatedStrategy(const DimMap& output_dim_map, - const DimMap& lhs_dim_map, - const DimMap& rhs_dim_map, - const Array& device_mesh) { - return IsFullyReplicatedSharding(output_dim_map, device_mesh) && - IsFullyReplicatedSharding(lhs_dim_map, device_mesh) && - IsFullyReplicatedSharding(rhs_dim_map, device_mesh); -} - -bool IsFullySharded(const DimMap& dim_map, int num_mesh_dims) { - return dim_map.size() >= num_mesh_dims; -} - -void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( - const DimMap& output_dim_map) { DimMap lhs_dim_map, rhs_dim_map; absl::flat_hash_set used_mesh_dims; @@ -626,11 +621,11 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( int lhs_batch_dim = lhs_batch_dims_[i]; int rhs_batch_dim = rhs_batch_dims_[i]; auto it = output_dim_map.find(output_batch_dim); - if (it != output_dim_map.end() && it->second >= 0) { - int mesh_dim = it->second; - used_mesh_dims.insert(mesh_dim); - lhs_dim_map[lhs_batch_dim] = mesh_dim; - rhs_dim_map[rhs_batch_dim] = mesh_dim; + if (it != output_dim_map.end() && !it->second.empty()) { + const StableSet& mesh_dim_set = it->second; + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); + lhs_dim_map[lhs_batch_dim] = mesh_dim_set; + rhs_dim_map[rhs_batch_dim] = mesh_dim_set; } } @@ -640,10 +635,10 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( int lhs_space_dim = lhs_space_dims_[i]; int output_space_dim = out_lhs_space_dims_[i]; auto it = output_dim_map.find(output_space_dim); - if (it != output_dim_map.end() && it->second >= 0) { - int mesh_dim = it->second; - used_mesh_dims.insert(mesh_dim); - lhs_dim_map[lhs_space_dim] = mesh_dim; + if (it != output_dim_map.end() && !it->second.empty()) { + const StableSet& mesh_dim_set = it->second; + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); + lhs_dim_map[lhs_space_dim] = mesh_dim_set; } } @@ -652,10 +647,10 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( int rhs_space_dim = rhs_space_dims_[i]; int output_space_dim = out_rhs_space_dims_[i]; auto it = output_dim_map.find(output_space_dim); - if (it != output_dim_map.end() && it->second >= 0) { - int mesh_dim = it->second; - used_mesh_dims.insert(mesh_dim); - rhs_dim_map[rhs_space_dim] = mesh_dim; + if (it != output_dim_map.end() && !it->second.empty()) { + const MeshDimSet& mesh_dim_set = it->second; + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); + rhs_dim_map[rhs_space_dim] = mesh_dim_set; } } @@ -669,7 +664,7 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( // generation code. Removing it will only increase the search space. IsFullySharded(output_dim_map, device_mesh_.num_dimensions())) { MaybeAppend(GenerateNameForDotSharding(output_dim_map, lhs_dim_map), - lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); + lhs_dim_map, rhs_dim_map, output_dim_map); } // Generate shardings for contraction dimensions @@ -677,10 +672,10 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( return; } - absl::flat_hash_set unused_mesh_dims; + std::vector unused_mesh_dims; for (size_t i = 0; i < device_mesh_.num_dimensions(); ++i) { if (!used_mesh_dims.contains(i) && device_mesh_.dim(i) > 1) { - unused_mesh_dims.insert(i); + unused_mesh_dims.push_back(i); } } @@ -698,11 +693,11 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( DimMap lhs_dim_map_with_contractions = lhs_dim_map; DimMap rhs_dim_map_with_contractions = rhs_dim_map; - for (const auto& [reducton_dim_index, mesh_dim] : reduction_dim_map) { + for (const auto& [reduction_dim_index, mesh_dim_set] : reduction_dim_map) { lhs_dim_map_with_contractions - [lhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim; + [lhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set; rhs_dim_map_with_contractions - [rhs_con_dims_[reduction_dims[reducton_dim_index]]] = mesh_dim; + [rhs_con_dims_[reduction_dims[reduction_dim_index]]] = mesh_dim_set; } // Skip fully the replicated strategy here as we add that outside of // HandleDot in auto_sharding_strategy. @@ -719,8 +714,10 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( double memory_cost = ByteSizeOfShapeWithSharding(ins_->shape(), output_sharding); double total_cost = 0; - for (const auto& [_, mesh_dim] : reduction_dim_map) { - total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim); + for (const auto& [_, mesh_dim_set] : reduction_dim_map) { + for (int mesh_dim : mesh_dim_set) { + total_cost += cluster_env_.AllReduceCost(memory_cost, mesh_dim); + } } return total_cost; }; @@ -728,65 +725,58 @@ void DotHandler::GenerateDotShardingStrategiesFromOutputSharding( MaybeAppend(GenerateNameForDotSharding(output_dim_map, lhs_dim_map_with_contractions), lhs_dim_map_with_contractions, rhs_dim_map_with_contractions, - output_dim_map, device_mesh_, + output_dim_map, /*compute_cost=*/0, communication_cost_fn); }; - EnumerateGeneral(split_func, reduction_dims.size(), - /*current_tensor_dim=*/0, unused_mesh_dims, - /*current_dim_map=*/{}); + Enumerate(split_func, reduction_dims.size(), + /*current_mesh_dim_idx=*/0, unused_mesh_dims, + /*current_dim_map=*/{}); } void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand( int operand_num, const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) { + double compute_cost) { const HloInstruction* operand = ins_->operand(operand_num); const DimMap& operand_dim_map = operand_num == 0 ? lhs_dim_map : rhs_dim_map; - absl::flat_hash_set sharded_tensor_dims; absl::flat_hash_set used_mesh_dims; - for (const auto [tensor_dim, mesh_dim] : operand_dim_map) { - if (device_mesh.dim(mesh_dim) == 1) { - continue; - } - sharded_tensor_dims.insert(tensor_dim); - used_mesh_dims.insert(mesh_dim); + for (const auto& [tensor_dim, mesh_dim_set] : operand_dim_map) { + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); } if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - sharded_tensor_dims.size() == operand->shape().rank()) { + used_mesh_dims.size() == operand->shape().rank()) { return; } for (int64_t tensor_dim = 0; tensor_dim < operand->shape().rank(); ++tensor_dim) { - if (sharded_tensor_dims.contains(tensor_dim)) { + if (auto it = operand_dim_map.find(tensor_dim); + it != operand_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { continue; } - for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); + for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); ++mesh_dim) { - if (used_mesh_dims.contains(mesh_dim) || - (device_mesh.dim(mesh_dim) == 1)) { + if (used_mesh_dims.contains(mesh_dim)) { continue; } DimMap further_sharded_dim_map = operand_dim_map; - further_sharded_dim_map[tensor_dim] = mesh_dim; + further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim}; - auto updated_communication_cost_fn = + auto communication_cost_fn = [](const HloSharding& output_sharding) -> double { // TODO(331684721): Model costs for windowed einsum return 100.0; }; - std::string updated_name = - absl::StrCat(absl::StrFormat("WindowedEinsum @ {%d,%d,%d}", - operand_num, tensor_dim, mesh_dim), - name); + std::string updated_name = absl::StrCat( + name, absl::StrFormat("|ag_windowed_einsum_o%dt%dm%d", operand_num, + tensor_dim, mesh_dim)); MaybeAppendInternal( updated_name, operand_num == 0 ? further_sharded_dim_map : lhs_dim_map, operand_num == 1 ? further_sharded_dim_map : rhs_dim_map, - output_dim_map, device_mesh, compute_cost, - updated_communication_cost_fn); + output_dim_map, compute_cost, communication_cost_fn); } } } @@ -794,62 +784,56 @@ void DotHandler::AppendAllGatherWindowedEinsumStrategyForOperand( void DotHandler::AppendReduceScatterWindowedEinsumStrategy( const std::string& name, const DimMap& lhs_dim_map, const DimMap& rhs_dim_map, const DimMap& output_dim_map, - const Array& device_mesh, double compute_cost) { - absl::flat_hash_set sharded_tensor_dims; + double compute_cost) { absl::flat_hash_set used_mesh_dims; - for (const auto [tensor_dim, mesh_dim] : output_dim_map) { - if (device_mesh.dim(mesh_dim) == 1) { - continue; - } - sharded_tensor_dims.insert(tensor_dim); - used_mesh_dims.insert(mesh_dim); + for (const auto& [tensor_dim, mesh_dim_set] : output_dim_map) { + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); } + if (used_mesh_dims.size() == device_mesh_.num_dimensions() || - sharded_tensor_dims.size() == ins_->shape().rank()) { + used_mesh_dims.size() == ins_->shape().rank()) { return; } for (int64_t tensor_dim = 0; tensor_dim < ins_->shape().rank(); ++tensor_dim) { - if (sharded_tensor_dims.contains(tensor_dim)) { + if (auto it = output_dim_map.find(tensor_dim); + it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { continue; } - for (int64_t mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); + for (int mesh_dim = 0; mesh_dim < device_mesh_.num_dimensions(); ++mesh_dim) { - if (used_mesh_dims.contains(mesh_dim) || - (device_mesh.dim(mesh_dim) == 1)) { + if (used_mesh_dims.contains(mesh_dim)) { continue; } DimMap further_sharded_dim_map = output_dim_map; - further_sharded_dim_map[tensor_dim] = mesh_dim; + further_sharded_dim_map[tensor_dim] = MeshDimSet{mesh_dim}; - auto updated_communication_cost_fn = + auto communication_cost_fn = [](const HloSharding& output_sharding) -> double { // TODO(331684721): Model costs for windowed einsum return 100.0; }; std::string updated_name = absl::StrCat( - absl::StrFormat("WindowedEinsum @ {%d,%d}", tensor_dim, mesh_dim), - name); + name, + absl::StrFormat("|rs_windowed_einsum_t%dm%d", tensor_dim, mesh_dim)); MaybeAppendInternal(updated_name, lhs_dim_map, rhs_dim_map, - further_sharded_dim_map, device_mesh, compute_cost, - updated_communication_cost_fn); + further_sharded_dim_map, compute_cost, + communication_cost_fn); } } } absl::Status DotHandler::RegisterStrategies() { - absl::flat_hash_set all_mesh_dims; - for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { - all_mesh_dims.insert(i); - } - EnumerateGeneral( + std::vector all_mesh_dims(device_mesh_.num_dimensions()); + std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0); + Enumerate( /*split_func=*/ [&](const DimMap& output_dim_map) { GenerateDotShardingStrategiesFromOutputSharding(output_dim_map); }, - ins_->shape().rank(), /*current_tensor_dim=*/0, all_mesh_dims, + ins_->shape().rank(), /*current_mesh_dim_idx=*/0, all_mesh_dims, /*current_dim_map=*/{}); SortStrategies(); return absl::OkStatus(); @@ -887,27 +871,27 @@ void ConvHandler::GenerateConvolutionShardingStrategiesFromOutputSharding( // Propagate batch dim sharding auto it = output_dim_map.find(out_batch_dim_); - if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) { - int mesh_dim = it->second; - lhs_dim_map[lhs_batch_dim_] = mesh_dim; - used_mesh_dims.insert(mesh_dim); - absl::StrAppend(&name, "b", mesh_dim); + if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { + const MeshDimSet& mesh_dim_set = it->second; + lhs_dim_map[lhs_batch_dim_] = mesh_dim_set; + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); + absl::StrAppend(&name, "b", ToString(mesh_dim_set)); } else { absl::StrAppend(&name, "b-1"); } // Propagate out channel dim sharding it = output_dim_map.find(out_out_channel_dim_); - if (it != output_dim_map.end() && device_mesh_.dim(it->second) > 1) { - int mesh_dim = it->second; - lhs_dim_map[rhs_out_channel_dim_] = mesh_dim; - used_mesh_dims.insert(mesh_dim); - absl::StrAppend(&name, "oc", mesh_dim); + if (it != output_dim_map.end() && IsMeshDimSetNonTrivial(it->second)) { + const MeshDimSet& mesh_dim_set = it->second; + lhs_dim_map[rhs_out_channel_dim_] = mesh_dim_set; + used_mesh_dims.insert(mesh_dim_set.begin(), mesh_dim_set.end()); + absl::StrAppend(&name, "oc", ToString(mesh_dim_set)); } else { absl::StrAppend(&name, "oc-1"); } - MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); + MaybeAppend(name, lhs_dim_map, rhs_dim_map, output_dim_map); // Generate shardings for contraction dimensions if (used_mesh_dims.size() == device_mesh_.num_dimensions()) { @@ -925,12 +909,12 @@ void ConvHandler::GenerateConvolutionShardingStrategiesFromOutputSharding( return; } - for (int64_t mesh_dim : unused_mesh_dims) { + for (int mesh_dim : unused_mesh_dims) { DimMap lhs_dim_map_with_contractions = lhs_dim_map; DimMap rhs_dim_map_with_contractions = rhs_dim_map; - lhs_dim_map_with_contractions[lhs_in_channel_dim_] = mesh_dim; - rhs_dim_map_with_contractions[rhs_in_channel_dim_] = mesh_dim; + lhs_dim_map_with_contractions[lhs_in_channel_dim_] = MeshDimSet{mesh_dim}; + rhs_dim_map_with_contractions[rhs_in_channel_dim_] = MeshDimSet{mesh_dim}; absl::StrAppend(&name, "ic", mesh_dim, "@allreduce"); auto communication_cost_fn = [&](const HloSharding& output_sharding) { @@ -940,7 +924,7 @@ void ConvHandler::GenerateConvolutionShardingStrategiesFromOutputSharding( }; MaybeAppend(name, lhs_dim_map_with_contractions, - rhs_dim_map_with_contractions, output_dim_map, device_mesh_, + rhs_dim_map_with_contractions, output_dim_map, /*compute_cost=*/0, communication_cost_fn); } } @@ -965,15 +949,13 @@ absl::Status ConvHandler::RegisterStrategies() { SplitDepthwise(false); } - absl::flat_hash_set all_mesh_dims; - for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { - all_mesh_dims.insert(i); - } - EnumerateGeneral( + std::vector all_mesh_dims(device_mesh_.num_dimensions()); + std::iota(all_mesh_dims.begin(), all_mesh_dims.end(), 0); + Enumerate( [&](const DimMap& output_dim_map) { GenerateConvolutionShardingStrategiesFromOutputSharding(output_dim_map); }, - 2, /*current_tensor_dim=*/0, all_mesh_dims, + 2, /*current_mesh_dim_idx=*/0, all_mesh_dims, /*current_dim_map=*/{}); // If force_batch_dim_to_mesh_dim is set, filter out invalid strategies @@ -991,38 +973,37 @@ absl::Status ConvHandler::RegisterStrategies() { void ConvHandler::SplitDepthwise(bool forward) { std::function split_func = [&](const DimMap& output_dim_map) { - int out_batch_mesh_dim = -1; - int out_out_channel_mesh_dim = -1; + MeshDimSet out_batch_mesh_dim_set; + MeshDimSet out_out_channel_mesh_dim_set; if (auto it = output_dim_map.find(out_batch_dim_); it != output_dim_map.end()) { - out_batch_mesh_dim = it->second; + out_batch_mesh_dim_set = it->second; } if (auto it = output_dim_map.find(out_out_channel_dim_); it != output_dim_map.end()) { - out_out_channel_mesh_dim = it->second; + out_out_channel_mesh_dim_set = it->second; } - if (out_batch_mesh_dim == -1 || out_out_channel_mesh_dim == -1) { + if (out_batch_mesh_dim_set.empty() || + out_out_channel_mesh_dim_set.empty()) { return; } DimMap lhs_dim_map, rhs_dim_map; lhs_dim_map[lhs_batch_dim_] = - forward ? out_batch_mesh_dim : out_out_channel_mesh_dim; + forward ? out_batch_mesh_dim_set : out_out_channel_mesh_dim_set; lhs_dim_map[lhs_in_channel_dim_] = - forward ? out_out_channel_mesh_dim : out_batch_mesh_dim; + forward ? out_out_channel_mesh_dim_set : out_batch_mesh_dim_set; - rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim; + rhs_dim_map[rhs_out_channel_dim_] = out_out_channel_mesh_dim_set; - MaybeAppend(absl::StrCat("b", out_batch_mesh_dim, "oc", - out_out_channel_mesh_dim, "@depthwise"), - lhs_dim_map, rhs_dim_map, output_dim_map, device_mesh_); + MaybeAppend( + absl::StrCat("b", ToString(out_batch_mesh_dim_set), "oc", + ToString(out_out_channel_mesh_dim_set), "|depthwise"), + lhs_dim_map, rhs_dim_map, output_dim_map); }; - absl::flat_hash_set all_mesh_dims; - for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { - all_mesh_dims.insert(i); - } - EnumerateGeneral(split_func, 2, /*current_tensor_dim=*/0, all_mesh_dims, - /*current_dim_map=*/{}); + std::vector all_mesh_dims(device_mesh_.num_dimensions()); + Enumerate(split_func, 2, /*current_mesh_dim_idx=*/0, all_mesh_dims, + /*current_dim_map=*/{}); } } // namespace @@ -1062,7 +1043,7 @@ absl::Status HandleConv(std::unique_ptr& strategy_group, strategy_group = CreateLeafStrategyGroup(instruction_id, ins, strategy_map, strategy_groups); - auto conv_as_dot_dims = + const dot_as_convolution_util::DotConvolutionDimsInfo& conv_as_dot_dims = dot_as_convolution_util::ParseConvolutionDimsInfo(ins); if (conv_as_dot_dims.conv_spatial_dims.empty()) { DotHandler handler( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc index d2f124b3e3f85c..56ca325679483e 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.cc @@ -82,8 +82,6 @@ std::string AutoShardingOption::ToString() const { absl::StrCat("allow_recompute_heavy_op: ", allow_recompute_heavy_op)); lines.push_back( absl::StrCat("allow_mixed_mesh_shape: ", allow_mixed_mesh_shape)); - lines.push_back( - absl::StrCat("grad_acc_num_micro_batches: ", grad_acc_num_micro_batches)); lines.push_back(absl::StrCat("solve_nd_sharding_iteratively: ", solve_nd_sharding_iteratively)); lines.push_back( diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h index 51eceae5b5909d..5e73af7929b181 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_option.h @@ -119,11 +119,6 @@ struct AutoShardingOption { // If true, allow adding 1d strategies in 2d logical mesh. bool allow_mixed_mesh_shape = true; - // The number of micro batches if gradient accumulation is used. - // If this is not 1, the cost of all-reduce for gradient synchronization - // is divided by this number. - int grad_acc_num_micro_batches = 1; - // If true, N-D sharding (e.g., N maybe be 2 or 3) will be solved in N // iterations, where one iteration chooses one tensor dimension to shard. If // false, solve N-D sharding directly, i.e., generating all possible sharding diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc index f204ff43496d61..67a6fed7149280 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc @@ -67,6 +67,10 @@ using ::operations_research::MPVariable; // solver cannot guarantee exact numerical precision. constexpr double kMaxCostEpsilon = 1.0001; +// Memory contributions in the Mixed ILP are converted to units in this range; +// beware that significantly larger / smaller values can cause numerical issues. +constexpr double kMemoryMultiplier = 1e6; + bool AutoShardingSolverOutput::operator==( const AutoShardingSolverOutput& other) const { return s_val == other.s_val && cost == other.cost && @@ -261,7 +265,7 @@ std::optional> ReduceMemoryTerms( reduced_groups.push_back({group.prims().begin(), group.prims().end()}); } } - solver.MakeIntVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(), + solver.MakeNumVarArray(reduced_groups.size(), 0.0, MPSolver::infinity(), absl::StrCat("group_", prim_type), &group_vars); for (int64_t group_idx = 0; group_idx < group_vars.size(); ++group_idx) { MPConstraint* constraint = solver.MakeRowConstraint( @@ -271,7 +275,7 @@ std::optional> ReduceMemoryTerms( for (const int64_t prim_idx : reduced_groups[group_idx]) { for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) { double memory_cost = memory_costs.at(prim_idx).costs(j); - memory_cost /= request.memory_budget() / 100.0; + memory_cost /= request.memory_budget() / kMemoryMultiplier; const double accumulated_coefficient = constraint->GetCoefficient(prim_vars[prim_idx][j]); constraint->SetCoefficient(prim_vars[prim_idx][j], @@ -302,9 +306,12 @@ void AddMemoryTerms( time_idx <= intervals[prim_idx].second; ++time_idx) { if (!reduced_times.contains(time_idx)) continue; if (!constraints.contains(time_idx)) { - MPConstraint* constraint = solver.MakeRowConstraint( - -MPSolver::infinity(), 100.0, absl::StrCat("mem[", time_idx, "]")); - if (overbudget_var) constraint->SetCoefficient(overbudget_var, -100.0); + MPConstraint* constraint = + solver.MakeRowConstraint(-MPSolver::infinity(), kMemoryMultiplier, + absl::StrCat("mem[", time_idx, "]")); + if (overbudget_var) { + constraint->SetCoefficient(overbudget_var, -kMemoryMultiplier); + } constraints[time_idx] = constraint; } MPConstraint* constraint = constraints[time_idx]; @@ -314,7 +321,7 @@ void AddMemoryTerms( } for (int64_t j = 0; j < prim_vars[prim_idx].size(); ++j) { double memory_cost = memory_costs.at(prim_idx).costs(j); - memory_cost /= request.memory_budget() / 100.0; + memory_cost /= request.memory_budget() / kMemoryMultiplier; const double accumulated_coefficient = constraint->GetCoefficient(prim_vars[prim_idx][j]); constraint->SetCoefficient(prim_vars[prim_idx][j], diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc index 5a237b2f979a67..05631bc2090b1a 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_solver_test.cc @@ -441,6 +441,46 @@ TEST(CallORToolsSolverTest, HandlesReducedIntervalsAndGroups) { EXPECT_EQ(result, expected_result); } +TEST(CallORToolsSolverTest, HandlesGroupsWithTinyMemoryCosts) { + AutoShardingSolverRequest request = DefaultAutoShardingSolverRequest(); + const std::vector> node_intervals = + {{5, -1}, {5, -1}, {2, 3}, {3, 4}, {100, -1}, {0, 4}}; + const std::vector> edge_intervals = + {{1, 2}, {2, 3}}; + const std::vector> node_groups = {{0, 1}}; + const std::vector> edge_groups = {}; + const CostMatrix memory_costs = {{1, 1, 1, 1}, // These values are tiny and + {2, 2, 2}, // shouldn't be rounded up. + {300, 300, 300, 300, 300, 300, 300}, + {4000, 4000, 4000, 4000, 4000, 4000, 4000}, + {50000, 50000, 50000}}; + const CostMatrix memory_edge_costs = {{0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0}, + {0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0}}; + request.clear_live(); + request.clear_memory_costs(); + AddIntervals(request.mutable_node_intervals(), node_intervals); + AddIntervals(request.mutable_edge_intervals(), edge_intervals); + AddGroups(request.mutable_node_groups(), node_groups); + AddGroups(request.mutable_edge_groups(), edge_groups); + AddCosts(request.mutable_memory_costs(), memory_costs); + AddCosts(request.mutable_memory_edge_costs(), memory_edge_costs); + request.set_enable_memory_edge_costs(true); + request.set_memory_budget(4321); + + const AutoShardingSolverResult result = CallORToolsSolver(request); + + const std::vector s_val = {0, 0, 0, 0, 0}; + const double objective_value = 7650.0; + const AutoShardingSolverOutput expected_output = {s_val, objective_value}; + const AutoShardingSolverResult expected_result = {expected_output, false}; + EXPECT_EQ(result, expected_result); +} + TEST(CallORToolsSolverTest, SolvesWithEquivalences) { const AutoShardingSolverRequest request = AutoShardingSolverRequestWithEquivalences(); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc index 6c4ae8251033b9..4414c4f7340dbc 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.cc @@ -39,6 +39,7 @@ limitations under the License. #include "absl/types/span.h" #include "xla/array.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_wrapper.h" @@ -77,7 +78,7 @@ std::optional ConstructImprovedSharding( std::pair ComputeSliceShardingAndCommunicationCostFromOperand( const HloSharding& input_spec, const Shape& old_shape, - const Shape& new_shape, const Array& device_mesh, + const Shape& new_shape, const DeviceMesh& device_mesh, const ClusterEnvironment& cluster_env) { if (input_spec.IsReplicated()) { return std::make_pair(input_spec, 0); @@ -135,7 +136,7 @@ BuildStrategyAndCost( const ClusterEnvironment& cluster_env, AutoShardingOption& option, const CallGraph& call_graph, const HloCostAnalysis& hlo_cost_analysis, bool trying_multiple_mesh_shapes) { - // const Array& device_mesh = cluster_env.device_mesh_; + // const DeviceMesh& device_mesh = cluster_env.device_mesh_; StrategyMap strategy_map; // This map stores all of the trimmed strategies due to user specified // sharding. The key is the instruction id, the value is the strategies. This @@ -149,8 +150,7 @@ BuildStrategyAndCost( const std::vector& instructions = sequence.instructions(); // Add penalty for replicated tensors - double replicated_penalty = std::round(cluster_env.AllReduceCost(1, 0) + - cluster_env.AllReduceCost(1, 1)); + double replicated_penalty = cluster_env.GetDefaultReplicatedPenalty(); int64_t max_depth = -1; for (auto iter : depth_map) { @@ -739,16 +739,13 @@ BuildStrategyAndCost( break; } case HloOpcode::kIota: { - // For an unknown reason, we do not generate partially replicated - // strategies for iota ops. This can be changed if we find that our - // search isn't exhaustive enough for certain ops. strategy_group = CreateAllStrategiesGroup( ins, ins->shape(), instruction_id, strategy_groups, cluster_env, strategy_map, option, replicated_penalty, batch_dim_map, call_graph, only_allow_divisible, /* create_replicated_strategies */ true, - /* create_partially_replicated_strategies */ false) + /* create_partially_replicated_strategies */ true) .value(); break; } @@ -828,15 +825,7 @@ BuildStrategyAndCost( } }; - if (IsCustomCallMarker(ins)) { - const HloInstruction* operand = ins->operand(0); - const StrategyGroup* src_strategy_group = - strategy_map.at(operand).get(); - CHECK(src_strategy_group->is_tuple); - strategy_group = MaybeFollowInsStrategyGroup( - src_strategy_group, ins->shape(), instruction_id, strategy_groups, - cluster_env, pretrimmed_strategy_map); - } else if (IsSPMDFullToShardShapeCustomCall(ins)) { + if (IsSPMDFullToShardShapeCustomCall(ins)) { return absl::InternalError( "An SPMDFullToShardShape call found outside a manually " "partitioned sub-graph."); @@ -1034,24 +1023,6 @@ BuildStrategyAndCost( strategy_map[ins] = std::move(strategy_group); } // end of for loop - // If gradient accumulation is used, adjust the cost of all-reduce for - // gradient synchronization. - if (option.grad_acc_num_micro_batches > 1) { - // find gradient-computation instructions - std::vector grad_insts = - GetGradientComputationInstructions(instructions); - for (const HloInstruction* inst : grad_insts) { - StrategyGroup* stra_vector = strategy_map[inst].get(); - CHECK(!stra_vector->is_tuple); - - for (auto& stra : stra_vector->strategies) { - if (absl::StrContains(stra.name, "allreduce")) { - stra.communication_cost /= option.grad_acc_num_micro_batches; - } - } - } - } - return std::make_tuple(std::move(strategy_map), std::move(strategy_groups), std::move(associative_dot_pairs)); } diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc index 4f9d533a4edcec..ab040e9c77edd7 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_test.cc @@ -23,12 +23,15 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_cost_graph.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" @@ -46,7 +49,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; @@ -68,25 +71,72 @@ using ::testing::Pair; using ::testing::ResultOf; using ::testing::UnorderedElementsAre; -using DummyAutoShardingTest = HloTestBase; - -TEST_F(DummyAutoShardingTest, ReplicatedShardingDummy) { - constexpr absl::string_view kHloString = R"( -HloModule module -ENTRY %elementwise { - %param0 = f32[5,7,11,13]{3,2,1,0} parameter(0) - %param1 = f32[5,7,11,13]{3,2,1,0} parameter(1) - %add = f32[5,7,11,13]{3,2,1,0} add(%param0, %param1) - ROOT %copy = f32[5,7,11,13]{3,2,1,0} copy(%add) -})"; - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(kHloString)); - TF_ASSERT_OK_AND_ASSIGN(bool changed, DummyAutoSharding().Run(module.get())); - EXPECT_TRUE(changed); - auto* instruction = FindInstruction(module.get(), "param0"); - ASSERT_NE(instruction, nullptr); - EXPECT_THAT(instruction, op::Sharding("{replicated}")); +TEST(DeviceMeshTest, IotaDeviceMesh2DStartsWith0) { + DeviceMesh device_mesh({2, 4}); + device_mesh.FillIota(0); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4)); + EXPECT_EQ(device_mesh.num_elements(), 8); +} + +TEST(DeviceMeshTest, IotaDeviceMesh3DStartsWithNonZero) { + DeviceMesh device_mesh({2, 4, 8}); + device_mesh.FillIota(55); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ExplicitSetValuesInferIotaIotaValues) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh.SetValues(device_mesh_values); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ExplicitSetValuesInferIotaNonIotaValues) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh_values[54] = 54; + device_mesh.SetValues(device_mesh_values); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ReshapeTestWithoutIota) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh_values[54] = 54; + device_mesh.SetValues(device_mesh_values); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); + + device_mesh.Reshape({2, 32}); + EXPECT_FALSE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 32)); + EXPECT_EQ(device_mesh.num_elements(), 64); +} + +TEST(DeviceMeshTest, ReshapeTestWithIota) { + DeviceMesh device_mesh({2, 4, 8}); + std::vector device_mesh_values(64); + absl::c_iota(device_mesh_values, 34); + device_mesh.SetValues(device_mesh_values); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 4, 8)); + EXPECT_EQ(device_mesh.num_elements(), 64); + + device_mesh.Reshape({2, 32}); + EXPECT_TRUE(device_mesh.is_iota); + EXPECT_THAT(device_mesh.dimensions(), ElementsAre(2, 32)); + EXPECT_EQ(device_mesh.num_elements(), 64); } class AutoShardingTest : public HloTestBase { @@ -346,6 +396,40 @@ ENTRY %elementwise { op::Sharding("{devices=[2,1,2]0,2,1,3 last_tile_dim_replicate}"))); } +TEST_F(AutoShardingTest, IotaPartiallyReplicatedShardingTest) { + constexpr absl::string_view kHloString = R"( +HloModule module + +ENTRY %elementwise { + iota1 = s32[11,1026]{1,0} iota(), iota_dimension=1 + param1 = s32[11,1026]{1,0} parameter(0), sharding={devices=[1,16,16]<=[16,16]T(1,0) last_tile_dim_replicate} + copy1 = s32[11,1026]{1,0} copy(iota1) + ROOT add1 = s32[11,1026]{1,0} add(copy1, param1) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + AutoSharding( + /* option */ { + .enable = true, + .preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings, + .only_allow_divisible_input_output = false, + .device_mesh_shape = {16, 16}, + .device_mesh_alpha = {1.0, 1.0}, + .device_mesh_beta = {0.01, 1.0}}) + .Run(module.get())); + VLOG(10) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* iota = FindInstruction(module.get(), "iota1"); + ASSERT_NE(iota, nullptr); + EXPECT_THAT( + iota, op::Sharding( + "{devices=[1,16,16]<=[16,16]T(1,0) last_tile_dim_replicate}")); +} + TEST_F(AutoShardingTest, SliceMixedUserShardingTest) { constexpr absl::string_view kHloString = R"( HloModule module @@ -520,7 +604,7 @@ ENTRY %RngBitGenerator (p0: u64[2]) -> (u64[2], u32[16,16]) { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); VLOG(10) << module->ToString(); EXPECT_TRUE(changed); - auto* instruction = FindInstruction(module.get(), "p0"); + const HloInstruction* instruction = FindInstruction(module.get(), "p0"); ASSERT_NE(instruction, nullptr); EXPECT_THAT(instruction, op::Sharding("{replicated}")); } @@ -653,14 +737,49 @@ ENTRY %RngBitGenerator { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); VLOG(10) << module->ToString(); EXPECT_TRUE(changed); - auto* param0 = FindInstruction(module.get(), "param.0"); - auto* param1 = FindInstruction(module.get(), "param.1"); + const HloInstruction* param0 = FindInstruction(module.get(), "param.0"); + const HloInstruction* param1 = FindInstruction(module.get(), "param.1"); ASSERT_NE(param0, nullptr); ASSERT_NE(param0, nullptr); EXPECT_THAT(param0, op::Sharding("{replicated}")); EXPECT_THAT(param1, op::Sharding("{replicated}")); } +TEST_F(AutoShardingTest, DotMixedMeshStrategies) { + constexpr absl::string_view kHloString = R"( +HloModule module +ENTRY %entry { + %param0 = f32[8192,23]{1,0} parameter(0), sharding={devices=[4,1]0,1,2,3} + %param1 = f32[23,23]{1,0} parameter(1) + %dot = f32[8192,23]{1,0} dot(%param0, %param1), lhs_contracting_dims={1}, rhs_contracting_dims={1} + ROOT %copy = f32[8192,23]{1,0} copy(%dot) +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + AutoShardingOption option; + option.enable = true; + option.device_mesh_shape = {2, 2}; + option.device_mesh_ids = {0, 1, 2, 3}; + option.device_mesh_alpha = {1.0, 1.0}; + option.device_mesh_beta = {0.01, 1.0}; + option.solve_nd_sharding_iteratively = false; + option.preserve_shardings = + AutoShardingOption::PreserveShardingsType::kKeepAllShardings; + TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); + VLOG(2) << module->ToString(); + EXPECT_TRUE(changed); + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + const HloInstruction* dot = FindInstruction(module.get(), "dot"); + ASSERT_NE(param0, nullptr); + ASSERT_NE(param1, nullptr); + ASSERT_NE(dot, nullptr); + EXPECT_THAT(param0, op::Sharding("{devices=[4,1]0,1,2,3}")); + EXPECT_THAT(param1, op::Sharding("{replicated}")); + EXPECT_THAT(dot, op::Sharding("{devices=[4,1]0,1,2,3}")); +} + TEST_F(AutoShardingTest, DotLHSTwoNonContractingDims) { constexpr absl::string_view kHloString = R"( HloModule module @@ -683,9 +802,9 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); VLOG(2) << module->ToString(); EXPECT_TRUE(changed); - auto* param0 = FindInstruction(module.get(), "param0"); - auto* param1 = FindInstruction(module.get(), "param1"); - auto* dot = FindInstruction(module.get(), "dot"); + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + const HloInstruction* dot = FindInstruction(module.get(), "dot"); ASSERT_NE(param0, nullptr); ASSERT_NE(param1, nullptr); ASSERT_NE(dot, nullptr); @@ -736,9 +855,9 @@ ENTRY %entry { TF_ASSERT_OK_AND_ASSIGN(bool changed, AutoSharding(option).Run(module.get())); VLOG(2) << module->ToString(); EXPECT_TRUE(changed); - auto* param0 = FindInstruction(module.get(), "param0"); - auto* param1 = FindInstruction(module.get(), "param1"); - auto* dot = FindInstruction(module.get(), "dot"); + const HloInstruction* param0 = FindInstruction(module.get(), "param0"); + const HloInstruction* param1 = FindInstruction(module.get(), "param1"); + const HloInstruction* dot = FindInstruction(module.get(), "dot"); ASSERT_NE(param0, nullptr); ASSERT_NE(param1, nullptr); ASSERT_NE(dot, nullptr); @@ -2448,6 +2567,36 @@ ENTRY entry { input_output_alias_config_after.ToString()); } +TEST(NormalizeTest, NormalizeHandlesNegativeCosts) { + EdgeReshardingCostMatrix edge_cost(2, 2); + edge_cost(0, 0).communication_cost = -100; + edge_cost(0, 1).communication_cost = 200; + edge_cost(1, 0).communication_cost = 300; + edge_cost(1, 1).communication_cost = 400; + + const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost); + + EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 0); + EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 300); + EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 400); + EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 500); +} + +TEST(NormalizeTest, NormalizeHandlesPositiveCosts) { + EdgeReshardingCostMatrix edge_cost(2, 2); + edge_cost(0, 0).communication_cost = 100; + edge_cost(0, 1).communication_cost = 200; + edge_cost(1, 0).communication_cost = 300; + edge_cost(1, 1).communication_cost = 400; + + const EdgeReshardingCostMatrix normalized_edge_cost = Normalize(edge_cost); + + EXPECT_EQ(normalized_edge_cost(0, 0).communication_cost, 100); + EXPECT_EQ(normalized_edge_cost(0, 1).communication_cost, 200); + EXPECT_EQ(normalized_edge_cost(1, 0).communication_cost, 300); + EXPECT_EQ(normalized_edge_cost(1, 1).communication_cost, 400); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc index 4b86f967ab7da2..6cb0c6b5c2ef9f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.cc @@ -37,12 +37,12 @@ limitations under the License. #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "json/json.h" #include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -59,16 +59,10 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" namespace xla { namespace spmd { -inline const HloInstruction* PassThroughCustomCallMarkerGetSource( - const HloInstruction* ins); -inline HloInstruction* PassThroughCustomCallMarkerUser( - HloInstruction* raw_user, const HloInstruction* inst); - std::optional GetInputSharding(const HloInstruction* ins, int64_t op_index, const HloSharding& output_sharding, @@ -109,26 +103,6 @@ std::optional GetInputSharding(const HloInstruction* ins, return inferred_sharding; } -// Return whether the instruction is an activation from another pipeline stage. -bool IsActivationFromAnotherStage(const HloInstruction* ins, - const InstructionBatchDimMap& batch_dim_map) { - if (!(ins->opcode() == HloOpcode::kParameter && - batch_dim_map.contains(GetBatchDimMapKey(ins)))) { - return false; - } - - for (const HloInstruction* user : ins->users()) { - if (!(user->opcode() == HloOpcode::kTuple && user->users().size() == 1 && - user->users().front()->IsCustomCall(kPipelineMarker) && - absl::StrContains(user->users().front()->metadata().op_type(), - "start"))) { - return false; - } - } - - return true; -} - // Propagate sharding for dim-wise operations (e.g., slice, pad) which works // independently on each dimension. // The sharding can successfully propagate if the operation only happens @@ -194,11 +168,6 @@ InstructionDepthMap BuildInstructionDepthMap( if (degree_dict[inst] == 0) { depth_map[inst] = 0; - // Add some initial depth for activations from other pipeline stages. - if (IsActivationFromAnotherStage(inst, batch_dim_map)) { - depth_map[inst] = 20; - } - current_frontier.push_back(inst); collected++; } @@ -246,10 +215,6 @@ InstructionDepthMap BuildInstructionDepthMap( if (reset) { depth_map[node] = 0; - } else if (node->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(node->operand(0))) { - depth_map[node] = - depth_map.at(PassThroughCustomCallMarkerGetSource(node)); } else { int64_t max_depth = depth_map.at(inst) + delta; for (const HloInstruction* operand : node->operands()) { @@ -813,12 +778,6 @@ InstructionBatchDimMap BuildInstructionBatchDimMap( batch_map[ins->name()] = batch_dim_of_source; } } - - if (ins->IsCustomCall(kPipelineMarker) && - absl::StrContains(ins->metadata().op_type(), "start")) { - // Reset the status after meet a new pipeline marker. - set_the_next_dot_conv = true; - } } int64_t previous_cnt = 0; while (true) { @@ -911,7 +870,7 @@ void RemoveDuplicatedStrategy(std::unique_ptr& strategy_group) { } } -bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, +bool IsDivisible(const HloInstruction* ins, const DeviceMesh& device_mesh, absl::Span tensor_dims, absl::Span mesh_dims) { CHECK_EQ(tensor_dims.size(), mesh_dims.size()); @@ -968,8 +927,7 @@ void TryReduceWithCommonAncestor(InstructionSet& replicated_set, for (HloInstruction* node : boundary_set) { HloInstruction* cur = node; while (cur->operand_count() == 1) { - HloInstruction* operand = - PassThroughCustomCallMarkerOperand(cur->mutable_operand(0), cur); + HloInstruction* operand = cur->mutable_operand(0); if (replicated_set.contains(operand)) { path.insert(cur); } @@ -1007,8 +965,7 @@ void UseAllReduceForGradAcc(InstructionSet& replicated_set, // Find the add instruction for grad accumulation, skip the identity marker // for remat and other elementwise ops. - HloInstruction* add = - PassThroughCustomCallMarkerUser(inst->users().front(), inst); + HloInstruction* add = inst->users().front(); if (add->opcode() == HloOpcode::kGetTupleElement || add->opcode() == HloOpcode::kTranspose) { if (add->users().size() != 1) { @@ -1025,7 +982,7 @@ void UseAllReduceForGradAcc(InstructionSet& replicated_set, } CHECK_EQ(add->users().size(), 1); // Skip the end marker of backward computation - add = PassThroughCustomCallMarkerUser(add->users().front(), add); + add = add->users().front(); // Do not partition the dot, add and parameter, so we can generate // all-reduce for grad accumulation. @@ -1037,7 +994,7 @@ void UseAllReduceForGradAcc(InstructionSet& replicated_set, replicated_set.erase(cur); for (auto x : cur->operands()) { - dfs_remove(PassThroughCustomCallMarkerOperand(x, cur)); + dfs_remove(x); } }; @@ -1123,7 +1080,7 @@ int64_t NumTileDimensions(const HloSharding& spec) { } bool TileAssignmentMatchesMesh(const HloSharding& spec, - const Array& mesh) { + const DeviceMesh& mesh) { int sharded_dims = 0; for (int i = 0; i < spec.tile_assignment().num_dimensions(); ++i) { if (spec.tile_assignment().dim(i) > 1) { @@ -1138,39 +1095,23 @@ bool TileAssignmentMatchesMesh(const HloSharding& spec, return sharded_dims <= 0; } -absl::StatusOr> GetTensorDimToMeshDimNoCrash( - int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, bool consider_reverse_device_meshes) { - if (spec.IsReplicated()) { - return std::vector(tensor_shape_rank, -1); - } - // Check the compatibility of tensor_shape_rank and spec - if (tensor_shape_rank != spec.TiledDataRank()) { - return absl::InvalidArgumentError( - "Tensor shape rank should be equal to the tiled data rank of the input " - "spec."); - } - +absl::StatusOr> GetMeshDimPermutationOrderInShardingSpec( + const HloSharding& spec, const DeviceMesh& device_mesh, + bool consider_reverse_device_meshes) { auto check_mesh = [&](const Array& mesh) -> std::optional> { // Permute the dimensions (or axes in numpy term), find the transform that // makes tile_assignment == device_mesh. std::vector axes(mesh.num_dimensions()); absl::c_iota(axes, 0); - bool found = false; do { Array transposed_mesh = Transpose(mesh, axes); if (std::equal(transposed_mesh.begin(), transposed_mesh.end(), spec.tile_assignment().array().begin())) { - found = true; - break; + return axes; } } while (absl::c_next_permutation(axes)); - if (found) { - return std::optional>(axes); - } else { - return std::nullopt; - } + return std::nullopt; }; // This is an expensive search, as we try all possible meshes obtained by @@ -1178,7 +1119,6 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( // the somewhat rare kReverse HLO op. The hope therefore is that most calls to // the function that reach here will find a mapping within the first iteration // of the loop below. - bool found = false; std::vector axes(device_mesh.num_dimensions()); size_t num_subsets = consider_reverse_device_meshes ? (1 << device_mesh.num_dimensions()) : 1; @@ -1199,24 +1139,35 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( *device = device_mesh(original_indices); }); if (auto result = check_mesh(new_mesh); result.has_value()) { - axes = result.value(); - found = true; - break; + return result.value(); } } + return absl::NotFoundError(absl::StrCat("Could not find mapping for ", + spec.ToString(), " with device mesh ", + device_mesh.ToString())); +} - if (!found) { - return absl::NotFoundError( - absl::StrCat("Could not find mapping for ", spec.ToString(), - " with device mesh ", device_mesh.ToString())); +absl::StatusOr> GetTensorDimToMeshDimNoCrash( + int64_t tensor_shape_rank, const HloSharding& spec, + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { + if (spec.IsReplicated()) { + return std::vector(tensor_shape_rank, -1); + } + // Check the compatibility of tensor_shape_rank and spec + if (tensor_shape_rank != spec.TiledDataRank()) { + return absl::InvalidArgumentError( + "Tensor shape rank should be equal to the tiled data rank of the input " + "spec."); } - if (!TileAssignmentMatchesMesh(spec, device_mesh)) { return absl::InvalidArgumentError( "Device mesh and tile assignment need to have the same number of " "sharded dims."); } + TF_ASSIGN_OR_RETURN(std::vector axes, + GetMeshDimPermutationOrderInShardingSpec( + spec, device_mesh, consider_reverse_device_meshes)); // Transform tile_assignment_dimensions using found transformation (axes). std::vector tensor_dim_to_device_dim(tensor_shape_rank, -1); int mesh_index = 0; @@ -1234,7 +1185,7 @@ absl::StatusOr> GetTensorDimToMeshDimNoCrash( std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, bool consider_reverse_device_meshes) { + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes) { auto mapping_or = GetTensorDimToMeshDimNoCrash( tensor_shape_rank, spec, device_mesh, consider_reverse_device_meshes); if (mapping_or.ok()) { @@ -1244,9 +1195,10 @@ std::vector GetTensorDimToMeshDim( } } -absl::StatusOr ComputeIntermediateShape( - const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Shape& shape, const Array& device_mesh) { +absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, + const HloSharding& dst_sharding, + const Shape& shape, + const DeviceMesh& device_mesh) { int64_t src_n_dim = NumTileDimensions(src_sharding); const HloSharding* sharding_1d; @@ -1282,7 +1234,7 @@ absl::StatusOr ComputeIntermediateShape( HloInstruction* ReshardTensor(HloInstruction* tensor, const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { const Shape& shape = tensor->shape(); HloComputation* computation = tensor->parent(); @@ -1330,7 +1282,7 @@ HloInstruction* ReshardTensor(HloInstruction* tensor, absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_shardings, - const Array& device_mesh) { + const DeviceMesh& device_mesh) { size_t tuple_size = inst->shape().tuple_shapes_size(); const HloSharding& current_sharding = inst->sharding(); @@ -1394,7 +1346,7 @@ absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( absl::Status FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, absl::flat_hash_map>& preserve_shardings) { const HloInstruction* operand = inst->operand(0); @@ -1436,7 +1388,7 @@ absl::Status FixMixedMeshShapeReshardingGetTupleElement( absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, ReshardingCache* resharding_cache) { HloInstruction* operand = inst->mutable_operand(operand_num); if (operand->opcode() == HloOpcode::kOutfeed || @@ -1535,7 +1487,7 @@ bool IsDivisible(int64_t numerator, int64_t denominator) { } std::vector> GetReplicaGroupsAlongOneDimension( - const Array& device_mesh, int32_t communication_dim) { + const DeviceMesh& device_mesh, int32_t communication_dim) { CHECK_LT(communication_dim, device_mesh.num_dimensions()); std::vector indices(device_mesh.num_dimensions(), 0); std::vector> replica_groups; @@ -1556,10 +1508,10 @@ std::vector> GetReplicaGroupsAlongOneDimension( } // Create a HloSharding that tiles some tensor dims on some device mesh dims. -HloSharding Tile(const Shape& tensor_shape, - absl::Span tensor_dims, - absl::Span mesh_dims, - const Array& device_mesh) { +HloSharding TileV1(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { CHECK_EQ(tensor_dims.size(), mesh_dims.size()); CHECK(tensor_shape.IsArray()); std::vector tile_assignment_dimensions(tensor_shape.rank(), 1); @@ -1567,8 +1519,12 @@ HloSharding Tile(const Shape& tensor_shape, // Split on certain mesh dimensions int64_t split_prod = 1; for (size_t i = 0; i < tensor_dims.size(); ++i) { - tile_assignment_dimensions[tensor_dims[i]] = device_mesh.dim(mesh_dims[i]); - split_prod *= device_mesh.dim(mesh_dims[i]); + int64_t num_devices_for_tensor_dim = 1; + for (int64_t mesh_dim_idx : mesh_dims[i]) { + num_devices_for_tensor_dim *= device_mesh.dim(mesh_dim_idx); + } + tile_assignment_dimensions[tensor_dims[i]] = num_devices_for_tensor_dim; + split_prod *= num_devices_for_tensor_dim; } // Replicate on remaining mesh dimensions bool replicate_on_last_tile_dim = false; @@ -1582,35 +1538,58 @@ HloSharding Tile(const Shape& tensor_shape, std::vector tile_assignment_devices; tile_assignment_devices.reserve(device_mesh.num_elements()); - std::vector tmp_indices(device_mesh.num_dimensions(), 0); - std::function)> + std::function)> generate_tile_assignment_devices; - generate_tile_assignment_devices = [&](int64_t tensor_dim, + generate_tile_assignment_devices = [&](int64_t current_tensor_dim, + int64_t current_mesh_dim_idx, std::vector mesh_indices) { - if (tensor_dim == tensor_shape.rank() - 1) { - AppendFlattenElements(&tile_assignment_devices, device_mesh, mesh_indices, - -1, tmp_indices); + int64_t current_tensor_dim_index = + GetIndex(tensor_dims, current_tensor_dim); + bool proceed_to_next_tensor_dim = false; + if (current_tensor_dim_index >= 0) { + proceed_to_next_tensor_dim = + (current_mesh_dim_idx == + mesh_dims[current_tensor_dim_index].size() - 1); } else { - int64_t next_tensor_dim = tensor_dim + 1; - int64_t next_mesh_dim = -1; + proceed_to_next_tensor_dim = true; + } + + if (proceed_to_next_tensor_dim && + current_tensor_dim == tensor_shape.rank() - 1) { + AppendFlattenElements(&tile_assignment_devices, device_mesh.device_array, + mesh_indices); + return; + } - int64_t index = GetIndex(tensor_dims, next_tensor_dim); - if (index >= 0) { - next_mesh_dim = mesh_dims[index]; + int64_t next_tensor_dim, next_mesh_dim_idx = -1, next_mesh_dim = -1; + if (proceed_to_next_tensor_dim) { + next_tensor_dim = current_tensor_dim + 1; + next_mesh_dim_idx = -1; + int64_t next_tensor_dim_index = GetIndex(tensor_dims, next_tensor_dim); + if (next_tensor_dim_index >= 0) { + next_mesh_dim_idx = 0; + next_mesh_dim = mesh_dims[next_tensor_dim_index][0]; } + } else { + next_tensor_dim = current_tensor_dim; + next_mesh_dim_idx = current_mesh_dim_idx + 1; + next_mesh_dim = mesh_dims[current_tensor_dim_index][next_mesh_dim_idx]; + } - for (int64_t i = 0; i < tile_assignment_dimensions[next_tensor_dim]; - ++i) { - if (next_mesh_dim != -1) { - mesh_indices[next_mesh_dim] = i; - } - generate_tile_assignment_devices(next_tensor_dim, mesh_indices); + int64_t limit = + (next_mesh_dim_idx >= 0) ? device_mesh.dim(next_mesh_dim) : 1; + for (int64_t i = 0; i < limit; ++i) { + if (next_mesh_dim != -1) { + mesh_indices[next_mesh_dim] = i; } + generate_tile_assignment_devices(next_tensor_dim, next_mesh_dim_idx, + mesh_indices); } }; std::vector mesh_indices(device_mesh.num_dimensions(), -1); - generate_tile_assignment_devices(-1, mesh_indices); + generate_tile_assignment_devices(/*current_tensor_dim=*/-1, + /*current_mesh_dim_idx=*/-1, mesh_indices); // Make HloSharding Array tile_assignment(tile_assignment_dimensions); @@ -1625,6 +1604,93 @@ HloSharding Tile(const Shape& tensor_shape, : HloSharding::Tile(std::move(tile_assignment)); } +HloSharding TileV2(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { + CHECK_EQ(tensor_dims.size(), mesh_dims.size()); + CHECK(tensor_shape.IsArray()); + std::vector tile_assignment_dimensions(tensor_shape.rank(), 1); + std::vector transpose_perm; + absl::Span reshape_dims = device_mesh.dimensions(); + + struct TensorDimWithIndex { + int64_t tensor_dim; + int64_t idx_in_vector; + }; + + std::vector sorted_tensor_dims(tensor_dims.size()); + for (size_t i = 0; i < tensor_dims.size(); ++i) { + sorted_tensor_dims[i].tensor_dim = tensor_dims[i]; + sorted_tensor_dims[i].idx_in_vector = i; + } + + absl::c_sort(sorted_tensor_dims, + [](const TensorDimWithIndex& a, const TensorDimWithIndex& b) { + return a.tensor_dim < b.tensor_dim; + }); + + // Split on certain mesh dimensions + int64_t split_prod = 1; + for (const TensorDimWithIndex& tensor_dim_with_index : sorted_tensor_dims) { + int64_t tensor_dim = tensor_dim_with_index.tensor_dim; + const std::vector& mesh_dims_for_this_tensor_dim = + mesh_dims[tensor_dim_with_index.idx_in_vector]; + int64_t num_devices_for_tensor_dim = 1; + for (int64_t mesh_dim_idx : mesh_dims_for_this_tensor_dim) { + num_devices_for_tensor_dim *= device_mesh.dim(mesh_dim_idx); + transpose_perm.push_back(mesh_dim_idx); + } + tile_assignment_dimensions[tensor_dim] = num_devices_for_tensor_dim; + split_prod *= num_devices_for_tensor_dim; + } + // Replicate on remaining mesh dimensions + bool replicate_on_last_tile_dim = false; + if (split_prod < device_mesh.num_elements()) { + tile_assignment_dimensions.push_back(device_mesh.num_elements() / + split_prod); + replicate_on_last_tile_dim = true; + } + + for (int i = 0; i < device_mesh.num_dimensions(); ++i) { + if (absl::c_find(transpose_perm, i) == transpose_perm.end()) { + transpose_perm.push_back(i); + } + } + + // Make HloSharding + TileAssignment tile_assignment(tile_assignment_dimensions, reshape_dims, + transpose_perm); + + return replicate_on_last_tile_dim + ? HloSharding::PartialTile(std::move(tile_assignment)) + : HloSharding::Tile(std::move(tile_assignment)); +} + +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh) { + if (device_mesh.is_iota) { + return TileV2(tensor_shape, tensor_dims, mesh_dims, device_mesh); + } + return TileV1(tensor_shape, tensor_dims, mesh_dims, device_mesh); +} + +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + absl::Span mesh_dims, + const DeviceMesh& device_mesh) { + std::vector> mesh_dims_general(mesh_dims.size()); + for (int i = 0; i < mesh_dims.size(); ++i) { + mesh_dims_general[i].push_back(mesh_dims[i]); + } + if (device_mesh.is_iota) { + return TileV2(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); + } + return TileV1(tensor_shape, tensor_dims, mesh_dims_general, device_mesh); +} + AliasMap BuildAliasMap(const HloModule* module, const HloInputOutputAliasConfig& alias_config) { AliasMap alias_map; @@ -1633,10 +1699,6 @@ AliasMap BuildAliasMap(const HloModule* module, const auto& parameter_instructions = entry->parameter_instructions(); const HloInstruction* output_tuple = entry->root_instruction(); - if (IsCustomCallMarker(output_tuple)) { - output_tuple = output_tuple->operand(0); - } - absl::flat_hash_map> parameter_index_to_operand_map; alias_config.ForEachAlias([&](const ShapeIndex& output_index, @@ -2079,26 +2141,32 @@ absl::StatusOr AdjustShardingsWithPartialMeshShape( } std::vector> DecomposeMeshShapes( - std::vector mesh_shape) { + const std::vector& mesh_shape, + const std::vector& mesh_alpha, + const std::vector& mesh_beta) { // Get the ranking order based on the size of each value. std::vector ranking_order; std::vector> partial_mesh_shapes; - std::vector> pairs(mesh_shape.size()); + std::vector> tuples( + mesh_shape.size()); for (size_t i = 0; i < mesh_shape.size(); i++) { - pairs[i] = {mesh_shape[i], i}; + // Here we prioritize the throughput term (beta) over the latency term + // (alpha), assuming that collectives are more often throughput-bound. This + // is currently somewhat of an arbitrary choice and can be changed. + tuples[i] = {mesh_beta[i], mesh_alpha[i], mesh_shape[i], i}; } // For vector of size 3, the sorted indices happen to be the same as their // rankings. mesh_shapes over 3 elements are not supported by AutoSharding. - std::sort(pairs.begin(), pairs.end(), - std::greater>()); + std::sort(tuples.begin(), tuples.end(), + std::greater>()); std::vector partial_mesh_shape(mesh_shape.size(), 1); // Starts from the largest dimension of mesh_shape. - for (size_t i = 0; i < pairs.size(); i++) { - if (pairs[i].first == 1) { - break; + for (size_t i = 0; i < tuples.size(); i++) { + if (std::get<2>(tuples[i]) == 1) { + continue; } - partial_mesh_shape[pairs[i].second] = pairs[i].first; + partial_mesh_shape[std::get<3>(tuples[i])] = std::get<2>(tuples[i]); // Needs to copy partial_mesh_shape. partial_mesh_shapes.push_back(partial_mesh_shape); } @@ -2209,30 +2277,37 @@ std::vector> InferMeshShapesToTry( const HloModule& module) { int64_t sharding_1d = -1; absl::flat_hash_set> shardings_nd; + int max_shardings_nd_dimension = -1; std::function process_sharding; - process_sharding = [&sharding_1d, &shardings_nd, - &process_sharding](const HloSharding& sharding) { + process_sharding = [&](const HloSharding& sharding) { if (sharding.IsTuple()) { for (const HloSharding& child : sharding.tuple_elements()) { process_sharding(child); } - } else if (!sharding.IsReplicated() && !sharding.IsTileMaximal() && - !sharding.IsManual()) { - absl::Span dims = sharding.tile_assignment().dimensions(); - std::vector dims_greater_than_one; - for (const int64_t dim : dims) { - if (dim > 1) { - dims_greater_than_one.push_back(dim); - } - } - if (dims_greater_than_one.size() == 1) { - CHECK(sharding_1d == -1 || sharding_1d == dims_greater_than_one[0]); - sharding_1d = dims_greater_than_one[0]; - } else { - std::sort(dims_greater_than_one.begin(), dims_greater_than_one.end()); - shardings_nd.insert(dims_greater_than_one); + return; + } + if (sharding.IsReplicated() || sharding.IsTileMaximal() || + sharding.IsManual()) { + return; + } + absl::Span dims = sharding.tile_assignment().dimensions(); + std::vector dims_greater_than_one; + for (const int64_t dim : dims) { + if (dim > 1) { + dims_greater_than_one.push_back(dim); } } + if (dims_greater_than_one.size() == 1) { + CHECK(sharding_1d == -1 || sharding_1d == dims_greater_than_one[0]); + sharding_1d = dims_greater_than_one[0]; + } else { + std::sort(dims_greater_than_one.begin(), dims_greater_than_one.end()); + shardings_nd.insert(dims_greater_than_one); + + max_shardings_nd_dimension = + std::max(max_shardings_nd_dimension, + static_cast(dims_greater_than_one.size())); + } }; for (const HloComputation* comp : module.computations()) { @@ -2243,20 +2318,29 @@ std::vector> InferMeshShapesToTry( } } + for (auto mesh_shape_it = shardings_nd.begin(), end = shardings_nd.end(); + mesh_shape_it != end;) { + // `erase()` will invalidate `mesh_shape_it`, so advance `mesh_shape_it` + // first. + auto copy_it = mesh_shape_it++; + if (copy_it->size() < max_shardings_nd_dimension) { + shardings_nd.erase(copy_it); + } + } + if (shardings_nd.empty() && sharding_1d < 0) { return {}; - } else if (shardings_nd.empty()) { - CHECK_GE(sharding_1d, 0); + } + if (shardings_nd.empty()) { return {{1, sharding_1d}}; - } else { - std::vector> result; - for (std::vector mesh : shardings_nd) { - do { - result.push_back(std::vector(mesh)); - } while (std::next_permutation(std::begin(mesh), std::end(mesh))); - } - return result; } + std::vector> result; + for (std::vector mesh : shardings_nd) { + do { + result.push_back(std::vector(mesh)); + } while (std::next_permutation(std::begin(mesh), std::end(mesh))); + } + return result; } std::vector> InferOrEnumerateMeshShapesToTry( @@ -2273,9 +2357,7 @@ std::vector> InferOrEnumerateMeshShapesToTry( dedup_result.insert( absl::btree_multiset(mesh_shape.begin(), mesh_shape.end())); } - mesh_shapes.clear(); - for (const absl::btree_multiset& mesh_shape_set : dedup_result) { mesh_shapes.push_back( std::vector(mesh_shape_set.begin(), mesh_shape_set.end())); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h index a4ea23c922fc06..678030f3520fb4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding_util.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -48,10 +49,7 @@ limitations under the License. namespace xla { namespace spmd { -inline constexpr absl::string_view kPipelineMarker = "xla_pipeline_marker"; inline constexpr absl::string_view kIdentityMarker = "identity"; -inline constexpr absl::string_view kPipelineMarkerStartType = "start"; -inline constexpr absl::string_view kPipelineMarkerEndType = "end"; inline constexpr int64_t kAutoShardingPointerSize = 8; @@ -85,7 +83,7 @@ inline std::string ToAdaptiveString(const HloInstruction* ins) { // Return whether the tensor shape is divisible by // the number of devices along multiple dimensions. -bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, +bool IsDivisible(const HloInstruction* ins, const DeviceMesh& device_mesh, absl::Span tensor_dims, absl::Span mesh_dims); @@ -94,9 +92,11 @@ bool IsDivisible(const HloInstruction* ins, const Array& device_mesh, // Append elements of `array` to `result`. The `indices` is a generalized // multi-dimensional index that can index a whole row (use -1 to indicate this). template -void AppendFlattenElements(std::vector* result, const Array& array, - absl::Span indices, int cur_depth, - std::vector cur_indices) { +void AppendFlattenElementsInternal(std::vector* result, + const Array& array, + absl::Span indices, + int cur_depth, + std::vector cur_indices) { if (cur_depth == array.num_dimensions() - 1) { result->push_back(array(cur_indices)); } else { @@ -106,15 +106,25 @@ void AppendFlattenElements(std::vector* result, const Array& array, if (index == -1) { for (int64_t i = 0; i < array.dim(next_depth); ++i) { cur_indices[next_depth] = i; - AppendFlattenElements(result, array, indices, next_depth, cur_indices); + AppendFlattenElementsInternal(result, array, indices, next_depth, + cur_indices); } } else { cur_indices[next_depth] = index; - AppendFlattenElements(result, array, indices, next_depth, cur_indices); + AppendFlattenElementsInternal(result, array, indices, next_depth, + cur_indices); } } } +template +void AppendFlattenElements(std::vector* result, const Array& array, + absl::Span indices) { + std::vector tmp_indices(array.num_dimensions(), 0); + AppendFlattenElementsInternal(result, array, indices, + /*cur_depth=*/-1, tmp_indices); +} + // Return the index of key in a span. -1 means not found. template int64_t GetIndex(absl::Span v, const T& key) { @@ -201,11 +211,6 @@ inline void ReplaceOperand(HloInstruction* inst, } } -// Return whether this instruction is a custom call marker introduced by us. -inline bool IsCustomCallMarker(const HloInstruction* inst) { - return inst->IsCustomCall({kPipelineMarker, kIdentityMarker}); -} - // Return whether this instruction is a TopK custom call. inline bool IsTopKCustomCall(const HloInstruction* inst) { return inst->opcode() == HloOpcode::kCustomCall && @@ -218,70 +223,6 @@ inline bool IsPartialReduceCustomCall(const HloInstruction* inst) { inst->custom_call_target() == "PartialReduce"; } -// Pass through the custom call marker and get the source instruction -inline const HloInstruction* PassThroughCustomCallMarkerGetSource( - const HloInstruction* ins) { - while (ins->opcode() == HloOpcode::kGetTupleElement && - IsCustomCallMarker(ins->operand(0))) { - const HloInstruction* custom_call = ins->operand(0); - const HloInstruction* tuple = custom_call->operand(0); - while (IsCustomCallMarker(tuple)) { - tuple = tuple->operand(0); - } - ins = tuple->operand(ins->tuple_index()); - } - return ins; -} - -// Pass through the custom call marker and get the acutal operand. -inline HloInstruction* PassThroughCustomCallMarkerOperand( - HloInstruction* raw_operand, const HloInstruction* inst) { - if (!IsCustomCallMarker(raw_operand)) { - return raw_operand; - } - - CHECK_EQ(inst->opcode(), HloOpcode::kGetTupleElement); - - int index = inst->tuple_index(); - return raw_operand->mutable_operand(0)->mutable_operand(index); -} - -// Return whether the tuple is only used by a custom call marker. -inline bool IsCustomCallMarkerTuple(const HloInstruction* inst) { - return inst->opcode() == HloOpcode::kTuple && inst->users().size() == 1 && - IsCustomCallMarker(inst->users().front()); -} - -// Pass through the custom call marker and get the actual user. -inline HloInstruction* PassThroughCustomCallMarkerUser( - HloInstruction* raw_user, const HloInstruction* inst) { - if (!IsCustomCallMarkerTuple(raw_user)) { - return raw_user; - } - - const HloInstruction* custom_call = raw_user->users().front(); - - int index = -1; - for (int i = 0; i < raw_user->operand_count(); i++) { - if (raw_user->operand(i) == inst) { - index = i; - break; - } - } - CHECK_NE(index, -1); - - HloInstruction* ret = nullptr; - for (HloInstruction* user : custom_call->users()) { - CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement); - if (user->tuple_index() == index) { - CHECK_EQ(ret, nullptr); - ret = user; - } - } - - return ret == nullptr ? raw_user : ret; -} - // Return the users of an instruction and its alias, // excluding the final output tuple. inline InstructionSet UsersWithAlias(const HloInstruction* inst, @@ -289,8 +230,7 @@ inline InstructionSet UsersWithAlias(const HloInstruction* inst, const HloInstruction* output) { InstructionSet users; for (HloInstruction* user : inst->users()) { - HloInstruction* pass_through_user = - PassThroughCustomCallMarkerUser(user, inst); + HloInstruction* pass_through_user = user; if (pass_through_user == output) { continue; } @@ -300,8 +240,7 @@ inline InstructionSet UsersWithAlias(const HloInstruction* inst, auto iter = alias_map.find(inst); if (iter != alias_map.end()) { for (HloInstruction* user : iter->second->users()) { - HloInstruction* pass_through_user = - PassThroughCustomCallMarkerUser(user, iter->second); + HloInstruction* pass_through_user = user; if (pass_through_user == output) { continue; } @@ -356,10 +295,6 @@ std::optional GetInputSharding(const HloInstruction* ins, const xla::CallGraph& call_graph, int64_t num_devices); -// Return whether the instruction is an activation from another pipeline stage. -bool IsActivationFromAnotherStage(const HloInstruction* inst, - const InstructionBatchDimMap& batch_dim_map); - // Depth analysis (breadth first search) that compute the depth of each // instruction. We also assign a much larger distance to heavy operators (e.g., // dot, convolution). @@ -442,63 +377,29 @@ int64_t NumTileDimensions(const HloSharding& spec); // When fixing mixed mesh resharding (see below), compute the correct // intermediate shape in order to insert copies. -absl::StatusOr ComputeIntermediateShape( - const HloSharding& src_sharding, const HloSharding& dst_sharding, - const Shape& shape, const Array& device_mesh); +absl::StatusOr ComputeIntermediateShape(const HloSharding& src_sharding, + const HloSharding& dst_sharding, + const Shape& shape, + const DeviceMesh& device_mesh); // Forcibly set the sharding of the operand of inst. // Also fix the resharding between 1d and 2d logical mesh. absl::Status FixMixedMeshShapeReshardingGetTupleElement( HloInstruction* inst, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, absl::flat_hash_map>& preserve_shardings); absl::Status FixMixedMeshShapeReshardingGetTupleElementWithTupleOutput( HloInstruction* inst, const std::vector>& dst_sharding, - const Array& device_mesh); + const DeviceMesh& device_mesh); absl::Status FixMixedMeshShapeResharding(HloInstruction* inst, int operand_num, const HloSharding& dst_sharding, - const Array& device_mesh, + const DeviceMesh& device_mesh, ReshardingCache* resharding_cache); -/* - * Gradient accumulation - */ -// Find all instructions that compute gradients in gradient accumulation. -// This is done by using the hint from pipeline_marker (gradient marker). -inline std::vector GetGradientComputationInstructions( - const std::vector& instructions) { - std::vector ret; - - for (size_t i = 0; i < instructions.size(); ++i) { - const HloInstruction* ins = instructions[i]; - if (ins->IsCustomCall(kPipelineMarker) && - (absl::StrContains(ins->metadata().op_name(), "compute_grad") || - absl::StrContains(ins->metadata().op_name(), "backward")) && - ins->metadata().op_type() == kPipelineMarkerEndType) { - const HloInstruction* tuple = ins->operand(0); - for (size_t j = 0; j < tuple->operand_count(); ++j) { - const HloInstruction* add = tuple->operand(j); - while (add->opcode() == HloOpcode::kAdd) { - ret.push_back(add->operand(0)); - ret.push_back(add->operand(1)); - - if (add->operand(0)->opcode() == HloOpcode::kAdd) { - add = add->operand(0); - } else { - add = add->operand(1); - } - } - } - } - } - - return ret; -} - // Gets the mapping vector from dim_from to dim_to. // Example: GetDimensionMapping([2], 3) = [0, 1, -1] std::vector GetDimensionMapping( @@ -511,7 +412,7 @@ bool IsDivisible(int64_t numerator, int64_t denominator); // be any number of dimensions. |communication_dim| has to be one of // |device_mesh|'s dimension. std::vector> GetReplicaGroupsAlongOneDimension( - const Array& device_mesh, int32_t communication_dim); + const DeviceMesh& device_mesh, int32_t communication_dim); // Gets values in |array| along |dim| while keeping indices at other // dimensions at 0, e.g., array is 2D and dim = 1, this returns array[0, 1], @@ -525,8 +426,7 @@ absl::StatusOr CheckArithmeticSequence( // Checks if the number of sharded dimensions in the tile assignment matches the // device mesh. -bool TileAssignmentMatchesMesh(const HloSharding& spec, - const Array& mesh); +bool TileAssignmentMatchesMesh(const HloSharding& spec, const DeviceMesh& mesh); // Get the mapped mesh dimension for every tensor dimension. // The returned value maps ith tensor dim to one mesh dim. -1 means the tensor @@ -535,18 +435,21 @@ bool TileAssignmentMatchesMesh(const HloSharding& spec, // mesh dim, and 1st tensor dim maps to the 2nd mesh dim. std::vector GetTensorDimToMeshDim( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, - bool consider_reverse_device_meshes = false); + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); absl::StatusOr> GetTensorDimToMeshDimNoCrash( int64_t tensor_shape_rank, const HloSharding& spec, - const Array& device_mesh, - bool consider_reverse_device_meshes = false); + const DeviceMesh& device_mesh, bool consider_reverse_device_meshes = false); + +HloSharding Tile(const Shape& tensor_shape, + absl::Span tensor_dims, + const std::vector>& mesh_dims, + const DeviceMesh& device_mesh); HloSharding Tile(const Shape& tensor_shape, absl::Span tensor_dims, absl::Span mesh_dims, - const Array& device_mesh); + const DeviceMesh& device_mesh); AliasMap BuildAliasMap(const HloModule* module, const HloInputOutputAliasConfig& alias_config); @@ -629,10 +532,13 @@ inline bool AdjustShardingsWithPartialMeshShape( // Decompose mesh shapes into partial mesh shapes so that we can solve the auto // sharding problem iteratively. Returns partial mesh shapes with larger -// dimensions first. For example, input [1, 4, 2] returns [1, 4, 1] and [1, 4, -// 2]; input [4, 8, 2] returns [1, 8, 1], [4, 8, 1] and [ 4, 8, 2]. +// dimensions and more expensive collective costs first. For example, if all +// mesh axes all have collective costs, input [1, 4, 2] returns [1, 4, 1] and +// [1, 4, 2]; input [4, 8, 2] returns [1, 8, 1], [4, 8, 1] and [ 4, 8, 2]. std::vector> DecomposeMeshShapes( - std::vector mesh_shape); + const std::vector& mesh_shape, + const std::vector& mesh_alpha, + const std::vector& mesh_beta); bool OutputInputSameShapes(const HloInstruction* ins); diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc index c2d82d1766e5c5..42402e39a1496f 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.cc @@ -195,7 +195,7 @@ double ClusterEnvironment::CollectivePermuteCost( // operation as an all-gather on all mesh dimensions. double ClusterEnvironment::OverestimateReplicationCost( const Shape& shape, const HloSharding& src_spec, - const Array& device_mesh) const { + const DeviceMesh& device_mesh) const { if (src_spec.IsTileMaximal() || src_spec.IsManual()) { // TODO(b/238210866) Do not use kInfinityCost. return kInfinityCost; diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h index 19736d19e25f0a..d17b026dd8ffb4 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/cluster_environment.h @@ -17,16 +17,23 @@ limitations under the License. #define XLA_HLO_EXPERIMENTAL_AUTO_SHARDING_CLUSTER_ENVIRONMENT_H_ #include +#include +#include #include #include #include #include #include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/types/span.h" +#include "xla/hlo/experimental/auto_sharding/auto_sharding_device_mesh.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_option.h" #include "xla/hlo/experimental/auto_sharding/auto_sharding_util.h" #include "xla/hlo/experimental/auto_sharding/profiling_result.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/shape.h" namespace xla { namespace spmd { @@ -38,8 +45,8 @@ namespace spmd { // the real profiling result. class ClusterEnvironment { public: - ClusterEnvironment(const Array& original_device_mesh, - const Array& device_mesh, + ClusterEnvironment(const DeviceMesh& original_device_mesh, + const DeviceMesh& device_mesh, absl::Span mesh_alpha, absl::Span mesh_beta, const ProfilingResult& prof_result, @@ -121,6 +128,14 @@ class ClusterEnvironment { return tensor_dim_to_mesh_dim; } + double GetDefaultReplicatedPenalty() const { + double replicated_penalty = 0; + for (int i = 0; i < device_mesh_.num_dimensions(); ++i) { + replicated_penalty += AllReduceCost(1, i); + } + return std::round(replicated_penalty); + } + double AllGatherCost(double num_bytes, int mesh_dim) const; double AllReduceCost(double num_bytes, int32_t mesh_dim, @@ -146,7 +161,7 @@ class ClusterEnvironment { // shape `shape` sharded according to `src_spec`. double OverestimateReplicationCost(const Shape& shape, const HloSharding& src_spec, - const Array& device_mesh) const; + const DeviceMesh& device_mesh) const; double ReshardingCost(const Shape& shape, const HloSharding& src_spec, const HloSharding& dst_spec) const; @@ -162,11 +177,11 @@ class ClusterEnvironment { } // The original, complete device mesh shape that describes the hardware. - const Array original_device_mesh_; + const DeviceMesh original_device_mesh_; // When solve_nd_sharding_iteratively is true, it is a partial mesh shape from // the original_device_mesh_. When solve_nd_sharding_iteratively is false, it // is the same as original_device_mesh_. - const Array device_mesh_; + const DeviceMesh device_mesh_; // Bandwidth of the device mesh const std::vector mesh_alpha_; const std::vector mesh_beta_; @@ -176,11 +191,11 @@ class ClusterEnvironment { // Cache a flatten 1d version of the device mesh. // Used for mixed mesh shape strategies. - Array device_mesh_1d_; + DeviceMesh device_mesh_1d_; // Cache a flatten 1d version of the original device mesh. // Used for mixed mesh shape strategies. - Array original_device_mesh_1d_; + DeviceMesh original_device_mesh_1d_; // The option may override the cost of communication primitives const AutoShardingOption& auto_sharding_option_; diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index 18e806490fc740..e65c48d982d89d 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -25,7 +25,6 @@ cc_library( "dfs_hlo_visitor.cc", "dynamic_parameter_binding.cc", "hlo_computation.cc", - "hlo_frontend_attributes.cc", "hlo_input_output_alias_config.cc", "hlo_instruction.cc", "hlo_instructions.cc", @@ -33,6 +32,7 @@ cc_library( "hlo_module_metadata.cc", "hlo_op_metadata.cc", "hlo_opcode.cc", + "hlo_original_value.cc", "hlo_schedule.cc", "hlo_sharding.cc", "hlo_sharding_metadata.cc", @@ -46,7 +46,6 @@ cc_library( "hlo_clone_context.h", "hlo_computation.h", "hlo_domain_metadata.h", - "hlo_frontend_attributes.h", "hlo_input_output_alias_config.h", "hlo_instruction.h", "hlo_instructions.h", @@ -54,6 +53,7 @@ cc_library( "hlo_module_metadata.h", "hlo_op_metadata.h", "hlo_opcode.h", + "hlo_original_value.h", "hlo_schedule.h", "hlo_sharding.h", "hlo_sharding_metadata.h", @@ -70,6 +70,7 @@ cc_library( "//xla:protobuf_util", "//xla:shape_tree", "//xla:shape_util", + "//xla:sort_json", "//xla:status_macros", "//xla:types", "//xla:util", @@ -77,6 +78,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", "//xla/service:compilation_environments", + "//xla/service:computation_layout", "//xla/service:computation_placer_hdr", "//xla/service:hlo_lexer", "//xla/service:hlo_module_config", @@ -98,13 +100,13 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/lib/gtl:iterator_range", "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", - "@local_tsl//tsl/platform:human_readable_json", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", @@ -134,9 +136,9 @@ xla_cc_test( deps = [ ":backend_config", "//xla/service/gpu:backend_configs_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/hlo/ir/backend_config_test.cc b/third_party/xla/xla/hlo/ir/backend_config_test.cc index 5ffe3ae98b8d6c..09b56347e450ed 100644 --- a/third_party/xla/xla/hlo/ir/backend_config_test.cc +++ b/third_party/xla/xla/hlo/ir/backend_config_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.cc b/third_party/xla/xla/hlo/ir/hlo_computation.cc index 025b1ce4f4388e..4fbf057f6eb9a0 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.cc +++ b/third_party/xla/xla/hlo/ir/hlo_computation.cc @@ -1103,6 +1103,19 @@ HloInstruction* HloComputation::CreateCallInstruction( return call_instruction; } +HloInstruction* HloComputation::CreateCompositeCallInstruction( + absl::Span instructions_to_call, + const std::string& name, const std::string& attributes, int64_t version) { + HloInstruction* root = instructions_to_call.front(); + HloInstruction* call_instruction = + AddInstruction(HloInstruction::CreateCompositeCall( + root->shape(), root, name, attributes, version), + root->name()); + AppendInstructionsIntoCalledComputation(instructions_to_call, + call_instruction); + return call_instruction; +} + absl::StatusOr HloComputation::CreateAsyncInstructions( HloInstruction* instruction, absl::Span context_shapes, absl::string_view async_execution_thread, bool replace, diff --git a/third_party/xla/xla/hlo/ir/hlo_computation.h b/third_party/xla/xla/hlo/ir/hlo_computation.h index 956cf1abe1ede2..3e73a68762e74f 100644 --- a/third_party/xla/xla/hlo/ir/hlo_computation.h +++ b/third_party/xla/xla/hlo/ir/hlo_computation.h @@ -17,18 +17,20 @@ limitations under the License. #define XLA_HLO_IR_HLO_COMPUTATION_H_ #include -#include #include #include #include #include #include +#include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/log/check.h" #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/cord.h" #include "absl/strings/string_view.h" @@ -42,9 +44,14 @@ limitations under the License. #include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_tree.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/lib/gtl/iterator_range.h" +#include "tsl/platform/errors.h" namespace xla { @@ -465,7 +472,7 @@ class HloComputation { absl::Span instructions_to_fuse, HloInstruction::FusionKind fusion_kind); - // Creates a call instruction containing the given instructions. Instructions + // Creates a call instruction containing the given instructions. Instructions // must be in reverse topological order (root of the called computation // first). Replaces all uses of the original root instruction with the call // instruction. The original instructions are removed if they have no uses @@ -473,6 +480,16 @@ class HloComputation { HloInstruction* CreateCallInstruction( absl::Span instructions_to_call); + // Creates a composite call instruction containing the given instructions. + // Instructions must be in reverse topological order (root of the called + // computation first). Replaces all uses of the original root instruction with + // the composite call instruction. The original instructions are removed if + // they have no uses after creating the composite call (this is necessarily + // true for at least the root). + HloInstruction* CreateCompositeCallInstruction( + absl::Span instructions_to_call, + const std::string& name, const std::string& attributes, int64_t version); + // Creates an async start/done instruction pair where instruction is wrapped // inside an asynchronous computation. The context shapes are appended to the // output tuple of the asynchronous start which is backend specific. Returns diff --git a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc b/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc deleted file mode 100644 index 347edcec61f393..00000000000000 --- a/third_party/xla/xla/hlo/ir/hlo_frontend_attributes.cc +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/hlo/ir/hlo_frontend_attributes.h" - -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" - -namespace xla { - -std::string FrontendAttributesToString( - const FrontendAttributes& frontend_attributes) { - std::vector> sorted_attributes( - frontend_attributes.map().begin(), frontend_attributes.map().end()); - absl::c_sort(sorted_attributes); - // Frontend attribute is a comma-separated list of attribute="value" pairs, - // e.g., frontend_attributes={name="value_a",type="int32_t"}. - const auto formatter = [](std::string* out, - const std::pair& item) { - absl::StrAppend(out, item.first, "=\"", item.second, "\""); - }; - return absl::StrFormat("{%s}", - absl::StrJoin(sorted_attributes, ",", formatter)); -} - -} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.cc b/third_party/xla/xla/hlo/ir/hlo_instruction.cc index 6ccbfcdb63a1a5..37d7a39d8ee0e0 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.cc @@ -55,11 +55,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_domain_metadata.h" -#include "xla/hlo/ir/hlo_frontend_attributes.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_op_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/hlo/ir/ptrvec.h" @@ -74,14 +74,16 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/sort_json.h" #include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/gtl/iterator_range.h" #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" -#include "tsl/platform/human_readable_json.h" #include "tsl/platform/logging.h" // IWYU pragma: keep +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -1158,12 +1160,43 @@ absl::StatusOr> HloInstruction::CreateFromProto( << instruction->opcode() << proto.name(); TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode(); - auto call_instruction = new HloCallInstruction( - shape, all_operands(), - computation_map.at(proto.called_computation_ids()[0])); - call_instruction->set_output_to_operand_aliasing( - output_to_operand_aliasing()); - instruction = absl::WrapUnique(call_instruction); + if (proto.is_composite()) { + TF_RET_CHECK(proto.has_frontend_attributes()) + << "A composite call op must have frontend attributes"; + auto map = proto.frontend_attributes().map(); + auto name = map.find("composite.name"); + TF_RET_CHECK(name != map.end() && !name->second.empty()) + << "A composite call op must have frontend attributes with key " + "composite.name whose value is non-empty"; + + auto attributes = map.find("composite.attributes"); + TF_RET_CHECK(attributes == map.end() || !attributes->second.empty()) + << "A composite call op must have frontend attributes with key " + "composite.attributes whose value is default: {} or non-empty"; + + auto version_str = map.find("composite.version"); + int64_t version = 0; + TF_RET_CHECK( + version_str == map.end() || + (absl::SimpleAtoi(version_str->second, &version) && version >= 0)) + << "A composite call op must have frontend attributes with a " + "composite.version whose value is a non-negative integer but " + "got: " + << version_str->second; + + instruction = CreateCompositeCall( + shape, all_operands(), + computation_map.at(proto.called_computation_ids()[0]), name->second, + attributes == map.end() ? "{}" : attributes->second, version); + instruction->set_output_to_operand_aliasing( + output_to_operand_aliasing()); + } else { + instruction = std::make_unique( + shape, all_operands(), + computation_map.at(proto.called_computation_ids()[0])); + instruction->set_output_to_operand_aliasing( + output_to_operand_aliasing()); + } break; } default: { @@ -1182,6 +1215,9 @@ absl::StatusOr> HloInstruction::CreateFromProto( for (const int64_t computation_id : proto.called_computation_ids()) { instruction->AppendComputation(computation_map.at(computation_id)); } + if (instruction->opcode() == HloOpcode::kWhile) { + instruction->while_body()->SetWhileCallInstruction(instruction.get()); + } TF_RET_CHECK(!proto.has_precision_config()) << instruction->opcode() << proto.DebugString(); @@ -1225,6 +1261,19 @@ absl::StatusOr> HloInstruction::CreateFromProto( instruction->set_statistics_viz(proto.statistics_viz()); } + if (proto.has_original_value()) { + const xla::OriginalValueProto& original_value_proto = + proto.original_value(); + auto original_value = std::make_shared(shape); + + for (const auto& leaf : original_value_proto.leaves()) { + *original_value->mutable_element(ShapeIndex(leaf.leaf_shape_index())) = { + leaf.instruction_name(), ShapeIndex(leaf.shape_index())}; + } + + instruction->set_original_value(original_value); + } + return std::move(instruction); } @@ -2237,6 +2286,27 @@ bool HloInstruction::HasSideEffect() const { return std::make_unique(shape, operands, computation); } +/* static */ std::unique_ptr +HloInstruction::CreateCompositeCall(const Shape& shape, + HloInstruction* decomposition_root, + const std::string& name, + const std::string& attributes, + int64_t version) { + return std::make_unique(shape, decomposition_root, name, + attributes, version); +} + +/* static */ std::unique_ptr +HloInstruction::CreateCompositeCall(const Shape& shape, + absl::Span operands, + HloComputation* decomposition, + const std::string& name, + const std::string& attributes, + int64_t version) { + return std::make_unique(shape, operands, decomposition, + name, attributes, version); +} + /* static */ std::unique_ptr HloInstruction::CreateCustomCall( const Shape& shape, absl::Span operands, absl::string_view custom_call_target, std::string opaque, @@ -2544,6 +2614,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateWhile(shape, while_condition(), while_body(), new_operands[0]); + // Repoint the while body back at the original while instruction. + // If a context was passed, the body will be cloned and the clone will + // point to the copied instruction. + while_body()->SetWhileCallInstruction(const_cast(this)); break; case HloOpcode::kConditional: CHECK_EQ(new_operands.size(), branch_count() + 1); @@ -2587,6 +2661,9 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( ? context->module()->DeepCloneComputation(callee, context) : callee; }); + if (opcode() == HloOpcode::kWhile) { + clone->while_body()->SetWhileCallInstruction(clone.get()); + } } if (!suffix.empty()) { @@ -3603,6 +3680,12 @@ void HloInstruction::PrintWithCanonicalNameMap( }); PrintExtraAttributes(attr_printer, options); + if (original_value_) { + printer->Append(", original_value={"); + printer->Append(OriginalValueToString(*original_value())); + printer->Append("}"); + } + if (options.print_metadata() && (!metadata_->op_type().empty() || !metadata_->op_name().empty() || !metadata_->source_file().empty() || @@ -3614,6 +3697,13 @@ void HloInstruction::PrintWithCanonicalNameMap( } if (options.print_backend_config() && !backend_config_.empty()) { absl::string_view config = backend_config_.GetRawString(); + std::string sorted_config; + if (options.sort_backend_config()) { + // Use `value_or` below, because the backend config string isn't + // guaranteed to be a JSON string. + sorted_config = SortJson(config).value_or(std::string(config)); + config = sorted_config; + } printer->Append(", backend_config="); // In the common case that the backend-config is valid-ish JSON, the parser // doesn't need it delimited by quotes, so we can print it without @@ -3761,6 +3851,10 @@ void HloInstruction::PrintExtraAttributes( PrintNameInternal(printer, to_apply()->name(), options); }); } + if (opcode() == HloOpcode::kCall && is_composite()) { + printer.Next( + [](Printer* printer) { printer->Append("is_composite=true"); }); + } } else if (opcode() == HloOpcode::kCustomCall) { if (!called_computations().empty()) { printer.Next([this, &options](Printer* printer) { @@ -3857,6 +3951,10 @@ void HloInstruction::PrintExtraAttributes( to_apply()->Print(printer, new_options); }); } + if (opcode() == HloOpcode::kCall && is_composite()) { + printer.Next( + [](Printer* printer) { printer->Append("is_composite=true"); }); + } break; default: if (!called_computations().empty()) { @@ -3879,13 +3977,18 @@ void HloInstruction::PrintExtraAttributes( sharding().Print(printer, options.print_metadata()); }); } - if (!rare()->frontend_attributes.map().empty()) { + if (!frontend_attributes().map().empty()) { printer.Next([this](Printer* printer) { AppendCat(printer, "frontend_attributes=", - FrontendAttributesToString(rare()->frontend_attributes)); + FrontendAttributesToString(frontend_attributes())); }); } + if (opcode() != HloOpcode::kCall) { + CHECK(!is_composite()) + << "Only kCall instructions should have is_composite set"; + } + if (options.print_control_dependencies() && !control_predecessors().empty()) { printer.Next([this, &options](Printer* printer) { printer->Append("control-predecessors={"); @@ -3931,6 +4034,23 @@ std::vector HloInstruction::ExtraAttributesToString( return std::move(multi_string_printer).ConsumeStrings(); } +std::string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes) { + std::vector> sorted_attributes( + frontend_attributes.map().begin(), frontend_attributes.map().end()); + absl::c_sort(sorted_attributes); + const auto formatter = [](std::string* out, + const std::pair& item) { + if (LexesAsJsonDict(item.second)) { + absl::StrAppend(out, item.first, "=", item.second); + } else { + absl::StrAppend(out, item.first, "=\"", item.second, "\""); + } + }; + return absl::StrFormat("{%s}", + absl::StrJoin(sorted_attributes, ",", formatter)); +} + std::string HloInstruction::ToShortString() const { return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(", StrJoin(operands_, ", ", @@ -3969,9 +4089,27 @@ HloInstructionProto HloInstruction::ToProto() const { } *proto.mutable_frontend_attributes() = frontend_attributes(); + proto.set_is_composite(is_composite()); *proto.mutable_statistics_viz() = statistics_viz(); + if (original_value_) { + xla::OriginalValueProto* original_value_proto = + proto.mutable_original_value(); + for (const auto& leaf : original_value_->leaves()) { + OriginalArrayProto* original_array_proto = + original_value_proto->add_leaves(); + for (const auto& index : leaf.first) { + original_array_proto->add_leaf_shape_index(index); + } + *original_array_proto->mutable_instruction_name() = + leaf.second->instruction_name; + for (const auto& index : leaf.second->shape_index) { + original_array_proto->add_shape_index(index); + } + } + } + return proto; } @@ -5479,4 +5617,13 @@ void HloInstruction::set_output_to_operand_aliasing( std::move(aliasing)); } +std::shared_ptr HloInstruction::original_value() const { + return original_value_; +} + +void HloInstruction::set_original_value( + std::shared_ptr original_value) { + original_value_ = original_value; +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_instruction.h b/third_party/xla/xla/hlo/ir/hlo_instruction.h index 337e9ff534eb84..a98f9963b9c2d4 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instruction.h +++ b/third_party/xla/xla/hlo/ir/hlo_instruction.h @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -52,6 +53,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/ptrvec.h" #include "xla/layout.h" @@ -97,6 +99,7 @@ class HloPrintOptions { print_metadata_(true), print_metadata_only_op_name_(false), print_backend_config_(true), + sort_backend_config_(false), print_infeed_outfeed_config_(true), compact_operands_(false), include_layout_in_shapes_(true), @@ -217,6 +220,14 @@ class HloPrintOptions { return *this; } + // If true, will attempt to sort the backend config's json representation + // before printing it. If the backend config is a raw string that is not json, + // it will be printed as is, without sorting. + HloPrintOptions& set_sort_backend_config(bool value) { + sort_backend_config_ = value; + return *this; + } + // If true, infeed_config and outfeed_config will be printed. HloPrintOptions& set_print_infeed_outfeed_config(bool value) { print_infeed_outfeed_config_ = value; @@ -381,6 +392,7 @@ class HloPrintOptions { return print_metadata_only_op_name_; } bool print_backend_config() const { return print_backend_config_; } + bool sort_backend_config() const { return sort_backend_config_; } bool print_infeed_outfeed_config() const { return print_infeed_outfeed_config_; } @@ -421,6 +433,7 @@ class HloPrintOptions { bool print_metadata_; bool print_metadata_only_op_name_; bool print_backend_config_; + bool sort_backend_config_; bool print_infeed_outfeed_config_; bool compact_operands_; bool include_layout_in_shapes_; @@ -1350,6 +1363,17 @@ class HloInstruction { const Shape& shape, absl::Span operands, HloComputation* computation); + // Creates a composite call instruction that applies the given computation on + // the given operands. "shape" is the resultant shape. + static std::unique_ptr CreateCompositeCall( + const Shape& shape, HloInstruction* decomposition_root, + const std::string& name, const std::string& attributes, int64_t version); + + static std::unique_ptr CreateCompositeCall( + const Shape& shape, absl::Span operands, + HloComputation* decomposition, const std::string& name, + const std::string& attributes, int64_t version); + // Creates a custom call instruction that applies the given custom call target // to the given operands. "opaque" can be an arbitrary string with a // backend-specific interpretation. "shape" is the resultant shape. @@ -2092,6 +2116,9 @@ class HloInstruction { mutable_rare()->frontend_attributes = std::move(frontend_attributes); } + // Appends the given frontend attributes to the existing ones. If existing + // frontend attributes are empty, then create it and set it to the provided + // one. void add_frontend_attributes(FrontendAttributes frontend_attributes) { if (!frontend_attributes.map().empty()) { mutable_rare()->frontend_attributes.mutable_map()->insert( @@ -2099,10 +2126,25 @@ class HloInstruction { } } + bool has_frontend_attributes() const { + return has_rare() && !rare()->frontend_attributes.map().empty(); + } + const FrontendAttributes& frontend_attributes() const { return rare()->frontend_attributes; } + void set_is_composite(bool is_composite) { + if (!has_rare() && !is_composite) { + return; + } + mutable_rare()->is_composite = is_composite; + } + + // Return the is_composite attribute. This attribute is only relevant for + // kCall instructions used as a Composite op. + bool is_composite() const { return has_rare() && rare()->is_composite; } + void add_single_statistic(Statistic statistic) { *mutable_rare()->statistics_viz.add_statistics() = std::move(statistic); } @@ -2198,8 +2240,8 @@ class HloInstruction { void set_metadata_preserve_layout(bool preserve_layout) { metadata_->set_preserve_layout(preserve_layout); } - void set_metadata_scheduling_name(const std::string& name) { - metadata_->set_scheduling_name(name); + void set_metadata_scheduling_name(absl::string_view name) { + metadata_->set_scheduling_name(std::string(name)); } const OpMetadata& metadata() const { return *metadata_; } @@ -2549,6 +2591,9 @@ class HloInstruction { HloInstruction(const HloInstruction&) = delete; HloInstruction& operator=(const HloInstruction&) = delete; + std::shared_ptr original_value() const; + void set_original_value(std::shared_ptr original_value); + protected: // Internal constructor for a given opcode/shape, other fields must be filled // by factory methods. @@ -2684,6 +2729,9 @@ class HloInstruction { // z' = const(20), frontend_attributes={?} FrontendAttributes frontend_attributes; + // Used by kCall to determine if the Call instruction is a composite. + bool is_composite; + // Used to render an HLO graph when tracking the propagation desired values // through it. StatisticsViz statistics_viz; @@ -2799,6 +2847,10 @@ class HloInstruction { // String identifier for instruction. std::string name_; + // Original value this instruction corresponds to in the unoptimized HLO + // graph. + std::shared_ptr original_value_ = nullptr; + // Metadata for debugging. Allocate it on heap, so that it does not increase // the memory footprint of HloInstruction. std::unique_ptr metadata_ = std::make_unique(); @@ -2819,6 +2871,14 @@ absl::StatusOr StringToFusionKind( // Custom (de)stringification functions for protos that live inside // HloInstruction. std::string PaddingConfigToString(const PaddingConfig& padding); + +// Returns string representation of frontend attributes. +// Frontend attribute is a list of attribute= pairs where value is either +// a "string" or a JSON-like dict surrounded in {}. Similar to custom_call +// backend config, this can be used to store stringified MLIR-dictionaries with +// pretty printing. +std::string FrontendAttributesToString( + const FrontendAttributes& frontend_attributes); std::string StatisticsVizToString(const StatisticsViz& statistics_viz); std::string RandomAlgorithmToString(const RandomAlgorithm& algorithm); std::string RandomDistributionToString(const RandomDistribution& distribution); diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.cc b/third_party/xla/xla/hlo/ir/hlo_instructions.cc index d9801cf3681963..cff0907ba534d5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.cc +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.cc @@ -1877,6 +1877,35 @@ HloCallableInstruction::HloCallableInstruction( } } +HloCallableInstruction::HloCallableInstruction(HloOpcode opcode, + const Shape& shape, + const std::string& name, + const std::string& attributes, + int64_t version) + : HloInstruction(opcode, shape) { + auto frontend_attributes = + BuildFrontendAttributesForComposite(name, attributes, version); + add_frontend_attributes(frontend_attributes); + set_is_composite(true); +} + +HloCallableInstruction::HloCallableInstruction( + HloOpcode opcode, const Shape& shape, + absl::Span operands, HloComputation* decomposition, + const std::string& name, const std::string& attributes, int64_t version) + : HloInstruction(opcode, shape) { + for (auto operand : operands) { + AppendOperand(operand); + } + SetAndSanitizeName(HloOpcodeString(opcode)); + AppendComputation(decomposition); + + auto frontend_attributes = + BuildFrontendAttributesForComposite(name, attributes, version); + add_frontend_attributes(frontend_attributes); + set_is_composite(true); +} + HloCallableInstruction::~HloCallableInstruction() { ClearCalledComputations(); } HloComputation* HloCallableInstruction::called_computation() const { @@ -1924,7 +1953,7 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation( return u->opcode() == HloOpcode::kGetTupleElement; }); if (called_computations().empty()) { - // New fusion instruction. It should not be a multioutput instruction. + // New fusion instruction. It should not be a multi-output instruction. CHECK(!add_output); auto builder = HloComputation::Builder(default_called_computation_name()); builder.AddInstruction(instruction_to_append->Clone(/*suffix=*/"")); @@ -2552,6 +2581,47 @@ HloCallInstruction::HloCallInstruction( : HloCallableInstruction(HloOpcode::kCall, shape, operands, called_computation) {} +HloCallInstruction::HloCallInstruction(const Shape& shape, + HloInstruction* decomposition_root, + const std::string& name, + const std::string& attributes, + int64_t version) + : HloCallableInstruction(HloOpcode::kCall, shape, name, attributes, + version) { + CHECK(decomposition_root != nullptr); + SetAndSanitizeName(HloOpcodeString(opcode())); + + FrontendAttributes frontend_attributes; + frontend_attributes.mutable_map()->insert({"composite.name", name}); + frontend_attributes.mutable_map()->insert( + {"composite.attributes", attributes}); + frontend_attributes.mutable_map()->insert( + {"composite.version", std::to_string(version)}); + + add_frontend_attributes(frontend_attributes); + set_is_composite(true); + set_parent(decomposition_root->parent()); + set_metadata(decomposition_root->metadata()); + CloneAndAppendInstructionIntoCalledComputation(decomposition_root); +} + +HloCallInstruction::HloCallInstruction( + const Shape& shape, absl::Span operands, + HloComputation* decomposition, const std::string& name, + const std::string& attributes, int64_t version) + : HloCallableInstruction(HloOpcode::kCall, shape, operands, decomposition, + name, attributes, version) { + FrontendAttributes frontend_attributes; + frontend_attributes.mutable_map()->insert({"composite.name", name}); + frontend_attributes.mutable_map()->insert( + {"composite.attributes", attributes}); + frontend_attributes.mutable_map()->insert( + {"composite.version", std::to_string(version)}); + + add_frontend_attributes(frontend_attributes); + set_is_composite(true); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, absl::Span parameters) @@ -3500,13 +3570,23 @@ HloGatherInstruction::HloGatherInstruction( AppendJoin(printer, dim_numbers.collapsed_slice_dims(), ","); printer->Append("}, start_index_map={"); AppendJoin(printer, dim_numbers.start_index_map(), ","); + if (dim_numbers.operand_batching_dims_size()) { + printer->Append("}, operand_batching_dims={"); + AppendJoin(printer, dim_numbers.operand_batching_dims(), ","); + } + if (dim_numbers.start_indices_batching_dims_size()) { + printer->Append("}, start_indices_batching_dims={"); + AppendJoin(printer, dim_numbers.start_indices_batching_dims(), ","); + } AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim()); } /* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( absl::Span offset_dims, absl::Span collapsed_slice_dims, - absl::Span start_index_map, int64_t index_vector_dim) { + absl::Span start_index_map, int64_t index_vector_dim, + absl::Span operand_batching_dims, + absl::Span start_indices_batching_dims) { GatherDimensionNumbers gather_dim_numbers; for (int64_t output_window_dim : offset_dims) { gather_dim_numbers.add_offset_dims(output_window_dim); @@ -3517,6 +3597,13 @@ HloGatherInstruction::HloGatherInstruction( for (int64_t gather_dim_to_input_dim : start_index_map) { gather_dim_numbers.add_start_index_map(gather_dim_to_input_dim); } + for (int64_t operand_batching_dim : operand_batching_dims) { + gather_dim_numbers.add_operand_batching_dims(operand_batching_dim); + } + for (int64_t start_indices_batching_dim : start_indices_batching_dims) { + gather_dim_numbers.add_start_indices_batching_dims( + start_indices_batching_dim); + } gather_dim_numbers.set_index_vector_dim(index_vector_dim); return gather_dim_numbers; @@ -3601,6 +3688,14 @@ HloScatterInstruction::HloScatterInstruction( AppendJoin(printer, dim_numbers.inserted_window_dims(), ","); printer->Append("}, scatter_dims_to_operand_dims={"); AppendJoin(printer, dim_numbers.scatter_dims_to_operand_dims(), ","); + if (dim_numbers.input_batching_dims_size()) { + printer->Append("}, input_batching_dims={"); + AppendJoin(printer, dim_numbers.input_batching_dims(), ","); + } + if (dim_numbers.scatter_indices_batching_dims_size()) { + printer->Append("}, scatter_indices_batching_dims={"); + AppendJoin(printer, dim_numbers.scatter_indices_batching_dims(), ","); + } AppendCat(printer, "}, index_vector_dim=", dim_numbers.index_vector_dim()); } @@ -3609,7 +3704,8 @@ HloScatterInstruction::MakeScatterDimNumbers( absl::Span update_window_dims, absl::Span inserted_window_dims, absl::Span scatter_dims_to_operand_dims, - int64_t index_vector_dim) { + int64_t index_vector_dim, absl::Span input_batching_dims, + absl::Span scatter_indices_batching_dims) { ScatterDimensionNumbers scatter_dim_numbers; for (int64_t update_window_dim : update_window_dims) { scatter_dim_numbers.add_update_window_dims(update_window_dim); @@ -3621,6 +3717,13 @@ HloScatterInstruction::MakeScatterDimNumbers( scatter_dim_numbers.add_scatter_dims_to_operand_dims( scatter_dim_to_operand_dim); } + for (int64_t input_batching_dim : input_batching_dims) { + scatter_dim_numbers.add_input_batching_dims(input_batching_dim); + } + for (int64_t scatter_indices_batching_dim : scatter_indices_batching_dims) { + scatter_dim_numbers.add_scatter_indices_batching_dims( + scatter_indices_batching_dim); + } scatter_dim_numbers.set_index_vector_dim(index_vector_dim); return scatter_dim_numbers; } diff --git a/third_party/xla/xla/hlo/ir/hlo_instructions.h b/third_party/xla/xla/hlo/ir/hlo_instructions.h index b0e337adfc5061..c0f03248dbf772 100644 --- a/third_party/xla/xla/hlo/ir/hlo_instructions.h +++ b/third_party/xla/xla/hlo/ir/hlo_instructions.h @@ -19,13 +19,13 @@ limitations under the License. #define XLA_HLO_IR_HLO_INSTRUCTIONS_H_ #include -#include #include #include #include #include #include +#include "absl/base/attributes.h" #include "absl/container/inlined_vector.h" #include "absl/functional/function_ref.h" #include "absl/status/status.h" @@ -38,7 +38,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_domain_metadata.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/iterator_util.h" #include "xla/layout.h" #include "xla/literal.h" #include "xla/printer.h" @@ -1343,6 +1342,15 @@ class HloCallableInstruction : public HloInstruction { absl::Span operands, absl::Span called_computations); + HloCallableInstruction(HloOpcode opcode, const Shape& shape, + const std::string& name, const std::string& attributes, + int64_t version); + + HloCallableInstruction(HloOpcode opcode, const Shape& shape, + absl::Span operands, + HloComputation* decomposition, const std::string& name, + const std::string& attributes, int64_t version); + ~HloCallableInstruction() override; // Adds a new operand to the callable instruction. @@ -1402,6 +1410,21 @@ class HloCallableInstruction : public HloInstruction { output_to_operand_aliasing_ = std::move(aliasing); } + FrontendAttributes BuildFrontendAttributesForComposite( + const std::string& name, + std::optional attributes = std::nullopt, + std::optional version = std::nullopt) { + FrontendAttributes frontend_attributes; + frontend_attributes.mutable_map()->insert({"composite.name", name}); + frontend_attributes.mutable_map()->insert( + {"composite.attributes", + attributes.has_value() ? std::string(*attributes) : "{}"}); + frontend_attributes.mutable_map()->insert( + {"composite.version", + version.has_value() ? std::to_string(*version) : "0"}); + return frontend_attributes; + } + protected: // Returns the default called computation name. virtual std::string default_called_computation_name() const = 0; @@ -1450,7 +1473,7 @@ class HloFusionInstruction : public HloCallableInstruction { void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge); // Merges the fused instructions from instruction_to_merge into the fused - // instruction set of 'this' and generates multioutput fusion instructions. + // instruction set of 'this' and generates multi-output fusion instructions. // All the users of instruction_to_merge will be redirected to 'this' // instruction. instruction_to_merge will be removed from its parent // computation. @@ -1555,6 +1578,15 @@ class HloCallInstruction : public HloCallableInstruction { absl::Span operands, HloComputation* called_computation); + HloCallInstruction(const Shape& shape, HloInstruction* decomposition_root, + const std::string& name, const std::string& attributes, + int64_t version); + + HloCallInstruction(const Shape& shape, + absl::Span operands, + HloComputation* decomposition, const std::string& name, + const std::string& attributes, int64_t version); + static bool ClassOf(const HloInstruction* hlo) { return hlo->opcode() == HloOpcode::kCall; } @@ -2313,7 +2345,9 @@ class HloGatherInstruction : public HloInstruction { static GatherDimensionNumbers MakeGatherDimNumbers( absl::Span offset_dims, absl::Span collapsed_slice_dims, - absl::Span start_index_map, int64_t index_vector_dim); + absl::Span start_index_map, int64_t index_vector_dim, + absl::Span operand_batching_dims = {}, + absl::Span start_indices_batching_dims = {}); // Returns the dump string of the given gather dimension numbers. static std::string GatherDimensionNumbersToString( const GatherDimensionNumbers& dim_numbers); @@ -2378,7 +2412,9 @@ class HloScatterInstruction : public HloInstruction { absl::Span update_window_dims, absl::Span inserted_window_dims, absl::Span scatter_dims_to_operand_dims, - int64_t index_vector_dim); + int64_t index_vector_dim, + absl::Span input_batching_dims = {}, + absl::Span scatter_indices_batching_dims = {}); // Returns the dump string of the given scatter dimension numbers. static std::string ScatterDimensionNumbersToString( const ScatterDimensionNumbers& dim_numbers); diff --git a/third_party/xla/xla/hlo/ir/hlo_module.cc b/third_party/xla/xla/hlo/ir/hlo_module.cc index 0711d49ef63e16..cc8dda9a321ee5 100644 --- a/third_party/xla/xla/hlo/ir/hlo_module.cc +++ b/third_party/xla/xla/hlo/ir/hlo_module.cc @@ -17,8 +17,8 @@ limitations under the License. #include #include +#include #include -#include #include #include #include @@ -30,24 +30,36 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/cord.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_frontend_attributes.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" +#include "xla/hlo/ir/hlo_sharding.h" #include "xla/map_util.h" #include "xla/printer.h" #include "xla/service/compilation_environments.h" +#include "xla/service/computation_layout.h" #include "xla/service/computation_placer.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/mapped_ptr_container_sorter.h" +#include "xla/service/name_uniquer.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/gtl/map_util.h" +#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" @@ -405,8 +417,8 @@ void HloModule::Print(Printer* printer, const HloPrintOptions& options) const { ? MakeComputationSorted() : MakeComputationPostOrder(); for (const HloComputation* computation : computations) { - // Don't print async computations when the sytax sugar is enabled since that - // is redundant information. + // Don't print async computations when the syntax sugar is enabled since + // that is redundant information. if (options.syntax_sugar_async_ops() && computation->IsAsyncComputation() && computation->CanExpandIntoSingleInstruction()) { continue; @@ -848,7 +860,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( outlined_instruction); // Mark instruction_to_outline an output if it is used outside the - // subcomputation or is the output of the original computation (i.e. used + // sub-computation or is the output of the original computation (i.e. used // externally). if (instruction_to_outline->user_count() == 0 || IsUsedOutsideSubcomputation(*instruction_to_outline, @@ -917,7 +929,7 @@ std::vector HloModule::MakeComputationPostOrder( if (computations_.empty()) { return {}; } - // First determine all root computations by building a set of nonroot + // First determine all root computations by building a set of non-root // computations (computations which are called by an instruction in the // module). absl::flat_hash_set nonroot_computations; diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.cc b/third_party/xla/xla/hlo/ir/hlo_original_value.cc new file mode 100644 index 00000000000000..789978d74cbf39 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/hlo/ir/hlo_original_value.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { + +std::string OriginalValueToStringHelper(const OriginalValue& original_value, + const Shape& shape, + std::vector& shape_index) { + std::string result; + if (shape.IsTuple()) { + if (shape.tuple_shapes().empty()) { + return "()"; + } + absl::StrAppend(&result, "("); + shape_index.push_back(0); + absl::StrAppend(&result, + OriginalValueToStringHelper( + original_value, shape.tuple_shapes(0), shape_index)); + shape_index.pop_back(); + for (int64_t i = 1; i < shape.tuple_shapes().size(); ++i) { + absl::StrAppend(&result, ", "); + shape_index.push_back(i); + absl::StrAppend(&result, + OriginalValueToStringHelper( + original_value, shape.tuple_shapes(i), shape_index)); + shape_index.pop_back(); + } + absl::StrAppend(&result, ")"); + return result; + } + + const auto& leaf = original_value.element(shape_index); + absl::StrAppend( + &result, "{", "\"", leaf->instruction_name, "\"", + (leaf->shape_index.empty() ? "" : " " + leaf->shape_index.ToString()), + "}"); + return result; +} + +std::string OriginalValueToString(const OriginalValue& original_value) { + std::vector shape_index; + return OriginalValueToStringHelper(original_value, original_value.shape(), + shape_index); +} +} // namespace xla diff --git a/third_party/xla/xla/hlo/ir/hlo_original_value.h b/third_party/xla/xla/hlo/ir/hlo_original_value.h new file mode 100644 index 00000000000000..a77bc8a13460c7 --- /dev/null +++ b/third_party/xla/xla/hlo/ir/hlo_original_value.h @@ -0,0 +1,37 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ +#define XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ + +#include +#include + +#include "xla/shape_tree.h" +#include "xla/shape_util.h" + +namespace xla { +// Stores information of original values. +struct OriginalArray { + std::string instruction_name; + ShapeIndex shape_index; +}; + +using OriginalValue = ShapeTree>; + +std::string OriginalValueToString(const OriginalValue& original_value); +} // namespace xla + +#endif // XLA_HLO_IR_HLO_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/hlo/ir/hlo_sharding.h b/third_party/xla/xla/hlo/ir/hlo_sharding.h index a15d3b33e4c44f..5a7c49e9265899 100644 --- a/third_party/xla/xla/hlo/ir/hlo_sharding.h +++ b/third_party/xla/xla/hlo/ir/hlo_sharding.h @@ -138,6 +138,11 @@ class HloSharding { static HloSharding Tuple(const Shape& tuple_shape, absl::Span shardings); + // Creates a new sharding for a flat tuple type. + static HloSharding FlatTuple(std::vector sub_shardings) { + return HloSharding(std::move(sub_shardings)); + } + // Creates a new sharding for a tuple type, with a single input sharding // repeated on each leaf. static HloSharding SingleTuple(const Shape& tuple_shape, diff --git a/third_party/xla/xla/hlo/transforms/BUILD b/third_party/xla/xla/hlo/transforms/BUILD index 93ffe889d7c187..a3aae57cce70d4 100644 --- a/third_party/xla/xla/hlo/transforms/BUILD +++ b/third_party/xla/xla/hlo/transforms/BUILD @@ -49,7 +49,7 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc b/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc index 58a25ef26d0aac..c7ebf8459502e8 100644 --- a/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc +++ b/third_party/xla/xla/hlo/transforms/hlo_constant_splitter_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/utils/BUILD b/third_party/xla/xla/hlo/utils/BUILD index 63db0ae7d5ef6e..8f20f63bfc9b2a 100644 --- a/third_party/xla/xla/hlo/utils/BUILD +++ b/third_party/xla/xla/hlo/utils/BUILD @@ -54,8 +54,8 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -158,6 +158,8 @@ cc_library( "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", ], ) diff --git a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc index b4155b103cfc11..64e4ab5ee37d62 100644 --- a/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_live_range_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_value.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/hlo/utils/hlo_query.cc b/third_party/xla/xla/hlo/utils/hlo_query.cc index 85e41fff68a149..147f54822aef97 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query.cc @@ -17,8 +17,10 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -268,5 +270,46 @@ HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, return gte; } +HloComputation* FindComputation(HloModule* module, absl::string_view name) { + auto computations = module->computations(); + auto it = absl::c_find_if( + computations, [&](HloComputation* c) { return c->name() == name; }); + if (it == computations.end()) { + return nullptr; + } + return *it; +} + +std::pair FindFirstInstruction( + const HloComputation* computation, absl::string_view name) { + int current_index = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->name() == name) { + return {instruction, current_index}; + break; + } + current_index++; + } + return {nullptr, -1}; +} + +std::pair FindFirstInstruction( + const HloComputation* computation, HloOpcode opcode) { + int current_index = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == opcode) { + return {instruction, current_index}; + break; + } + current_index++; + } + return {nullptr, -1}; +} + +bool IsBeforeInComputation(const HloComputation* computation, + absl::string_view inst1, absl::string_view inst2) { + return FindFirstInstruction(computation, inst1).second < + FindFirstInstruction(computation, inst2).second; +} } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query.h b/third_party/xla/xla/hlo/utils/hlo_query.h index cda265362d452b..ec5c0b25804d10 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query.h +++ b/third_party/xla/xla/hlo/utils/hlo_query.h @@ -17,8 +17,10 @@ limitations under the License. #define XLA_HLO_UTILS_HLO_QUERY_H_ #include +#include #include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" @@ -152,6 +154,25 @@ bool HasX64TransformedHostTransfer(const HloModule& module); HloInstruction* GetUniqueGteInstruction(const HloInstruction* operand, int64_t index); +// Gets the computation from the given module with the given name. +HloComputation* FindComputation(HloModule* module, absl::string_view name); +// Gets the first instruction and its index from the given computation with the +// given instruction name. The function returns {nullptr, -1} if the instruction +// cannot be found. +std::pair FindFirstInstruction( + const HloComputation* computation, absl::string_view name); +// Gets the first instruction and its index from the given computation with the +// given instruction opcode. The function returns {nullptr, -1} if the +// instruction cannot be found. +std::pair FindFirstInstruction( + const HloComputation* computation, HloOpcode opcode); + +// Check that one instruction comes before another one for a given computation. +// The function returns true if the first instruction comes before the second +// one, and false otherwise. This is useful for partial checks on the +// transformed IR without going through a full file check. +bool IsBeforeInComputation(const HloComputation* computation, + absl::string_view inst1, absl::string_view inst2); } // namespace hlo_query } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_query_test.cc b/third_party/xla/xla/hlo/utils/hlo_query_test.cc index acefa21aa9e2f4..e4dad1007fa685 100644 --- a/third_party/xla/xla/hlo/utils/hlo_query_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_query_test.cc @@ -40,6 +40,14 @@ int CountInstructions(Hlo& module, HloOpcode opcode) { return counter; } +constexpr absl::string_view kConstantAdditionHloString = R"( +HloModule test +ENTRY main { + zero = f32[] constant(0) + five = f32[] constant(5) + ROOT out = f32[] add(zero, five) +})"; + TEST_F(HloQueryTest, GetInstructionWithOpCodeReturnsMatchingInstructionForModule) { constexpr absl::string_view kHloString = R"( @@ -132,5 +140,66 @@ TEST_F(HloQueryTest, GetUniqueGteTest) { EXPECT_EQ(gte2, nullptr); } +TEST_F(HloQueryTest, FindComputationTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + EXPECT_NE(hlo_query::FindComputation(module.get(), "main"), nullptr); + EXPECT_EQ(hlo_query::FindComputation(module.get(), "foo"), nullptr); +} + +TEST_F(HloQueryTest, FindInstructionUsingNameTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "zero").first, nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "five").first, nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction(main, "out").first, nullptr); + EXPECT_EQ(hlo_query::FindFirstInstruction(main, "foo").first, nullptr); +} + +TEST_F(HloQueryTest, FindInstructionUsingOpcodeTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE( + hlo_query::FindFirstInstruction(main, StringToHloOpcode("add").value()) + .first, + nullptr); + EXPECT_NE(hlo_query::FindFirstInstruction( + main, StringToHloOpcode("constant").value()) + .first, + nullptr); + EXPECT_EQ( + hlo_query::FindFirstInstruction(main, StringToHloOpcode("select").value()) + .first, + nullptr); +} + +TEST_F(HloQueryTest, FindInstructionDoesNotExistTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_NE(main, nullptr); + auto find_beef = hlo_query::FindFirstInstruction(main, "deadbeef"); + auto find_nothing = hlo_query::FindFirstInstruction(main, ""); + EXPECT_EQ(find_beef.first, nullptr); + EXPECT_EQ(find_beef.second, -1); + EXPECT_EQ(find_nothing.first, nullptr); + EXPECT_EQ(find_nothing.second, -1); +} + +TEST_F(HloQueryTest, IsBeforeInComputationTest) { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnUnverifiedModule(kConstantAdditionHloString)); + const HloComputation* main = hlo_query::FindComputation(module.get(), "main"); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "zero", "five")); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(main, "five", "out")); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc index 20213b587ec37f..d6fe4946fbf237 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.cc @@ -786,8 +786,11 @@ std::optional ReshapeSharding(const Shape& source_shape, sharding_tile_dims_stack.pop_back(); } - if (s_partitions > 1 && s_size % s_partitions == 0 && - t_size % s_partitions == 0) { + if (s_size == t_size) { + // Same dimension. + append_sharding_dim(s_partitions); + } else if (s_partitions > 1 && s_size % s_partitions == 0 && + t_size % s_partitions == 0) { // If s_partitions evenly divides both s_size and t_size, we can add this // sharding dim and work on shard sized shapes in the next iteration. source_dims_stack.push_back(s_size / s_partitions); @@ -795,9 +798,6 @@ std::optional ReshapeSharding(const Shape& source_shape, sharding_tile_dims_stack.push_back(1); append_sharding_dim(s_partitions); inplace_add_sharding_dim = true; - } else if (s_size == t_size) { - // Same dimension. - append_sharding_dim(s_partitions); } else if (t_size == 1) { // Trivial dimension added. append_sharding_dim(1); @@ -2118,7 +2118,7 @@ std::optional TransposeShardingWithCollapsedDims( << "Sharding transpose should not move subgroup dims before data dims."; perm[src_to_tgt[i] - skipped_tgt_dims + skipped_src_dims] = i; } - auto tgt_sharding = hlo_sharding_util::TransposeSharding(source, perm); + auto tgt_sharding = TransposeSharding(source, perm); DimensionVector tgt_tiles(tgt_to_src.size(), 1); for (int64_t i = 0; i < tgt_tiles.size(); ++i) { if (tgt_to_src[i] >= 0) { @@ -2247,7 +2247,6 @@ std::optional GetGatherScatterBatchParallelDims( // %indices = concatenate(..., %iota.1, ...) // ... = gather(..., %indices) // is common for tf.reverse_sequence and would match this case. - absl::InlinedVector iotas; const int num_indices = index_map.size(); std::vector index_parallel_in_dim(num_indices, -1); @@ -2508,8 +2507,8 @@ HloSharding InferGatherScatterParallelShardingFromOperandSharding( operand_sharding.tile_assignment().dim(operand_idx); } HloSharding replicate_non_parallel_dims = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - operand_sharding, operand_non_parallel_dims); + PartiallyReplicateTiledShardingOnDims(operand_sharding, + operand_non_parallel_dims); if (replicate_non_parallel_dims.IsTileMaximal()) { return replicate_non_parallel_dims; } @@ -2733,21 +2732,21 @@ GroupedSharding GroupShardingOnReplicatedDim( // 2. Try borrow dimensions from replicable_dims in order, and group sharding. if (sharding.IsTiled()) { - int64_t max_replicable_dimensions = + const int64_t reps_on_last_tile_dim = sharding.ReplicateOnLastTileDim() ? sharding.tile_assignment().dimensions().back() : 1; - max_replicable_dimensions = absl::c_accumulate( - replicable_dims, max_replicable_dimensions, + + const int64_t max_replicable_dimensions = absl::c_accumulate( + replicable_dims, reps_on_last_tile_dim, [&](int64_t product, int64_t dim) { return product * sharding.tile_assignment().dim(dim); }); - if (max_replicable_dimensions % num_groups == 0) { + + if (max_replicable_dimensions % num_groups == 0 && + num_groups % reps_on_last_tile_dim == 0) { auto tile_assignment = [&]() -> std::optional { - int dimensions_to_borrow = - num_groups / (sharding.ReplicateOnLastTileDim() - ? sharding.tile_assignment().dimensions().back() - : 1); + int dimensions_to_borrow = num_groups / reps_on_last_tile_dim; DimensionVector tile_dims( sharding.tile_assignment().dimensions().begin(), sharding.tile_assignment().dimensions().end()); @@ -3336,15 +3335,13 @@ std::optional ReturnImprovedShardingImpl( return std::nullopt; } int64_t sharding_tiles = from.NumTiles(); - if (hlo_sharding_util::MergeSharding(*to_improved, &from, - may_combine_partial_sharding)) { + if (MergeSharding(*to_improved, &from, may_combine_partial_sharding)) { // Override existing tiled sharding only when the new sharding is compatible // with the existing one. This avoids unexpected resharding when `sharding` // just has more tiles than existing sharding but they are not mergeable. if (!allow_aggressive_resharding && to_improved_shape.IsArray() && !to_improved->IsTileMaximal() && from.NumTiles() == sharding_tiles) { - if (!hlo_sharding_util::IsSubTilingOrEqualSharding(to_improved_shape, - from, *to_improved)) { + if (!IsSubTilingOrEqualSharding(to_improved_shape, from, *to_improved)) { VLOG(10) << "Not merging because of different device distribution"; VLOG(10) << "Instr sharding: " << to_improved->ToString(); VLOG(10) << "New sharding " << from.ToString(); @@ -3357,16 +3354,13 @@ std::optional ReturnImprovedShardingImpl( } HloSharding InferDotOperandSharding( - const HloInstruction* dot, int64_t operand_index, + const HloSharding* dot_sharding, const HloSharding* other_operand_sharding, + int64_t operand_index, const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, bool consider_other_operand, bool may_combine_partial_sharding) { - CHECK(dot->opcode() == HloOpcode::kDot || - dot->opcode() == HloOpcode::kConvolution); CHECK(operand_index == 0 || operand_index == 1); CHECK(dnums.conv_spatial_dims.empty()); - auto operand = dot->operand(operand_index); - auto other = dot->operand(1 - operand_index); std::vector output_dims_to_replicate; std::vector other_operand_dims_to_replicate; for (const auto& dim : operand_index == 0 ? dnums.rhs_non_contracting_dims @@ -3391,33 +3385,47 @@ HloSharding InferDotOperandSharding( other_operand_dims_to_replicate.push_back(other_dim); } } - HloSharding output_other_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - dot->sharding(), output_dims_to_replicate); - std::vector output_to_operand_dims(dot->shape().rank(), -1); - std::vector operand_to_output_dims(operand->shape().rank(), -1); - for (const auto& dim : dnums.batch_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; - } - for (const auto& dim : operand_index == 0 ? dnums.lhs_non_contracting_dims - : dnums.rhs_non_contracting_dims) { - output_to_operand_dims[dim.output] = operand_index == 0 ? dim.lhs : dim.rhs; - operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = dim.output; + int64_t operand_shape_rank = + operand_index == 0 ? dnums.lhs_shape_rank : dnums.rhs_shape_rank; + int64_t other_shape_rank = + operand_index == 0 ? dnums.rhs_shape_rank : dnums.lhs_shape_rank; + + HloSharding sharding = HloSharding::Replicate(); + + if (dot_sharding != nullptr) { + HloSharding output_other_dims_replicated = + PartiallyReplicateTiledShardingOnDims(*dot_sharding, + output_dims_to_replicate); + + std::vector output_to_operand_dims(dnums.output_shape_rank, -1); + std::vector operand_to_output_dims(operand_shape_rank, -1); + for (const auto& dim : dnums.batch_dims) { + output_to_operand_dims[dim.output] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + for (const auto& dim : operand_index == 0 + ? dnums.lhs_non_contracting_dims + : dnums.rhs_non_contracting_dims) { + output_to_operand_dims[dim.output] = + operand_index == 0 ? dim.lhs : dim.rhs; + operand_to_output_dims[operand_index == 0 ? dim.lhs : dim.rhs] = + dim.output; + } + sharding = std::move(*TransposeShardingWithCollapsedDims( + output_other_dims_replicated, output_to_operand_dims, + operand_to_output_dims)); } - auto sharding = *hlo_sharding_util::TransposeShardingWithCollapsedDims( - output_other_dims_replicated, output_to_operand_dims, - operand_to_output_dims); - if (consider_other_operand && - hlo_sharding_util::IsSpatiallyPartitioned(other)) { - auto other_operand_dims_replicated = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - other->sharding(), other_operand_dims_to_replicate); + if (consider_other_operand && other_operand_sharding != nullptr && + IsSpatiallyPartitioned(*other_operand_sharding)) { + auto other_operand_dims_replicated = PartiallyReplicateTiledShardingOnDims( + *other_operand_sharding, other_operand_dims_to_replicate); - std::vector other_to_operand_dims(other->shape().rank(), -1); - std::vector operand_to_other_dims(operand->shape().rank(), -1); + std::vector other_to_operand_dims(other_shape_rank, -1); + std::vector operand_to_other_dims(operand_shape_rank, -1); for (const auto& dim : dnums.batch_dims) { other_to_operand_dims[operand_index == 0 ? dim.rhs : dim.lhs] = operand_index == 0 ? dim.lhs : dim.rhs; @@ -3430,12 +3438,11 @@ HloSharding InferDotOperandSharding( operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = operand_index == 0 ? dim.rhs : dim.lhs; } - HloSharding sharding_from_other = - *hlo_sharding_util::TransposeShardingWithCollapsedDims( - other_operand_dims_replicated, other_to_operand_dims, - operand_to_other_dims); - if (hlo_sharding_util::MergeSharding(sharding, &sharding_from_other, - may_combine_partial_sharding)) { + HloSharding sharding_from_other = *TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { sharding = std::move(sharding_from_other); } } @@ -3443,5 +3450,20 @@ HloSharding InferDotOperandSharding( return sharding; } +HloSharding InferDotOperandSharding( + const HloInstruction* dot, int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding) { + CHECK(dot->opcode() == HloOpcode::kDot || + dot->opcode() == HloOpcode::kConvolution); + + const HloInstruction* other_operand = dot->operand(1 - operand_index); + return InferDotOperandSharding( + dot->has_sharding() ? &dot->sharding() : nullptr, + other_operand->has_sharding() ? &other_operand->sharding() : nullptr, + operand_index, dnums, consider_other_operand, + may_combine_partial_sharding); +} + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h index 1df5aebf107829..335cb6b53fe46b 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util.h +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util.h @@ -539,6 +539,14 @@ HloSharding InferDotOperandSharding( const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, bool consider_other_operand, bool may_combine_partial_sharding); +// Same as above, but takes the sharding of the dot and the other operand as +// input. +HloSharding InferDotOperandSharding( + const HloSharding* dot_sharding, const HloSharding* other_operand_sharding, + int64_t operand_index, + const dot_as_convolution_util::DotConvolutionDimsInfo& dnums, + bool consider_other_operand, bool may_combine_partial_sharding); + } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc index 9015726ffd6c86..fcbc4a4cd4bbdf 100644 --- a/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc +++ b/third_party/xla/xla/hlo/utils/hlo_sharding_util_test.cc @@ -295,6 +295,18 @@ TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne3) { EXPECT_EQ(result.value(), output_sharding); } +TEST(HloShardingUtilTest, ReshapeShardingSuffixShapeSizeOne4) { + Shape input_shape = ShapeUtil::MakeShape(F32, {4, 2, 1}); + Shape output_shape = ShapeUtil::MakeShape(F32, {4, 2}); + HloSharding input_sharding = HloSharding::IotaTile({4, 2, 4}); + HloSharding output_sharding = + HloSharding::PartialTile(TileAssignment({4, 2, 4})); + std::optional result = + ReshapeSharding(input_shape, output_shape, input_sharding); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), output_sharding); +} + TEST(HloShardingUtilTest, ReshapeShardingPrefixShapeSizeOne1) { Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 64}); Shape output_shape = ShapeUtil::MakeShape(F32, {1, 64}); @@ -1017,7 +1029,7 @@ TEST(HloShardingUtilTest, UntileShape) { using HloShardingUtilTestWithHlo = HloTestBase; -TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest) { +TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest1) { absl::string_view hlo_string = R"( HloModule module @@ -1061,6 +1073,55 @@ TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest) { } } +TEST_F(HloShardingUtilTestWithHlo, InferDotOperandShardingTest2) { + absl::string_view hlo_string = R"( + HloModule module + + ENTRY %main.7 { + %p0 = bf16[32,64,128,512] parameter(0), sharding={devices=[8,1,1,4]<=[32]} + %p1 = bf16[32,64,256,512] parameter(1), sharding={devices=[1,1,1,2,16]<=[8,2,2]T(1,0,2) last_tile_dim_replicate} + ROOT %dot.3 = bf16[32,64,128,256] dot(%p0, %p1), lhs_batch_dims={0,1}, rhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_contracting_dims={3}, sharding={devices=[2,2,2,2,2]<=[32] last_tile_dim_replicate} + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + const HloInstruction* dot = module->entry_computation()->root_instruction(); + auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(dot); + + const HloSharding& lhs_sharding = dot->operand(0)->sharding(); + const HloSharding& rhs_sharding = dot->operand(1)->sharding(); + const HloSharding& dot_sharding = dot->sharding(); + + bool may_combine_partial_sharding = true; + for (int64_t i = 0; i < 2; ++i) { + EXPECT_EQ(InferDotOperandSharding(nullptr, nullptr, i, dnums, true, + may_combine_partial_sharding), + HloSharding::Replicate()); + } + + // If the other_operand_sharding is missing (nullptr), we only infer the + // result from the result. + for (int64_t i = 0; i < 2; ++i) { + EXPECT_EQ(InferDotOperandSharding(&dot_sharding, nullptr, i, dnums, true, + may_combine_partial_sharding), + InferDotOperandSharding(dot, i, dnums, false, + may_combine_partial_sharding)); + } + + EXPECT_EQ(InferDotOperandSharding(nullptr, &rhs_sharding, 0, dnums, true, + may_combine_partial_sharding), + rhs_sharding); + EXPECT_EQ(InferDotOperandSharding(nullptr, &lhs_sharding, 1, dnums, true, + may_combine_partial_sharding), + lhs_sharding); + + EXPECT_EQ(InferDotOperandSharding(nullptr, &rhs_sharding, 0, dnums, false, + may_combine_partial_sharding), + HloSharding::Replicate()); + EXPECT_EQ(InferDotOperandSharding(nullptr, &lhs_sharding, 1, dnums, false, + may_combine_partial_sharding), + HloSharding::Replicate()); +} + } // namespace } // namespace hlo_sharding_util } // namespace xla diff --git a/third_party/xla/xla/lit.bzl b/third_party/xla/xla/lit.bzl index d6ec58096671f3..bbee57e4246e46 100644 --- a/third_party/xla/xla/lit.bzl +++ b/third_party/xla/xla/lit.bzl @@ -1,7 +1,7 @@ """Helper rules for writing LIT tests.""" load("@bazel_skylib//lib:paths.bzl", "paths") -load("//xla/tsl:tsl.bzl", "if_oss") +load("//xla/tsl:tsl.bzl", "if_hermetic_cuda_tools", "if_oss") def enforce_glob(files, **kwargs): """A utility to enforce that a list matches a glob expression. @@ -50,6 +50,7 @@ def lit_test_suite( timeout = None, default_tags = None, tags_override = None, + hermetic_cuda_data_dir = None, **kwargs): """Creates one lit test per source file and a test suite that bundles them. @@ -74,6 +75,8 @@ def lit_test_suite( timeout: timeout argument passed to the individual tests. default_tags: string list. Tags applied to all tests. tags_override: string_dict. Tags applied in addition to only select tests. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -105,6 +108,7 @@ def lit_test_suite( env = env, timeout = timeout, tags = default_tags + tags_override.get(test_file, []), + hermetic_cuda_data_dir = hermetic_cuda_data_dir, **kwargs ) @@ -114,6 +118,23 @@ def lit_test_suite( **kwargs ) +def lit_script_with_xla_gpu_cuda_data_dir( + name, + input_file, + output_file, + xla_gpu_cuda_data_dir): + """Adds a line to the LIT script to set the XLA_FLAGS environment variable.""" + return native.genrule( + name = name, + srcs = [input_file], + outs = [output_file], + cmd = if_hermetic_cuda_tools( + """echo -e '// RUN: export XLA_FLAGS=\"--xla_gpu_cuda_data_dir={}\"' > $@; +cat $< >> $@;""".format(xla_gpu_cuda_data_dir), + "cat $< >> $@;", + ), + ) + def lit_test( name, test_file, @@ -124,6 +145,7 @@ def lit_test( visibility = None, env = None, timeout = None, + hermetic_cuda_data_dir = None, **kwargs): """Runs a single test file with LLVM's lit tool. @@ -146,6 +168,8 @@ def lit_test( env: string_dict. Environment variables available during test execution. See the common Bazel test attribute. timeout: bazel test timeout string, as per common bazel definitions. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. **kwargs: additional keyword arguments to pass to all generated rules. See https://llvm.org/docs/CommandGuide/lit.html for details on lit @@ -170,12 +194,19 @@ def lit_test( tools_on_path_target_name, "lit_bin", ) + lib_dir = paths.join( + native.package_name(), + tools_on_path_target_name, + "lit_lib", + ) _tools_on_path( name = tools_on_path_target_name, testonly = True, srcs = tools, bin_dir = bin_dir, + lib_dir = lib_dir, + deps = ["//xla/stream_executor/cuda:all_runtime"], visibility = ["//visibility:private"], **kwargs ) @@ -195,6 +226,18 @@ def lit_test( ) # copybara:comment_end + + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(test_file) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + test_file, + output_file, + hermetic_cuda_data_dir, + ) + test_file = output_file + native_test( name = name, src = lit_name, @@ -275,6 +318,22 @@ def _tools_on_path_impl(ctx): " {} and {} conflict".format(runfiles_symlinks[bin_path], exe)) runfiles_symlinks[bin_path] = exe + # The loop below symlinks the libraries that are used by the tools. + for dep in ctx.attr.deps: + linker_inputs = dep[CcInfo].linking_context.linker_inputs.to_list() + for linker_input in linker_inputs: + if len(linker_input.libraries) == 0: + continue + lib = linker_input.libraries[0].dynamic_library + if not lib: + continue + lib_path = paths.join(ctx.attr.lib_dir, lib.basename) + if lib_path in runfiles_symlinks: + fail("All libs used by lit tests must have unique basenames, as" + + " they are added to the path." + + " {} and {} conflict".format(runfiles_symlinks[lib_path], lib)) + runfiles_symlinks[lib_path] = lib + return [ DefaultInfo(runfiles = ctx.runfiles( symlinks = runfiles_symlinks, @@ -286,6 +345,8 @@ _tools_on_path = rule( attrs = { "srcs": attr.label_list(allow_files = True, mandatory = True), "bin_dir": attr.string(mandatory = True), + "lib_dir": attr.string(mandatory = True), + "deps": attr.label_list(), }, doc = "Symlinks srcs into a single lit_bin directory. All basenames must be unique.", ) diff --git a/third_party/xla/xla/literal.cc b/third_party/xla/xla/literal.cc index 34d70133ab2411..4ce52706ed97a1 100644 --- a/third_party/xla/xla/literal.cc +++ b/third_party/xla/xla/literal.cc @@ -252,7 +252,7 @@ Literal::Literal(const Shape& shape) void Literal::SetShape(const Shape& shape) { Shape shape_storage; const Shape* shape_ptr = &shape; - if (LayoutUtil::HasCustomElementSizeInBits(shape)) { + if (shape.IsArray() && LayoutUtil::HasCustomElementSizeInBits(shape)) { shape_storage = shape; shape_storage.mutable_layout()->set_element_size_in_bits(0); shape_ptr = &shape_storage; diff --git a/third_party/xla/xla/literal_comparison_test.cc b/third_party/xla/xla/literal_comparison_test.cc index 241baf6e9eb84f..893820780276fe 100644 --- a/third_party/xla/xla/literal_comparison_test.cc +++ b/third_party/xla/xla/literal_comparison_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/literal_util.h" #include "xla/test_helpers.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/ml_dtypes.h" namespace xla { diff --git a/third_party/xla/xla/literal_test.cc b/third_party/xla/xla/literal_test.cc index cddd1212bfee20..42b4340d2ddf82 100644 --- a/third_party/xla/xla/literal_test.cc +++ b/third_party/xla/xla/literal_test.cc @@ -46,10 +46,10 @@ limitations under the License. #include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "tsl/platform/macros.h" @@ -2583,6 +2583,14 @@ TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArrays) { EXPECT_FALSE(c1.IsKnown()); } +TEST_F(LiteralUtilTest, CreateFromShapeWithUnknownLeafArraysS4Tuple) { + auto inner_shape = ShapeUtil::MakeShape(S4, {4, 4}); + inner_shape.mutable_layout()->set_element_size_in_bits(4); + Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( + ShapeUtil::MakeTupleShape({inner_shape})); + EXPECT_FALSE(c1.IsKnown()); +} + TEST_F(LiteralUtilTest, CreatePartiallyKnownTuple) { Literal c1 = Literal::CreateFromShapeWithUnknownLeafArrays( ShapeUtil::MakeShape(F32, {4, 4})); diff --git a/third_party/xla/xla/literal_util.h b/third_party/xla/xla/literal_util.h index f6e5f581802480..dc282774a83f5a 100644 --- a/third_party/xla/xla/literal_util.h +++ b/third_party/xla/xla/literal_util.h @@ -533,7 +533,7 @@ template template /* static */ Literal LiteralUtil::MakeScalarMatrixR2(int64_t size, NativeT scalar) { - Array2D array(size, size, 0); + Array2D array(size, size, NativeT(0)); for (int64_t i = 0; i < size; ++i) { array(i, i) = scalar; } @@ -542,7 +542,7 @@ template template /* static */ Literal LiteralUtil::MakeIdentityR2(int64_t size) { - return MakeScalarMatrixR2(size, 1); + return MakeScalarMatrixR2(size, NativeT(1)); } template @@ -550,7 +550,7 @@ template NativeT scale) { NativeT row_factor = log10(m) + 1; NativeT col_factor = log10(n) + 1; - Array2D array(m, n, 0); + Array2D array(m, n, NativeT(0)); for (int64_t i = 0; i < m; ++i) { for (int64_t j = 0; j < n; ++j) { array(i, i) = scale * (row_factor * i + col_factor * j); diff --git a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc index a190c13e5a4ac9..6716b1660fa960 100644 --- a/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc +++ b/third_party/xla/xla/mlir/tools/mlir_interpreter/dialects/vector.cc @@ -230,7 +230,7 @@ InterpreterValue MaskImpl(mlir::Operation* op, ArrayRef mask_sizes) { } InterpreterValue ConstantMask(InterpreterState&, vector::ConstantMaskOp mask) { - return MaskImpl(mask, ExtractVector(mask.getMaskDimSizes())); + return MaskImpl(mask, mask.getMaskDimSizes()); } // TODO(jreiffers): Support masked contractions. @@ -553,7 +553,7 @@ InterpreterValue MultiReduction(InterpreterState& state, const InterpreterValue& acc) { auto element_ty = getElementTypeOrSelf(reduction->getResultTypes()[0]); return {ReductionImpl(state, source, &acc, reduction.getKind(), - ExtractVector(reduction.getReductionDims()), + SmallVector(reduction.getReductionDims()), element_ty)}; } @@ -634,7 +634,7 @@ InterpreterValue Shuffle(InterpreterState& state, vector::ShuffleOp shuffle, auto& result_view = result.View(); result_view.is_vector = true; - auto mask = ExtractVector(shuffle.getMask()); + auto mask = shuffle.getMask(); bool is_zero_dim = v0.View().Rank() == 0; int64_t size0 = is_zero_dim ? 1 : v0.View().sizes[0]; for (auto [dst_index, src_index] : llvm::enumerate(mask)) { diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index 4267e525a8aaf1..bffefdb7229718 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -30,11 +30,11 @@ cc_test( srcs = ["error_util_test.cc"], deps = [ ":error_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test_main", ], diff --git a/third_party/xla/xla/mlir/utils/error_util_test.cc b/third_party/xla/xla/mlir/utils/error_util_test.cc index f325cd070f7f52..23f214f9658b26 100644 --- a/third_party/xla/xla/mlir/utils/error_util_test.cc +++ b/third_party/xla/xla/mlir/utils/error_util_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "llvm/ADT/Twine.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace mlir { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index 0de939289537b9..df3f11a2fc03f1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -415,7 +415,7 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp) //===----------------------------------------------------------------------===// // Follow async operation use-def chain to find the start of the async chain. -AsyncStartOp findAsyncChainStart(Operation* op) { +static AsyncStartOp findAsyncChainStart(Operation* op) { Operation* start = op; while (start != nullptr && !isa(start)) { start = start->getOperand(0).getDefiningOp(); @@ -423,8 +423,8 @@ AsyncStartOp findAsyncChainStart(Operation* op) { return dyn_cast_or_null(start); } -Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types, - bool expectsTuple = false) { +static Type maybeTupleFromTypes(MLIRContext* ctx, ArrayRef types, + bool expectsTuple = false) { if (!expectsTuple && types.size() == 1 && !isa(types[0])) return types[0]; return TupleType::get(ctx, TypeRange(types)); @@ -903,13 +903,31 @@ LogicalResult DotOp::verify() { //===----------------------------------------------------------------------===// LogicalResult DotGeneralOp::verify() { + bool isDefaultPrecisionConfig = + !getPrecisionConfig().has_value() || + llvm::all_of(getPrecisionConfig().value(), [](Attribute attr) { + return cast(attr).getValue() == Precision::DEFAULT; + }); + bool hasAlgorithmSpecified = getAlgorithm().has_value(); + if (hasAlgorithmSpecified) { + DotAlgorithmAttr attr = getAlgorithm().value(); + if (failed(DotAlgorithmAttr::verify( + [&] { return this->emitError(); }, attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), + attr.getAllowImpreciseAccumulation()))) + return failure(); + } + return hlo::verifyDotGeneralOp( getLoc(), getLhs(), getRhs(), getDotDimensionNumbersAttr().getLhsBatchingDimensions(), getDotDimensionNumbersAttr().getRhsBatchingDimensions(), getDotDimensionNumbersAttr().getLhsContractingDimensions(), getDotDimensionNumbersAttr().getRhsContractingDimensions(), - getPrecisionConfig(), getResult()); + getPrecisionConfig(), isDefaultPrecisionConfig, hasAlgorithmSpecified, + getResult()); } LogicalResult DotGeneralOp::reifyReturnTypeShapes( @@ -949,6 +967,17 @@ LogicalResult DotGeneralOp::reifyReturnTypeShapes( return success(); } +LogicalResult DotAlgorithmAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, + Type lhsPrecisionType, Type rhsPrecisionType, Type accumulationType, + int64_t lhsComponentCount, int64_t rhsComponentCount, + int64_t numPrimitiveOperations, bool allowImpreciseAccumulation) { + return hlo::verifyDotAlgorithmAttr( + emitError, lhsPrecisionType, rhsPrecisionType, accumulationType, + lhsComponentCount, rhsComponentCount, numPrimitiveOperations, + allowImpreciseAccumulation); +} + //===----------------------------------------------------------------------===// // SparseDotOp //===----------------------------------------------------------------------===// @@ -1002,8 +1031,9 @@ LogicalResult SparseDotOp::verify() { //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// -LogicalResult verify1dTensor(std::optional loc, - DenseIntElementsAttr attr, std::string attrName) { +static LogicalResult verify1dTensor(std::optional loc, + DenseIntElementsAttr attr, + std::string attrName) { auto rank = attr.getType().getRank(); if (rank != 1) { return emitOptionalError(loc, attrName, " has rank ", rank, @@ -1221,8 +1251,8 @@ LogicalResult GatherOp::inferReturnTypeComponents( //===----------------------------------------------------------------------===// // Canonicalize mhlo.dynamic_gather to mhlo.gather when slice_sizes is constant. -LogicalResult simplifyDynamicGatherToGather(DynamicGatherOp op, - PatternRewriter& rewriter) { +static LogicalResult simplifyDynamicGatherToGather(DynamicGatherOp op, + PatternRewriter& rewriter) { DenseIntElementsAttr dynamicGatherSliceSizes; if (!matchPattern(op.getSliceSizes(), m_Constant(&dynamicGatherSliceSizes))) { return failure(); @@ -1633,7 +1663,7 @@ struct ConvolutionIsDot : public OpRewritePattern { op.getContext(), {}, {}, {lhsContractDim}, {rhsContractDim}); auto dotOp = rewriter.create( op.getLoc(), op.getType(), lhs, rhs, dotNums, - op.getPrecisionConfig().value_or(nullptr)); + op.getPrecisionConfig().value_or(nullptr), DotAlgorithmAttr{}); rewriter.replaceOp(op, dotOp.getResult()); return success(); @@ -1669,7 +1699,7 @@ struct ConvolutionIsDot : public OpRewritePattern { {lhsContractDim + 1}, {rhsContractDim == 0 ? 2 : 0}); auto dotOp = rewriter.create( op.getLoc(), dotTy, lhs, rhs, dotNums, - op.getPrecisionConfig().value_or(nullptr)); + op.getPrecisionConfig().value_or(nullptr), DotAlgorithmAttr{}); llvm::SmallVector perms; perms.resize(3, dNums.getOutputBatchDimension() == 0 ? 0 : 2); @@ -3371,7 +3401,7 @@ Operation* ReduceWindowOp::getReductionOp(int resultIndex) { return nullptr; } -bool isSplatZero(SplatElementsAttr attr) { +static bool isSplatZero(SplatElementsAttr attr) { if (!attr) return false; if (isa(attr.getElementType())) { return attr.getSplatValue().isZero(); @@ -3609,7 +3639,7 @@ LogicalResult ReduceOp::fold(FoldAdaptor /*adaptor*/, return failure(); } -bool hasSameOperandAndResultTypes(Operation& op) { +static bool hasSameOperandAndResultTypes(Operation& op) { Type expected; if (op.getNumResults() != 0) expected = op.getResult(0).getType(); if (op.getNumOperands() != 0) expected = op.getOperand(0).getType(); @@ -4588,9 +4618,9 @@ struct Abs { } }; -double rsqrt(double d) { return 1.0 / std::sqrt(d); } +static double rsqrt(double d) { return 1.0 / std::sqrt(d); } -double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } +static double logistic(double d) { return 1.0 / (1.0 + std::exp(-d)); } // NOLINTBEGIN(bugprone-macro-parentheses) #define UNARY_FOLDER(Op, Func) \ @@ -4828,7 +4858,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { return {}; } -bool isSplatOne(SplatElementsAttr attr) { +static bool isSplatOne(SplatElementsAttr attr) { if (!attr) return false; if (isa(attr.getElementType())) { return attr.getSplatValue().convertToDouble() == 1.0; @@ -5756,8 +5786,8 @@ LogicalResult ScatterOp::verify() { getScatterDimensionNumbers().getIndexVectorDim(), getUpdateComputation()); } -llvm::SmallVector evaluateMhloRegion(Region& region, - ArrayRef inputs) { +static llvm::SmallVector evaluateMhloRegion( + Region& region, ArrayRef inputs) { if (region.getNumArguments() != inputs.size()) return {}; llvm::DenseMap values; @@ -6950,8 +6980,8 @@ static LogicalResult verifyArgResultAliasAttr(StringAttr attrName, // Each CrossProgramPrefetchAttr specifies a parameter and a ShapeIndex // (1) the parameter must be valid // (2) there must be a subshape at the given indices -LogicalResult verifyCrossProgramPrefetchAttr(CrossProgramPrefetchAttr cpp, - ModuleOp module) { +static LogicalResult verifyCrossProgramPrefetchAttr( + CrossProgramPrefetchAttr cpp, ModuleOp module) { func::FuncOp main = module.lookupSymbol("main"); if (cpp.getParameter() >= main.getNumArguments() || cpp.getParameter() < 0) return module->emitOpError() @@ -7055,7 +7085,7 @@ Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, return builder.create(loc, type, elementsAttr); } -int64_t getNumLeafBuffers(Type type) { +static int64_t getNumLeafBuffers(Type type) { if (auto tuple = dyn_cast(type)) { auto ans = 0; for (auto type : tuple.getTypes()) ans += getNumLeafBuffers(type); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td index bda156bbbdc5b2..3b68ad70a332e1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2418,7 +2418,7 @@ def MHLO_ConvolutionOp : MHLO_Op<"convolution", [Pure]> { MHLO_ConvDimensionNumbers:$dimension_numbers, ConfinedAttr:$feature_group_count, ConfinedAttr:$batch_group_count, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); @@ -2608,7 +2608,7 @@ def MHLO_DotOp: MHLO_Op<"dot", [Pure]> { let arguments = ( ins MHLO_Tensor:$lhs, MHLO_Tensor:$rhs, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); // Dot op required custom exporter to pass the preferred element type @@ -2643,7 +2643,8 @@ def MHLO_DotGeneralOp: MHLO_ShapedInterfaceOp<"dot_general", [Pure]> { MHLO_Tensor:$lhs, MHLO_Tensor:$rhs, MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config, + OptionalAttr:$algorithm ); let results = (outs MHLO_Tensor); @@ -2667,7 +2668,7 @@ def MHLO_SparseDotOp: MHLO_Op<"sparse_dot", [Pure]> { OptionalAttr:$lhs_sparsity, OptionalAttr:$rhs_sparsity, MHLO_DotDimensionNumbers:$dot_dimension_numbers, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); // SparseDot op required custom exporter to pass the preferred element type @@ -3850,7 +3851,7 @@ def MHLO_DynamicConvOp : MHLO_Op<"dynamic_conv", [Pure]> { MHLO_ConvDimensionNumbers:$dimension_numbers, ConfinedAttr:$feature_group_count, ConfinedAttr:$batch_group_count, - MHLO_PrecisionConfigAttr:$precision_config + OptionalAttr:$precision_config ); let results = (outs MHLO_Tensor); let hasCanonicalizer = 1; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td index c43d89a34709e8..229e0d72e0437f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_attrs.td @@ -53,6 +53,32 @@ def MHLO_GatherDimensionNumbers : AttrDef { + let mnemonic = "dot_algorithm"; + let summary = "Attribute that models the algorithm constraints to use for computing dot."; + let parameters = (ins + "Type":$lhsPrecisionType, + "Type":$rhsPrecisionType, + "Type":$accumulationType, + "int64_t":$lhsComponentCount, + "int64_t":$rhsComponentCount, + "int64_t":$numPrimitiveOperations, + "bool":$allowImpreciseAccumulation + ); + let assemblyFormat = [{ + `<` + `lhs_precision_type` `=` $lhsPrecisionType `,` + `rhs_precision_type` `=` $rhsPrecisionType `,` + `accumulation_type` `=` $accumulationType `,` + `lhs_component_count` `=` $lhsComponentCount `,` + `rhs_component_count` `=` $rhsComponentCount `,` + `num_primitive_operations` `=` $numPrimitiveOperations `,` + `allow_imprecise_accumulation` `=` $allowImpreciseAccumulation + `>` + }]; + let genVerifyDecl = 1; +} + def MHLO_DotDimensionNumbers : AttrDef { let mnemonic = "dot"; let summary = "Attribute that models the dimension information for dot."; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td index 25375ac741da18..3e4039ef9598ad 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_enums.td @@ -40,8 +40,7 @@ def MHLO_PrecisionAttr : EnumAttr; // TODO(b/129153247) See if it's possible to also validate the size. def MHLO_PrecisionConfigAttr: - OptionalAttr< - TypedArrayAttrBase>; + TypedArrayAttrBase; //===----------------------------------------------------------------------===// // Custom call schedule hints diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt index 8131c0caab9571..60afe10e64759a 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/CMakeLists.txt @@ -175,6 +175,7 @@ add_mlir_library(ChloPasses MLIRhlo_opsIncGen MLIRChloLegalizeToHloIncGen MLIRMhloPassIncGen + PassesIncGen LINK_COMPONENTS Core diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 159a95463fa72f..196f65d068365f 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -258,6 +258,13 @@ Attribute convertAttr(Attribute hloAttr) { } // NOTE: We cannot process CustomCallApiVersionAttr here because // `dyn_cast()` succeeds for IntegerAttr too. + if (auto attr = mlir::dyn_cast(hloAttr)) { + return stablehlo::DotAlgorithmAttr::get( + attr.getContext(), attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), attr.getAllowImpreciseAccumulation()); + } if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::DotDimensionNumbersAttr::get( attr.getContext(), attr.getLhsBatchingDimensions(), diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc index bfeeaed83f89d1..e986bdc5ad694c 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_to_dot_general/legalize_dot_to_dot_general.cc @@ -55,9 +55,9 @@ struct DotToDotGeneralPattern : public OpRewritePattern { /*lhsContractingDimensions=*/{lhs.getType().getRank() - 1}, /*rhsContractingDimensions=*/{0}); - rewriter.replaceOpWithNewOp(dotOp, dotOp.getType(), lhs, rhs, - dotDimensionNumbers, - dotOp.getPrecisionConfigAttr()); + rewriter.replaceOpWithNewOp( + dotOp, dotOp.getType(), lhs, rhs, dotDimensionNumbers, + dotOp.getPrecisionConfigAttr(), DotAlgorithmAttr{}); return success(); } }; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index f8c0f9eafd7c83..c35ce560146dcb 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -159,7 +159,7 @@ struct EinsumToDotGeneralPattern : public OpRewritePattern { auto dotGeneralOp = rewriter.create( einsum.getLoc(), dotGeneralResultType, einsum.getLhs(), einsum.getRhs(), dimNumbers, - /*precision_config=*/ArrayAttr{}); + /*precision_config=*/ArrayAttr{}, /*dot_algorithm=*/DotAlgorithmAttr{}); if (isNaturalOrder) { // The dot_general is already in an appropriate result order. diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc index 0ab6c240c1e790..dfd370298bd862 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/prepare_for_export/prepare_for_export.cc @@ -87,56 +87,6 @@ static void prepareConstantOp(Operation *op, SplatElementsAttr attr) { op->erase(); } -// Ensure that there aren't any implicit capture before exporting. -static void prepareWhileOp(WhileOp whileOp) { - llvm::SetVector implicitInputs; - getUsedValuesDefinedAbove(whileOp->getRegions(), implicitInputs); - if (implicitInputs.empty()) return; - // Each captured value has to be passed as operand to the while, become then - // an operand to the condition region and the body region, and an extra - // operand to the return op in the body. It also becomes an extra result for - // the while operation, even if it is unused. - // We'll process the captured values one at a time and patch the body and - // condition regions as we go, but we'll accumulate the new operands and - // result type and recreate a new while op to replace the existing one at the - // end. - SmallVector returnedTypes(whileOp->getResultTypes().begin(), - whileOp->getResultTypes().end()); - SmallVector operands(whileOp->getOperands().begin(), - whileOp->getOperands().end()); - Region &condRegion = whileOp.getCond(); - Region &bodyRegion = whileOp.getBody(); - - for (Value input : implicitInputs) { - returnedTypes.push_back(input.getType()); - operands.push_back(input); - - Value condArg = - condRegion.front().addArgument(input.getType(), input.getLoc()); - Value bodyArg = - bodyRegion.front().addArgument(input.getType(), input.getLoc()); - for (OpOperand &operand : llvm::make_early_inc_range(input.getUses())) { - if (condRegion.isAncestor(operand.getOwner()->getParentRegion())) - operand.set(condArg); - else if (bodyRegion.isAncestor(operand.getOwner()->getParentRegion())) - operand.set(bodyArg); - } - auto returnOp = cast(bodyRegion.front().back()); - returnOp->insertOperands(returnOp->getNumOperands(), bodyArg); - } - OpBuilder builder(whileOp); - auto newWhileOp = - builder.create(whileOp.getLoc(), returnedTypes, operands); - newWhileOp.getCond().getBlocks().clear(); - newWhileOp.getCond().takeBody(whileOp.getCond()); - newWhileOp.getBody().getBlocks().clear(); - newWhileOp.getBody().takeBody(whileOp.getBody()); - for (auto zippedResults : - llvm::zip_first(whileOp.getResults(), newWhileOp.getResults())) - std::get<0>(zippedResults).replaceAllUsesWith(std::get<1>(zippedResults)); - whileOp->erase(); -} - static void prepareBroadcastInDim(BroadcastInDimOp bcast) { DenseIntElementsAttr dims = bcast.getBroadcastDimensions(); // If dimensions aren't sorted, there is a transpose fused into the op, which @@ -200,7 +150,6 @@ void PrepareForExportPass::runOnOperation() { mlir::SplatElementsAttr attr; if (matchPattern(op, m_Constant(&attr))) return prepareConstantOp(op, attr); - if (auto whileOp = dyn_cast(op)) return prepareWhileOp(whileOp); if (auto bcastOp = dyn_cast(op)) return prepareBroadcastInDim(bcastOp); // IfOp, CaseOp, WhileOp are already being handled during diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index cd94cb58733b33..7570d34ace0bc1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -88,6 +88,13 @@ Attribute convertAttr(Attribute stablehloAttr) { mlir::dyn_cast(stablehloAttr)) { RETURN_CONVERTED_ENUM_ATTR(CustomCallApiVersion); } + if (auto attr = mlir::dyn_cast(stablehloAttr)) { + return mhlo::DotAlgorithmAttr::get( + attr.getContext(), attr.getLhsPrecisionType(), + attr.getRhsPrecisionType(), attr.getAccumulationType(), + attr.getLhsComponentCount(), attr.getRhsComponentCount(), + attr.getNumPrimitiveOperations(), attr.getAllowImpreciseAccumulation()); + } if (auto attr = mlir::dyn_cast(stablehloAttr)) { return mhlo::DotDimensionNumbersAttr::get( diff --git a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp index fb5728f5ab05f9..68ec9699e9030d 100644 --- a/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp +++ b/third_party/xla/xla/mlir_hlo/stablehlo_ext/transforms/chlo_recompose_ops.cpp @@ -104,7 +104,7 @@ struct TopKOpRecomposePattern auto res = verifyCustomCallOpAttributes( op, rewriter, [&](NamedAttribute attr) -> LogicalResult { if (attr.getName() != "largest") return success(); - if (cast(attr.getValue()).getValue() == false) + if (!cast(attr.getValue()).getValue()) return rewriter.notifyMatchFailure( op, "largest = false is not supported."); return success(); diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir index 4324c8e7731b2b..49f1de75619860 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/chlo/chlo_legalize_to_mhlo.mlir @@ -2802,7 +2802,7 @@ func.func @tan_f32(%arg : tensor) -> tensor { // CHECK-SAME: (%[[ARG:.*]]: tensor<16x16xf32>) func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32>) { // CHECK-HIGH-LEVEL: mhlo.topk - // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8, largest = true) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) %1:2 = chlo.top_k(%arg, k=8) : tensor<16x16xf32> -> (tensor<16x8xf32>, tensor<16x8xi32>) func.return %1#0, %1#1 : tensor<16x8xf32>, tensor<16x8xi32> } @@ -2814,7 +2814,7 @@ func.func @top_k(%arg : tensor<16x16xf32>) -> (tensor<16x8xf32>, tensor<16x8xi32 // CHECK-SAME: -> (tensor, tensor) func.func @dyn_top_k(%arg0: tensor) -> (tensor, tensor) { // CHECK-HIGH-LEVEL: mhlo.topk - // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2, largest = true) : tensor -> (tensor, tensor) + // CHECK: %values, %indices = mhlo.topk(%arg0, k = 2) : tensor -> (tensor, tensor) %values, %indices = chlo.top_k(%arg0, k = 2) : tensor -> (tensor, tensor) return %values, %indices : tensor, tensor } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index c25e8ff27fe486..90965f06086831 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -802,6 +802,45 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) func.return %0 : tensor<8x8x8xf32> } +// CHECK-LABEL: "op_dot_general_algorithm" +func.func @op_dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "stablehlo.dot_general"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ + // CHECK-SAME: algorithm = #stablehlo.dot_algorithm< + // CHECK-SAME: lhs_precision_type = tf32, + // CHECK-SAME: rhs_precision_type = tf32, + // CHECK-SAME: accumulation_type = f32, + // CHECK-SAME: lhs_component_count = 1, + // CHECK-SAME: rhs_component_count = 1, + // CHECK-SAME: num_primitive_operations = 1, + // CHECK-SAME: allow_imprecise_accumulation = false + // CHECK-SAME: >, + // CHECK-SAME: dot_dimension_numbers = #stablehlo.dot< + // CHECK-SAME: lhs_batching_dimensions = [0], + // CHECK-SAME: rhs_batching_dimensions = [0], + // CHECK-SAME: lhs_contracting_dimensions = [2], + // CHECK-SAME: rhs_contracting_dimensions = [1] + // CHECK-SAME: > + // CHECK-SAME: }> : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + %0 = "mhlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #mhlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #mhlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + // CHECK-LABEL: "op_dot" func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { // CHECK: "stablehlo.dot"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir index 6944e796b73f65..a214de54fe8b50 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/prepare-for-export.mlir @@ -22,94 +22,6 @@ func.func @splat_constant_complex_float() -> tensor<128x1014x508xcomplex> { // ----- -// CHECK-LABEL: @while_without_implicit_capture -func.func @while_without_implicit_capture(%arg0: tensor) -> tensor { - // CHECK: mhlo.while - // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0) - // CHECK-SAME: {mhlo.sharding = "{{\{}}{replicated},{replicated}}"} - %0:2 = "mhlo.while"(%arg0, %arg0) ({ - ^bb0(%arg1: tensor, %arg2: tensor): - %1 = "mhlo.compare"(%arg1, %arg2) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor, %arg2: tensor): - %2 = mhlo.add %arg1, %arg1 : tensor - "mhlo.return"(%2, %arg2) : (tensor, tensor) -> () - }) {mhlo.sharding = "{{replicated},{replicated}}"} : (tensor, tensor) -> (tensor, tensor) - func.return %0#0 : tensor -} - -// ----- - -// CHECK-LABEL: @while_with_implicit_arg_capture -func.func @while_with_implicit_arg_capture(%arg0: tensor) -> tensor { - // CHECK: mhlo.while - // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0) - %0 = "mhlo.while"(%arg0) ({ - ^bb0(%arg1: tensor): - // CHECK: mhlo.compare - // CHECK-SAME: %[[ARG2]], %[[ARG1]] - %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG1]], %[[ARG1]] - %2 = mhlo.add %arg1, %arg1 : tensor - // CHECK: mhlo.return - // CHECK-SAME: %[[ADD]], %[[ARG2]] - "mhlo.return"(%2) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - -// CHECK-LABEL: @while_with_implicit_capture -// func @while_with_implicit_capture(%arg0 : tuple, tensor<5xi32>>) -> tuple, tensor<5xi32>> { -func.func @while_with_implicit_capture(%arg0 : tensor, %arg1 : tensor<5xi32>) -> tuple, tensor<5xi32>> { - %0 = mhlo.constant dense<0> : tensor - %1 = mhlo.constant dense : tensor - // Check that the iota implicit capture is made explicit - // CHECK: %[[IOTA:.*]] = "mhlo.iota - %2 = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<5xi32> - // CHECK: mhlo.while{{.*}} %[[IOTA]]) - %3:2 = "mhlo.while"(%arg0, %arg1) ({ - ^bb0(%arg2: tensor, %arg3 : tensor<5xi32>): - "mhlo.return"(%arg2) : (tensor) -> () - }, { - ^bb0(%arg2: tensor, %arg3 : tensor<5xi32>): - "mhlo.return"(%arg2, %2) : (tensor, tensor<5xi32>) -> () - }) : (tensor, tensor<5xi32>) -> (tensor, tensor<5xi32>) - %4 = "mhlo.tuple"(%3#0, %3#1) : (tensor, tensor<5xi32>) -> tuple, tensor<5xi32>> - func.return %4 : tuple, tensor<5xi32>> - } - -// ----- - -// Verifies that a value captured multiple times gets all of its uses updated. -// CHECK-LABEL: @while_with_multiple_capture -func.func @while_with_multiple_capture(%arg0: tensor) -> tensor { - // CHECK: mhlo.while - // CHECK-SAME: (%[[ARG1:.*]] = %arg0, %[[ARG2:.*]] = %arg0) - %0 = "mhlo.while"(%arg0) ({ - ^bb0(%arg1: tensor): - // CHECK: mhlo.compare - // CHECK-SAME: %[[ARG2]], %[[ARG1]] - %1 = "mhlo.compare"(%arg0, %arg1) {comparison_direction = #mhlo} : (tensor, tensor) -> tensor - "mhlo.return"(%1) : (tensor) -> () - }, { - ^bb0(%arg1: tensor): - // CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG1]] - %2 = mhlo.add %arg0, %arg1 : tensor - // CHECK: mhlo.return - // CHECK-SAME: %[[ADD]], %[[ARG2]] - "mhlo.return"(%2) : (tensor) -> () - }) : (tensor) -> tensor - func.return %0 : tensor -} - -// ----- - // CHECK-LABEL: @broadcast_in_dim_dimension_unsorted func.func @broadcast_in_dim_dimension_unsorted(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> { // Unfuse the transpose from the broadcastInDim before export. diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index 3c47a056eb638d..0f2e1b108a710f 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -786,6 +786,45 @@ func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) func.return %0 : tensor<8x8x8xf32> } +// CHECK-LABEL: "op_dot_general_algorithm" +func.func @op_dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "mhlo.dot_general"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ + // CHECK-SAME: algorithm = #mhlo.dot_algorithm< + // CHECK-SAME: lhs_precision_type = tf32, + // CHECK-SAME: rhs_precision_type = tf32, + // CHECK-SAME: accumulation_type = f32, + // CHECK-SAME: lhs_component_count = 1, + // CHECK-SAME: rhs_component_count = 1, + // CHECK-SAME: num_primitive_operations = 1, + // CHECK-SAME: allow_imprecise_accumulation = false + // CHECK-SAME: >, + // CHECK-SAME: dot_dimension_numbers = #mhlo.dot< + // CHECK-SAME: lhs_batching_dimensions = [0], + // CHECK-SAME: rhs_batching_dimensions = [0], + // CHECK-SAME: lhs_contracting_dimensions = [2], + // CHECK-SAME: rhs_contracting_dimensions = [1] + // CHECK-SAME: > + // CHECK-SAME: }> : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + // CHECK-LABEL: "op_dot" func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { // CHECK: "mhlo.dot"([[ARG0:%arg[0-9]+]], [[ARG1:%arg[0-9]+]]) <{ diff --git a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc index 2562c060e82f9b..0c8a4de9e98454 100644 --- a/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc +++ b/third_party/xla/xla/mlir_hlo/utils/codegen_utils.cc @@ -100,9 +100,10 @@ SmallVector calcMultiDimIndex(OpBuilder& b, Location loc, return calcMultiDimIndex(b, loc, linearIndex, shapeVec); } -SmallVector calcMultiDimIndexForFirstOperand(OpBuilder& b, Location loc, - Value linearIndex, - Operation* op) { +static SmallVector calcMultiDimIndexForFirstOperand(OpBuilder& b, + Location loc, + Value linearIndex, + Operation* op) { assert(op->getDialect()->getNamespace() == "lmhlo"); Value operandMemref = op->getOperand(0); return calcMultiDimIndex(b, loc, linearIndex, operandMemref); diff --git a/third_party/xla/xla/pjrt/BUILD b/third_party/xla/xla/pjrt/BUILD index 67a03dde3e592d..adf30e0833ba0f 100644 --- a/third_party/xla/xla/pjrt/BUILD +++ b/third_party/xla/xla/pjrt/BUILD @@ -185,8 +185,8 @@ xla_cc_test( ":pjrt_api", "//xla/pjrt/c:pjrt_c_api_hdrs", "//xla/pjrt/c:pjrt_c_api_wrapper_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], @@ -219,6 +219,8 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -226,6 +228,7 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -308,6 +311,10 @@ cc_library( deps = [ ":pjrt_common", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", ], ) @@ -471,6 +478,7 @@ cc_library( deps = [ ":event_pool", ":host_callback", + ":host_memory_spaces", ":local_device_state", ":metrics", ":mlir_to_hlo", @@ -495,7 +503,6 @@ cc_library( "//xla/client:local_client", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", - "//xla/pjrt:host_memory_spaces", "//xla/pjrt/distributed:protocol_proto_cc", "//xla/service:compiler", "//xla/service:computation_layout", @@ -516,6 +523,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -556,10 +564,10 @@ xla_cc_test( "//xla/service:cpu_plugin", "//xla/service:platform_util", "//xla/tsl/concurrency:async_value", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], @@ -610,6 +618,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:statusor", + "@shardy//shardy/dialect/sdy/ir:dialect", "@stablehlo//:chlo_ops", "@stablehlo//:register", "@stablehlo//:stablehlo_ops", @@ -831,9 +840,9 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/pjrt/c:pjrt_c_api_cpu_internal", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -846,12 +855,28 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":pjrt_client", + ":pjrt_common", + ":pjrt_compiler", + ":pjrt_executable", ":pjrt_future", + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/service:computation_placer_hdr", + "//xla/service:hlo_cost_analysis", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:errors", ], ) @@ -896,9 +921,9 @@ xla_cc_test( ":host_callback", ":pjrt_client", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/xla/pjrt/c/BUILD b/third_party/xla/xla/pjrt/c/BUILD index dc44b6623ee7a9..1c854c092f24a5 100644 --- a/third_party/xla/xla/pjrt/c/BUILD +++ b/third_party/xla/xla/pjrt/c/BUILD @@ -410,13 +410,13 @@ xla_cc_test( "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_common", "//xla/pjrt/distributed:in_memory_key_value_store", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc index eee88adae5a78d..e17b04de73cec0 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_gpu_test.cc @@ -77,8 +77,8 @@ class PjrtCApiGpuTest : public PjrtCApiTestBase { TEST_F(PjrtCApiGpuTest, CreateViewOfDeviceBuffer) { // Prepares a device memory ptr on GPU. - std::unique_ptr buffer = - create_buffer().first; + auto [buffer, buffer_future] = create_buffer(); + TF_CHECK_OK(buffer_future.Await()); PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args device_buffer_ptr_args; device_buffer_ptr_args.struct_size = PJRT_Buffer_OpaqueDeviceMemoryDataPointer_Args_STRUCT_SIZE; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc index fcfc9119e9ec79..b9508cf24950b4 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers.cc @@ -994,13 +994,6 @@ absl::Span DeviceDescriptions( absl::StatusOr GetCompiledMemoryStats( const PJRT_Api* api, PJRT_Executable* executable) { - // TODO(jieying): To be removed after 03/2024. - if (api->pjrt_api_version.major_version == 0 && - api->pjrt_api_version.minor_version < 40) { - return absl::UnimplementedError( - "GetCompiledMemoryStats requires a plugin with PJRT C API version >= " - "0.40"); - } PJRT_Executable_GetCompiledMemoryStats_Args args; args.struct_size = PJRT_Executable_GetCompiledMemoryStats_Args_STRUCT_SIZE; args.extension_start = nullptr; diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc index d6e240d8c5e96b..8d0a51a48bc840 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_helpers_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/pjrt_client.h" #include "xla/pjrt/pjrt_common.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc index e1ba7c832f314d..54b8dbb6514350 100644 --- a/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc +++ b/third_party/xla/xla/pjrt/c/pjrt_c_api_wrapper_impl.cc @@ -2129,7 +2129,7 @@ PJRT_Error* PJRT_Layouts_MemoryLayout_Serialize( PJRT_Layouts_MemoryLayout_Serialize_Args_STRUCT_SIZE, args->struct_size)); PJRT_Layouts_SerializedLayout* s_layout = new PJRT_Layouts_SerializedLayout{ - .serialized = args->layout->layout->Serialize()}; + /* .serialized = */ args->layout->layout->Serialize()}; args->serialized_layout = s_layout; args->serialized_bytes = s_layout->serialized.data(); args->serialized_bytes_size = s_layout->serialized.size(); diff --git a/third_party/xla/xla/pjrt/cpu/BUILD b/third_party/xla/xla/pjrt/cpu/BUILD index 5405a7d3885ec0..c8896b77f4d019 100644 --- a/third_party/xla/xla/pjrt/cpu/BUILD +++ b/third_party/xla/xla/pjrt/cpu/BUILD @@ -149,6 +149,9 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_executor", "//xla/client:executable_build_options", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", @@ -185,12 +188,10 @@ cc_library( "//xla/service/cpu:cpu_runtime", "//xla/service/cpu:cpu_xfeed", "//xla/service/cpu:simple_orc_jit", - "//xla/service/cpu/runtime:buffer_allocations", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:thunk_executor", "//xla/stream_executor", "//xla/tsl/concurrency:async_value", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:dynamic_annotations", @@ -208,7 +209,6 @@ cc_library( "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # TODO(zhangqiaorjc): Remove if use TFRT threadpool. "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:denormal", "@local_tsl//tsl/platform:env", @@ -242,11 +242,11 @@ xla_cc_test( "//xla/service:hlo_proto_cc", "//xla/tests:literal_test_util", "//xla/tests:test_utils", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", @@ -306,7 +306,7 @@ cc_library( xla_cc_test( name = "gloo_collectives_test", srcs = ["gloo_collectives_test.cc"], - tags = ["nomac"], + linkstatic = True, deps = [ ":gloo_collectives", ":gloo_kv_store", @@ -317,17 +317,25 @@ xla_cc_test( "//xla/service:collective_ops_utils", "//xla/service:global_device_id", "//xla/service/cpu:collectives_interface", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@gloo//:transport_tcp", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", - ], + ] + select({ + # Gloo's transport_tcp is not available on MacOS + "//xla/tsl:macos": [ + "@gloo//:transport_uv", + ], + "//conditions:default": [ + "@gloo//:transport_tcp", + ], + }), ) cc_library( diff --git a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h index 293bc4016ed803..7c06b26c8aa785 100644 --- a/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/abstract_tfrt_cpu_buffer.h @@ -62,8 +62,8 @@ class MarkEventReadyOnExit { MarkEventReadyOnExit(const MarkEventReadyOnExit&) = delete; MarkEventReadyOnExit& operator=(const MarkEventReadyOnExit&) = delete; - MarkEventReadyOnExit(MarkEventReadyOnExit&&) = default; - MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) = default; + MarkEventReadyOnExit(MarkEventReadyOnExit&&) noexcept = default; + MarkEventReadyOnExit& operator=(MarkEventReadyOnExit&&) noexcept = default; ~MarkEventReadyOnExit() { if (event_) event_.SetStateConcrete(); @@ -163,7 +163,7 @@ class AbstractTfrtCpuBuffer : public PjRtBuffer { DonationTransaction(const DonationTransaction&) = delete; DonationTransaction& operator=(const DonationTransaction&) = delete; DonationTransaction(DonationTransaction&&) = default; - DonationTransaction& operator=(DonationTransaction&& other) { + DonationTransaction& operator=(DonationTransaction&& other) noexcept { Abort(); buffer_ = other.buffer_; diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.cc b/third_party/xla/xla/pjrt/cpu/cpu_client.cc index a832ab568d3408..65ed8589f4f2b6 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.cc @@ -47,6 +47,9 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "mlir/IR/BuiltinOps.h" #include "xla/array.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/debug_options_flags.h" @@ -83,9 +86,6 @@ limitations under the License. #include "xla/service/cpu/cpu_executable_run_options.h" #include "xla/service/cpu/cpu_runtime.h" #include "xla/service/cpu/cpu_xfeed.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -103,10 +103,10 @@ limitations under the License. #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/async_value_ref.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/casts.h" #include "tsl/platform/denormal.h" #include "tsl/platform/env.h" @@ -398,6 +398,11 @@ absl::StatusOr> GetTfrtCpuClient( num_threads, options.asynchronous)); } +// An upper bound on the number of threads to use for intra-op parallelism. It +// is nearly impossible to utilize efficiently more than 256 threads for compute +// intensive operations that are supposed to run inside the intra-op threadpool. +static const size_t kMaxIntraOpThreads = 256; + static tsl::ThreadOptions GetThreadOptions() { tsl::ThreadOptions thread_options; // On Mac OS the default stack size is 512KiB, which is too small for some @@ -415,16 +420,17 @@ TfrtCpuClient::TfrtCpuClient( : process_index_(process_index), owned_devices_(std::move(devices)), computation_placer_(std::make_unique()), + eigen_intraop_pool_(new tsl::thread::ThreadPool( + tsl::Env::Default(), "XLAEigen", + std::min(num_threads, kMaxIntraOpThreads))), + eigen_intraop_device_( + new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(), + eigen_intraop_pool_->NumThreads())), pjrt_client_thread_pool_( new tsl::thread::ThreadPool(tsl::Env::Default(), GetThreadOptions(), "XLATfrtCpuClient", num_threads)), async_work_runner_(std::make_unique( pjrt_client_thread_pool_.get())), - eigen_intraop_pool_(new tsl::thread::ThreadPool(tsl::Env::Default(), - "XLAEigen", num_threads)), - eigen_intraop_device_( - new Eigen::ThreadPoolDevice(eigen_intraop_pool_->AsEigenThreadPool(), - eigen_intraop_pool_->NumThreads())), last_collective_launch_event_( tsl::MakeAvailableAsyncValueRef()), transpose_cache_(1024), @@ -463,10 +469,10 @@ TfrtCpuClient::TfrtCpuClient( owned_memory_spaces_.push_back(std::move(memory_space)); } - LOG(INFO) << "TfrtCpuClient created."; + VLOG(1) << "TfrtCpuClient created."; } -TfrtCpuClient::~TfrtCpuClient() { LOG(INFO) << "TfrtCpuClient destroyed."; } +TfrtCpuClient::~TfrtCpuClient() { VLOG(1) << "TfrtCpuClient destroyed."; } absl::StatusOr TfrtCpuClient::LookupDevice( xla::PjRtGlobalDeviceId global_device_id) const { @@ -857,10 +863,7 @@ absl::StatusOr> TfrtCpuClient::Compile( TF_RETURN_IF_ERROR(MlirToXlaComputation( module, xla_computation, /*use_tuple_args=*/options.parameter_is_tupled_arguments, - /*return_tuple=*/false, - exec_build_options.has_debug_options() - ? exec_build_options.debug_options().xla_use_shardy() - : false)); + /*return_tuple=*/false, exec_build_options.use_shardy_partitioner())); return Compile(xla_computation, options); } diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client.h b/third_party/xla/xla/pjrt/cpu/cpu_client.h index 54f75f3e90f55a..ba4426b73a1c98 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client.h +++ b/third_party/xla/xla/pjrt/cpu/cpu_client.h @@ -446,14 +446,14 @@ class TfrtCpuClient final : public PjRtClient { // Pointers to `owned_memory_spaces_`. std::vector memory_spaces_; - // Thread pool for running PjRtClient tasks. - std::unique_ptr pjrt_client_thread_pool_; - std::unique_ptr async_work_runner_; - // TODO(zhangqiaorjc): Use tsl::compat::EigenHostContextThreadPool. std::unique_ptr eigen_intraop_pool_; std::unique_ptr eigen_intraop_device_; + // Thread pool for running PjRtClient tasks. + std::unique_ptr pjrt_client_thread_pool_; + std::unique_ptr async_work_runner_; + // Launching collectives are prone to deadlock when we use fixed-sized // threadpools since ExecuteHelper will block until all replicas reach the // barrier. We ensure that @@ -589,7 +589,7 @@ class TfrtCpuExecutable final : public PjRtLoadedExecutable { } memory_stats.serialized_hlo_proto = proto->SerializeAsString(); memory_stats.PopulateBufferStatsFromAllocations( - cpu_executable_.get()->GetAllocations()); + cpu_executable_->GetAllocations()); return memory_stats; } diff --git a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc index 641222e91ce21a..52ca154759e81a 100644 --- a/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc +++ b/third_party/xla/xla/pjrt/cpu/cpu_client_test.cc @@ -48,8 +48,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" diff --git a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc b/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc index 9301cf0a23d094..b8bb7810dd3909 100644 --- a/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc +++ b/third_party/xla/xla/pjrt/cpu/gloo_collectives_test.cc @@ -25,8 +25,12 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" #include "absl/types/span.h" +#if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#endif // defined(__linux__) #include "xla/executable_run_options.h" #include "xla/pjrt/cpu/gloo_kv_store.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" @@ -34,8 +38,8 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/cpu/collectives_interface.h" #include "xla/service/global_device_id.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -57,7 +61,11 @@ absl::StatusOr> GetCommunicator( const std::shared_ptr& kv_store, int rank) { auto collectives = std::make_shared( std::make_unique(kv_store), +#if defined(__linux__) gloo::transport::tcp::CreateDevice(gloo::transport::tcp::attr())); +#elif defined(__APPLE__) + gloo::transport::uv::CreateDevice(gloo::transport::uv::attr())); +#endif // defined(__linux__) return collectives->GetCommunicator(global_devices, rank); } diff --git a/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h index 4bfc1c57aed269..8d22bd891e6faf 100644 --- a/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h +++ b/third_party/xla/xla/pjrt/cpu/tracked_tfrt_cpu_device_buffer.h @@ -128,8 +128,9 @@ class TrackedTfrtCpuDeviceBuffer { absl::AnyInvocable on_delete_callback = nullptr); // Move-only. - TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default; - TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default; + TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) noexcept = default; + TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) noexcept = + default; TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete; TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) = delete; diff --git a/third_party/xla/xla/pjrt/distributed/BUILD b/third_party/xla/xla/pjrt/distributed/BUILD index be69c677d7011c..e0165c8f3e02fc 100644 --- a/third_party/xla/xla/pjrt/distributed/BUILD +++ b/third_party/xla/xla/pjrt/distributed/BUILD @@ -50,9 +50,9 @@ xla_cc_test( ":protocol_proto_cc", ":topology_util", "//xla:test_helpers", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -145,13 +145,13 @@ xla_cc_test( "//xla:protobuf_util", "//xla:status_macros", "//xla/tsl/distributed_runtime/coordination:coordination_service_agent", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", diff --git a/third_party/xla/xla/pjrt/distributed/client.cc b/third_party/xla/xla/pjrt/distributed/client.cc index d0d96c6c511d9f..ede5e27b860f0d 100644 --- a/third_party/xla/xla/pjrt/distributed/client.cc +++ b/third_party/xla/xla/pjrt/distributed/client.cc @@ -92,6 +92,8 @@ DistributedRuntimeCoordinationServiceClient:: absl::ToInt64Milliseconds(options.shutdown_timeout)); config.set_agent_destruction_without_shutdown( !options.shutdown_on_destruction); + config.set_poll_for_error_from_service_at_startup( + options.poll_for_error_from_service_at_startup); auto error_fn = [timeout_fn = options.missed_heartbeat_callback]( const absl::Status& status) { LOG(ERROR) << "Coordination service agent in error status: " << status; diff --git a/third_party/xla/xla/pjrt/distributed/client.h b/third_party/xla/xla/pjrt/distributed/client.h index 79973124485452..2387fe6dd452f5 100644 --- a/third_party/xla/xla/pjrt/distributed/client.h +++ b/third_party/xla/xla/pjrt/distributed/client.h @@ -101,6 +101,12 @@ class DistributedRuntimeClient { // For testing. Should the client explicitly Shutdown() on destruction? bool shutdown_on_destruction = true; + + // Whether the client should send a request to wait for error from the + // coordination service at the startup. + // TODO(b/355706798): Enable this by default once we confirm this works for + // all cases and eventually remove this option. + bool poll_for_error_from_service_at_startup = false; }; virtual ~DistributedRuntimeClient() = default; diff --git a/third_party/xla/xla/pjrt/distributed/client_server_test.cc b/third_party/xla/xla/pjrt/distributed/client_server_test.cc index 5ccbf232dd07a6..dfd46be79b29bc 100644 --- a/third_party/xla/xla/pjrt/distributed/client_server_test.cc +++ b/third_party/xla/xla/pjrt/distributed/client_server_test.cc @@ -45,7 +45,7 @@ limitations under the License. #include "xla/protobuf_util.h" #include "xla/status_macros.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_agent.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -424,6 +424,116 @@ TEST_F(ClientServerTest, ClientsTerminateShutdownIfAnyClientGoesAway) { } } +TEST_F(ClientServerTest, + ClientsTerminateShutdownIfAnyClientGoesAway_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = node_id != 0; + client_options.missed_heartbeat_callback = + [&](absl::Status status, bool coordinator_initiated) {}; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + + if (node_id == 0) { + return absl::OkStatus(); + } + + // The call to Shutdown() should be interrupted if a worker stops issuing + // heartbeats. + return client->Shutdown(); + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + TF_EXPECT_OK(statuses[0]); + for (int i = 1; i < num_nodes; ++i) { + // The error type depends on whether the node turns into ERROR state during + // or before the shutdown call. + EXPECT_TRUE(absl::IsInternal(statuses[i]) || + absl::IsFailedPrecondition(statuses[i])); + } +} + +TEST_F(ClientServerTest, ClientsShutdownSuccessfully_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = true; + client_options.missed_heartbeat_callback = + [&](absl::Status status, bool coordinator_initiated) {}; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + return client->Shutdown(); + // The error polling request will be cancelled automatically when the + // client is shutting down. + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + for (int i = 0; i < num_nodes; ++i) { + TF_EXPECT_OK(statuses[i]); + } +} + +TEST_F(ClientServerTest, + MissedHeartbeatCallbackIsExecutedIfAnyClientGoesAway_WithErrorPolling) { + int num_nodes = 3; + StartService(num_nodes); + + auto thread_fn = [&](int node_id) -> absl::Status { + DistributedRuntimeClient::Options client_options; + client_options.shutdown_on_destruction = (node_id != 0); + absl::Notification shutdown; + client_options.missed_heartbeat_callback = [&](absl::Status status, + bool coordinator_initiated) { + shutdown.Notify(); + }; + client_options.poll_for_error_from_service_at_startup = true; + auto client = GetClient(node_id, client_options); + + TF_RETURN_IF_ERROR(client->Connect()); + + if (node_id == 0) { + return absl::OkStatus(); + } + shutdown.WaitForNotification(); + return absl::OkStatus(); + }; + + std::vector statuses(num_nodes); + { + tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "test_threads", + num_nodes); + for (int i = 0; i < num_nodes; ++i) { + thread_pool.Schedule([&, i]() { statuses[i] = thread_fn(i); }); + } + } + for (int i = 0; i < num_nodes; ++i) { + TF_EXPECT_OK(statuses[i]); + } +} + TEST_F(ClientServerTest, ClientsReceiveMissedHeartbeatIfAnyClientGoesAway) { int num_nodes = 3; StartService(num_nodes); diff --git a/third_party/xla/xla/pjrt/distributed/service.cc b/third_party/xla/xla/pjrt/distributed/service.cc index 238e146bf044ec..6a8a77a5fca534 100644 --- a/third_party/xla/xla/pjrt/distributed/service.cc +++ b/third_party/xla/xla/pjrt/distributed/service.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include "absl/time/clock.h" #include "absl/time/time.h" diff --git a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc index aaf859c658e157..193dae87ca1a0b 100644 --- a/third_party/xla/xla/pjrt/distributed/topology_util_test.cc +++ b/third_party/xla/xla/pjrt/distributed/topology_util_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/pjrt/distributed/in_memory_key_value_store.h" #include "xla/pjrt/distributed/protocol.pb.h" #include "xla/test_helpers.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/gpu/BUILD b/third_party/xla/xla/pjrt/gpu/BUILD index 4e1f1181aae214..6d02226add5ff8 100644 --- a/third_party/xla/xla/pjrt/gpu/BUILD +++ b/third_party/xla/xla/pjrt/gpu/BUILD @@ -91,6 +91,7 @@ cc_library( "//xla/tsl/framework:bfc_allocator", "//xla/tsl/framework:device_id", "//xla/tsl/framework:device_id_impl", + "//xla/tsl/lib/strings:proto_serialization", "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", @@ -108,7 +109,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", @@ -152,6 +152,7 @@ xla_cc_test( "//xla:shape_util", "//xla:status_macros", "//xla:test", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/client:xla_computation", "//xla/ffi", @@ -168,13 +169,14 @@ xla_cc_test( "//xla/service:platform_util", "//xla/stream_executor", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc index cf60a8f6072c03..79f8d7db00c93f 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.cc @@ -82,7 +82,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/framework/allocator.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" @@ -122,17 +122,28 @@ class AsyncHostToDeviceTransferManager : public xla::PjRtClient::AsyncHostToDeviceTransferManager { public: static absl::StatusOr> - Create(absl::Span shapes, PjRtStreamExecutorDevice* device, - PjRtStreamExecutorClient* client, PjRtMemorySpace* memory_space) { + Create(absl::Span shape_specs, + std::optional> device_layouts, + PjRtStreamExecutorDevice* device, PjRtStreamExecutorClient* client, + PjRtMemorySpace* memory_space) { + if (device_layouts != std::nullopt && + device_layouts->size() != shape_specs.size()) { + return InvalidArgument( + "Number of layouts %d does not match the number of shapes %d", + device_layouts->size(), shape_specs.size()); + } absl::InlinedVector, 4> buffers; absl::InlinedVector, 4> buffer_ptrs; absl::InlinedVector, 4> definition_events; - buffers.reserve(shapes.size()); - buffer_ptrs.reserve(shapes.size()); - definition_events.reserve(shapes.size()); - for (const auto& shape : shapes) { - if (shape.IsTuple()) { + absl::InlinedVector device_shapes; + buffers.reserve(shape_specs.size()); + buffer_ptrs.reserve(shape_specs.size()); + definition_events.reserve(shape_specs.size()); + device_shapes.reserve(shape_specs.size()); + for (int i = 0; i < shape_specs.size(); ++i) { + const PjRtClient::ShapeSpec& shape_spec = shape_specs[i]; + if (shape_spec.element_type == TUPLE) { return Unimplemented( "Async buffer transfer of tuples not implemented."); } @@ -140,16 +151,22 @@ class AsyncHostToDeviceTransferManager // event will block the buffer usage until the transfer is done. definition_events.push_back( std::make_shared(client->thread_pool())); - TF_ASSIGN_OR_RETURN(Shape compact_shape, - client->client() - ->backend() - .transfer_manager() - ->ChooseCompactLayoutForShape(shape)); + Shape& device_shape = device_shapes.emplace_back( + ShapeUtil::MakeShape(shape_spec.element_type, shape_spec.dims)); + if (device_layouts == std::nullopt) { + TF_ASSIGN_OR_RETURN(device_shape, + client->client() + ->backend() + .transfer_manager() + ->ChooseCompactLayoutForShape(device_shape)); + } else { + *device_shape.mutable_layout() = (*device_layouts)[i]; + } LocalDeviceState* local_device = device->local_device_state(); se::Stream* h2d_stream = local_device->host_to_device_stream(); TF_ASSIGN_OR_RETURN(auto buffer, AllocateDestinationBuffer( - compact_shape, device, local_device, h2d_stream, + device_shape, device, local_device, h2d_stream, /*is_uninitialized_create=*/true, client, definition_events.back(), memory_space)); // Get a temporary hold just so we can fish out a shared_ptr to the @@ -167,7 +184,7 @@ class AsyncHostToDeviceTransferManager return std::make_unique( std::move(buffers), std::move(buffer_ptrs), - std::move(definition_events), device); + std::move(definition_events), std::move(device_shapes), device); } AsyncHostToDeviceTransferManager( @@ -175,10 +192,12 @@ class AsyncHostToDeviceTransferManager absl::InlinedVector, 4> buffer_ptrs, absl::InlinedVector, 4> definition_events, + absl::InlinedVector device_shapes, PjRtStreamExecutorDevice* device) : buffers_(std::move(buffers)), buffer_ptrs_(std::move(buffer_ptrs)), definition_events_(std::move(definition_events)), + device_shapes_(std::move(device_shapes)), remaining_buffer_count_(buffer_ptrs_.size()), transfers_in_flight_(0), device_(device) { @@ -229,9 +248,6 @@ class AsyncHostToDeviceTransferManager TransferManager* transfer_manager = se_client->client()->backend().transfer_manager(); - TF_ASSIGN_OR_RETURN( - Shape compact_shape, - transfer_manager->ChooseCompactLayoutForShape(literal.shape())); std::shared_ptr buffer; { @@ -256,16 +272,6 @@ class AsyncHostToDeviceTransferManager } DCHECK_EQ(buffer->device_memory().size(), 1); - auto& buffer_memory = buffer->device_memory()[0]; - if (transfer_manager->GetByteSizeRequirement(compact_shape) != - buffer_memory.size()) { - return InvalidArgument( - "TransferLiteralToBuffer shape %s has size %lld " - "but buffer has size %lld", - ShapeUtil::HumanStringWithLayout(compact_shape), - transfer_manager->GetByteSizeRequirement(compact_shape), - buffer_memory.size()); - } ++transfers_in_flight_; } @@ -274,7 +280,7 @@ class AsyncHostToDeviceTransferManager // TODO(misard) assess if it would be preferable to introduce a heuristic to // put the transfer into the calling thread for small literals. auto transfer_h2d = [this, buffer_index, stream, transfer_manager, literal, - device_buffer = buffer.get(), compact_shape, + device_buffer = buffer.get(), local_device = std::move(device_->local_device_state()), on_done = std::move(on_done)]() mutable { @@ -285,7 +291,8 @@ class AsyncHostToDeviceTransferManager auto event = local_device->event_pool().AllocateEvent(stream->parent()); // Initiate linearization and transfer of the buffer on the stream. - ShapedBuffer buffer = device_buffer->AsShapedBuffer(compact_shape); + ShapedBuffer buffer = + device_buffer->AsShapedBuffer(device_shapes_[buffer_index]); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( stream, literal, buffer)); local_device->event_pool().ThenRecordEvent(stream, event.value()); @@ -449,6 +456,8 @@ class AsyncHostToDeviceTransferManager // corresponding buffer transfer has completed. absl::InlinedVector, 4> definition_events_ ABSL_GUARDED_BY(mu_); + // Device shapes for all buffers with either compact or custom layout. + const absl::InlinedVector device_shapes_; // Count of buffers that have not yet been fully transferred. size_t remaining_buffer_count_ ABSL_GUARDED_BY(mu_); // Count of transfers that have been started but have not yet called cleanup. @@ -544,22 +553,56 @@ absl::string_view StreamExecutorGpuClient::platform_version() const { absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtDevice* device) { + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) { auto* stream_executor_device = tensorflow::down_cast(device); return xla::AsyncHostToDeviceTransferManager::Create( - shapes, stream_executor_device, this, /*memory_space=*/nullptr); + shape_specs, std::move(device_layouts), stream_executor_device, this, + /*memory_space=*/nullptr); } absl::StatusOr> StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( - absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::Span shapes, PjRtDevice* device) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, device); +} + +absl::StatusOr> +StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) { CHECK_EQ(memory_space->devices().size(), 1); PjRtDevice* device = memory_space->devices()[0]; auto* stream_executor_device = tensorflow::down_cast(device); return xla::AsyncHostToDeviceTransferManager::Create( - shapes, stream_executor_device, this, memory_space); + shape_specs, std::move(device_layouts), stream_executor_device, this, + memory_space); +} + +absl::StatusOr> +StreamExecutorGpuClient::CreateBuffersForAsyncHostToDevice( + absl::Span shapes, PjRtMemorySpace* memory_space) { + absl::InlinedVector shape_specs; + shape_specs.reserve(shapes.size()); + for (const auto& shape : shapes) { + shape_specs.emplace_back(PjRtClient::ShapeSpec{ + shape.element_type(), + DimensionVector(shape.dimensions().begin(), shape.dimensions().end())}); + } + return CreateBuffersForAsyncHostToDevice( + shape_specs, /*device_layouts=*/std::nullopt, memory_space); } absl::StatusOr @@ -1013,7 +1056,8 @@ absl::StatusOr BuildDistributedDevices( ordinal_and_device.second->executor()->GetPlatform(); TF_ASSIGN_OR_RETURN( std::unique_ptr desc, - platform->DescriptionForDevice(ordinal_and_device.first)); + platform->DescriptionForDevice( + ordinal_and_device.second->local_hardware_id().value())); DeviceProto* device_proto = local_topology.add_devices(); device_proto->set_local_device_ordinal(ordinal_and_device.first); device_proto->set_name(desc->name()); @@ -1206,7 +1250,6 @@ absl::StatusOr> GetStreamExecutorGpuClient( auto host_memory_allocator = GetGpuHostAllocator(local_device_states.begin()->second->executor()); - std::vector> devices; auto gpu_run_options = std::make_unique(); if (options.enable_mock_nccl) { gpu_run_options->set_enable_mock_nccl_collectives(); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h index afb624b248f863..a481e9a59ea73d 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client.h @@ -207,11 +207,22 @@ class StreamExecutorGpuClient : public xla::PjRtStreamExecutorClient { int num_replicas, int num_partitions) const override; absl::string_view platform_version() const override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) override; absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtDevice* device) override; + absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) override; + absl::StatusOr> CreateBuffersForAsyncHostToDevice(absl::Span shapes, PjRtMemorySpace* memory_space) override; diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc index a664a11352a023..54bfaf5c4b61d0 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc @@ -35,9 +35,11 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "absl/time/clock.h" #include "absl/time/time.h" +#include "absl/types/span.h" #include "xla/client/xla_computation.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/pjrt/distributed/in_memory_key_value_store.h" @@ -57,8 +59,9 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" @@ -405,6 +408,54 @@ TEST(StreamExecutorGpuClientTest, ToLiteralAsync) { literal->Relayout(src_literal.shape().layout()).data()); } +TEST(StreamExecutorGpuClientTest, ToLiteralAsyncWithNonCompactLayout) { + TF_ASSERT_OK_AND_ASSIGN(auto client, + GetStreamExecutorGpuClient(GpuClientOptions())); + ASSERT_GE(client->addressable_devices().size(), 1); + + xla::Shape transposed_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( + xla::S32, {2, 3}, /*minor_to_major=*/{0, 1}); + xla::Literal src_literal = xla::LiteralUtil::CreateR2WithLayout( + {{3, 14, 25}, {36, 47, 58}}, transposed_shape.layout()); + + PjRtClient::ShapeSpec spec; + spec.element_type = src_literal.shape().element_type(); + spec.dims = DimensionVector(src_literal.shape().dimensions().begin(), + src_literal.shape().dimensions().end()); + TF_ASSERT_OK_AND_ASSIGN( + auto transfer_manager, + client->CreateBuffersForAsyncHostToDevice( + {spec}, + std::make_optional>( + {transposed_shape.layout()}), + client->addressable_devices()[0]->memory_spaces()[0])); + auto buffer = transfer_manager->RetrieveBuffer(0); + + absl::Mutex mu; + auto literal = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(buffer->on_device_shape())); + bool got_literal = false; + + TF_ASSERT_OK( + transfer_manager->TransferLiteralToBuffer(0, src_literal, [&]() {})); + + buffer->ToLiteral(literal.get()).OnReady([&](absl::Status s) { + absl::MutexLock l(&mu); + TF_ASSERT_OK(s); + got_literal = true; + }); + buffer.reset(); + + { + absl::MutexLock l(&mu); + mu.Await(absl::Condition(&got_literal)); + } + + ASSERT_TRUE(ShapeUtil::Compatible(src_literal.shape(), literal->shape())); + ASSERT_EQ(src_literal.data(), + literal->Relayout(src_literal.shape().layout()).data()); +} + TEST(StreamExecutorGpuClientTest, ToLiteralAsyncBeforeBufferReady) { TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(GpuClientOptions())); diff --git a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc index 22de6c126af4ab..ea9541ce8a03b1 100644 --- a/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc +++ b/third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_compiler.cc @@ -199,13 +199,15 @@ StreamExecutorGpuCompiler::Compile(CompileOptions options, #endif } -STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { - PjRtRegisterCompiler( #if TENSORFLOW_USE_ROCM - RocmName(), +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { + PjRtRegisterCompiler(RocmName(), + std::make_unique()); +}); #else - CudaName(), -#endif - std::make_unique()); +STREAM_EXECUTOR_REGISTER_MODULE_INITIALIZER(pjrt_register_se_gpu_compiler, { + PjRtRegisterCompiler(CudaName(), + std::make_unique()); }); +#endif } // namespace xla diff --git a/third_party/xla/xla/pjrt/host_callback_test.cc b/third_party/xla/xla/pjrt/host_callback_test.cc index f443b9f8bbb524..ef9d5d9ec70c59 100644 --- a/third_party/xla/xla/pjrt/host_callback_test.cc +++ b/third_party/xla/xla/pjrt/host_callback_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/pjrt/pjrt_client.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/pjrt/mlir_to_hlo.cc b/third_party/xla/xla/pjrt/mlir_to_hlo.cc index e97c1f18b391ae..58a6b48fdc23a3 100644 --- a/third_party/xla/xla/pjrt/mlir_to_hlo.cc +++ b/third_party/xla/xla/pjrt/mlir_to_hlo.cc @@ -50,6 +50,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/Passes.h" +#include "shardy/dialect/sdy/ir/utils.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/Register.h" #include "stablehlo/dialect/Serialization.h" @@ -126,6 +127,7 @@ absl::StatusOr> ParseMlirModuleString( registry.insert(); mlir::func::registerAllExtensions(registry); mlir::mhlo::registerAllMhloDialects(registry); + mlir::sdy::loadAllRequiredDialects(&context); mlir::stablehlo::registerAllDialects(registry); context.appendDialectRegistry(registry); diff --git a/third_party/xla/xla/pjrt/pjrt_api_test.cc b/third_party/xla/xla/pjrt/pjrt_api_test.cc index b6e13ca5d14e2c..8ee9e49451a99e 100644 --- a/third_party/xla/xla/pjrt/pjrt_api_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_api_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include #include "xla/pjrt/c/pjrt_c_api.h" #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/protobuf/error_codes.pb.h" namespace { diff --git a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc index 188e159419a2e3..4dbdd5d03af4cb 100644 --- a/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_c_api_client_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/pjrt/pjrt_client.h b/third_party/xla/xla/pjrt/pjrt_client.h index d39d40cf86f50f..e8607d23dd6709 100644 --- a/third_party/xla/xla/pjrt/pjrt_client.h +++ b/third_party/xla/xla/pjrt/pjrt_client.h @@ -31,8 +31,11 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" @@ -55,6 +58,7 @@ limitations under the License. #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" // API notes: // PjRt stands for "Pretty much Just another RunTime". @@ -490,6 +494,11 @@ struct PjRtPluginAttributes { // will eventually be able to make progress. class PjRtClient { public: + struct ShapeSpec { + PrimitiveType element_type; + DimensionVector dims; + }; + PjRtClient() = default; explicit PjRtClient(std::unique_ptr host_memory_for_device_manager) @@ -743,6 +752,32 @@ class PjRtClient { virtual void AddTransferMetadata(const TransferMetadata& metadata) = 0; }; + // Returns a manager for async transfers into a set of buffers with on-host + // shapes defined by 'shape_specs' and optional `device_layouts`. The + // `device_layout` is used when non-compact layouts are preferred. + virtual absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtDevice* device) { + return absl::UnimplementedError(absl::StrCat( + "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " + "not implemented on platform: ", + platform_name())); + } + + // Variant of CreateBuffersForAsyncHostToDevice with PjRtMemorySpace. + virtual absl::StatusOr> + CreateBuffersForAsyncHostToDevice( + absl::Span shape_specs, + std::optional> device_layouts, + PjRtMemorySpace* memory_space) { + return absl::UnimplementedError(absl::StrCat( + "CreateBuffersForAsyncHostToDevice with ShapeSpec and Layout is " + "not implemented on platform: ", + platform_name())); + } + // Returns a manager for async transfers into a set of buffers with on-host // shapes 'shapes'. virtual absl::StatusOr> diff --git a/third_party/xla/xla/pjrt/pjrt_device_description.h b/third_party/xla/xla/pjrt/pjrt_device_description.h index ed852699e404c5..77107fdc495c71 100644 --- a/third_party/xla/xla/pjrt/pjrt_device_description.h +++ b/third_party/xla/xla/pjrt/pjrt_device_description.h @@ -20,12 +20,35 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/pjrt/pjrt_common.h" namespace xla { using PjRtDeviceAttribute = PjRtValueType; +class PjRtMemorySpaceDescription { + public: + PjRtMemorySpaceDescription(absl::string_view kind, int kind_id) + : kind_(kind), kind_id_(kind_id) {} + + // A platform-dependent string that uniquely identifies the kind of the + // memory space. + absl::string_view kind() const { return kind_; } + + // An ID uniquely identifies the kind of the memory space among those attached + // to the same `PjRtClient`. The IDs assigned to a kind is implementation + // specific. + int kind_id() const { return kind_id_; } + + private: + absl::string_view kind_; + int kind_id_; +}; + class PjRtDeviceDescription { public: virtual ~PjRtDeviceDescription() = default; @@ -60,6 +83,19 @@ class PjRtDeviceDescription { // reference will remain valid for the lifetime of the PjRtDevice. virtual const absl::flat_hash_map& Attributes() const = 0; + + // Returns all memory spaces attached to this device. + // The memory spaces are in no particular order. + virtual absl::Span memory_spaces() + const { + return {}; + } + + // Returns the default memory space attached to this device. + virtual absl::StatusOr + default_memory_space() const { + return absl::UnimplementedError("default_memory_space Not implemented."); + } }; } // namespace xla diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc index 9d3c820550a3b7..7204b158f339f7 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc @@ -3524,10 +3524,7 @@ PjRtStreamExecutorClient::Compile(mlir::ModuleOp module, TF_RETURN_IF_ERROR(MlirToXlaComputation( module, xla_computation, /*use_tuple_args=*/options.parameter_is_tupled_arguments, - /*return_tuple=*/false, - exec_build_options.has_debug_options() - ? exec_build_options.debug_options().xla_use_shardy() - : false)); + /*return_tuple=*/false, exec_build_options.use_shardy_partitioner())); // If the compile options specify argument layout, then let's // fall back to using the options to determine layouts. diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h index 4e8595591c8356..1fb61152bdab61 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client.h @@ -33,6 +33,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" diff --git a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc index d34d5c3c54740f..2fa381df57290a 100644 --- a/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc +++ b/third_party/xla/xla/pjrt/pjrt_stream_executor_client_test.cc @@ -34,8 +34,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -46,10 +46,8 @@ absl::StatusOr> GetClient() { LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie(); TF_ASSIGN_OR_RETURN(se::Platform * platform, PlatformUtil::GetPlatform("Host")); - se::StreamExecutorConfig config; - config.ordinal = 0; TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); + platform->ExecutorForDevice(0)); auto device_state = std::make_unique( executor, local_client, LocalDeviceState::kSynchronous, /*max_inflight_computations=*/32, diff --git a/third_party/xla/xla/pjrt/tf_pjrt_client.h b/third_party/xla/xla/pjrt/tf_pjrt_client.h index 363c0526f0ba56..c4299d37b0e0fe 100644 --- a/third_party/xla/xla/pjrt/tf_pjrt_client.h +++ b/third_party/xla/xla/pjrt/tf_pjrt_client.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef XLA_PJRT_TF_PJRT_CLIENT_H_ #define XLA_PJRT_TF_PJRT_CLIENT_H_ +#include +#include #include #include #include @@ -26,9 +28,26 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" +#include "absl/types/span.h" +#include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/layout.h" +#include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_common.h" +#include "xla/pjrt/pjrt_compiler.h" +#include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_future.h" +#include "xla/service/computation_placer.h" +#include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/util.h" +#include "tsl/platform/casts.h" #include "tsl/platform/errors.h" namespace xla { diff --git a/third_party/xla/xla/protobuf_util.cc b/third_party/xla/xla/protobuf_util.cc index a8d6dfa15a2a32..4c6815d9396491 100644 --- a/third_party/xla/xla/protobuf_util.cc +++ b/third_party/xla/xla/protobuf_util.cc @@ -49,20 +49,5 @@ size_t ProtobufHash(const tsl::protobuf::Message& m) { return absl::HashOf(serialized); } -absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, - const std::string& directory, - const std::string& file_name, - std::string* full_path) { - tsl::Env* env = tsl::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - std::string safe_file_name = SanitizeFileName(file_name) + ".pb"; - std::string full_path_impl; - if (!full_path) { - full_path = &full_path_impl; - } - *full_path = tsl::io::JoinPath(directory, safe_file_name); - return tsl::WriteBinaryProto(env, *full_path, message); -} - } // namespace protobuf_util } // namespace xla diff --git a/third_party/xla/xla/protobuf_util.h b/third_party/xla/xla/protobuf_util.h index 79f00773fb07a0..81f795287c17d8 100644 --- a/third_party/xla/xla/protobuf_util.h +++ b/third_party/xla/xla/protobuf_util.h @@ -55,17 +55,6 @@ class ProtobufHashWrapper { return ProtobufHash(m); } }; -// Writes the given message in binary proto to the path formed by joining -// 'directory/file_name.pb'. The 'directory' is recursively created if it -// doesn't already exist, and the 'file_name' is sanitized by replacing -// illegal characters with underscore '_'. -// -// If 'full_name' is not null then it is set to the name of the file the -// protobuf was written to. -absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, - const std::string& directory, - const std::string& file_name, - std::string* full_path = nullptr); // Registers a function that may either expand a dirpath or forward the original // dirpath along as-is. diff --git a/third_party/xla/xla/python/BUILD b/third_party/xla/xla/python/BUILD index d887f10f58ac8c..d9c04626523e31 100644 --- a/third_party/xla/xla/python/BUILD +++ b/third_party/xla/xla/python/BUILD @@ -3,10 +3,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library") load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library") -load( - "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", - "if_cuda_is_configured", -) load("//xla:pytype.default.bzl", "pytype_strict_library") load("//xla:strict.default.bzl", "py_strict_library", "py_strict_test") load( @@ -368,7 +364,9 @@ cc_library( "@com_google_absl//absl/types:variant", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@nanobind", + "@shardy//shardy/dialect/sdy/ir:dialect", "@local_config_python//:python_headers", # buildcleaner: keep "//xla:comparison_util", "//xla:literal", @@ -410,8 +408,12 @@ cc_library( "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", "//xla/service:platform_util", + "//xla/service/spmd/shardy:constants", + "//xla/service/spmd/shardy:utils", + "//xla/service/spmd/shardy/sdy_round_trip:pipelines", "//xla/tsl/concurrency:ref_count", "//xla/tsl/framework:allocator", + "//xla/tsl/framework/mlir:status_scoped_diagnostic_handler", "//xla/tsl/python/lib/core:numpy", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:errors", @@ -425,9 +427,9 @@ cc_library( "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ] + if_cuda([ "@local_config_cuda//cuda:cuda_headers", - # TODO(b/324133505): remove this dependency after JAX OSS migrates to cuda plugin. - "//xla/stream_executor/cuda:cuda_driver", ]) + if_rocm([ + # keep sorted + "@local_config_rocm//rocm:hip", "@local_config_rocm//rocm:rocm_headers", ]) + if_cuda_or_rocm([ ":py_client_gpu", # TODO(b/337876408): remove after migration to plugin @@ -752,11 +754,13 @@ cc_library( "@com_google_absl//absl/types:span", "@nanobind", "@local_config_python//:python_headers", # buildcleaner: keep + "//xla:shape_util", "//xla:util", "//xla/pjrt:exceptions", "//xla/pjrt:lru_cache", "//xla/pjrt:pjrt_client", "//xla/pjrt:pjrt_future", + "//xla/pjrt:pjrt_layout", "//xla/pjrt:status_casters", "//xla/python/ifrt", "//xla/tsl/concurrency:ref_count", @@ -1045,12 +1049,7 @@ cc_library( "@local_tsl//tsl/profiler/rpc:profiler_server_impl", "@local_tsl//tsl/profiler/rpc/client:capture_profile", "@local_tsl//tsl/profiler/rpc/client:profiler_client_impl", - ] + select({ - ":gpu_enabled": [ - "//xla/backends/profiler/gpu:device_tracer", - ], - "//conditions:default": [], - }), + ], ) cc_library( @@ -1173,7 +1172,7 @@ cc_library( "//xla/service:hlo_proto_cc", "//xla/service:name_uniquer", "//xla/service:tuple_simplifier", - "@local_tsl//tsl/lib/strings:proto_serialization", + "//xla/tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -1187,26 +1186,6 @@ tf_proto_library( cc_api_version = 2, ) -# TODO(phawkins): the configuration settings here are overly confusing. The right fix is to split -# xla_extension.so so that each backend is a separate plugin, however that must wait for a clean -# ABI separation between devices. -config_setting( - name = "link_gpu_plugin", - define_values = {"xla_python_enable_gpu": "true"}, -) - -bool_flag( - name = "enable_gpu", - build_setting_default = True, -) - -config_setting( - name = "gpu_enabled", - flag_values = { - ":enable_gpu": "True", - }, -) - # If this flag is enabled, it sets RPATH on the xla_extension to values that are suitable for # finding NVIDIA's CUDA libraries when they are installed as pip packages. bool_flag( @@ -1221,17 +1200,6 @@ config_setting( }, ) -# We cannot nest select and if_cuda_is_configured so we introduce -# a standalone cc_library target. -cc_library( - name = "gpu_plugin_deps", - deps = [ - "//xla/service:gpu_plugin", - ] + if_cuda_is_configured([ - "//xla/stream_executor:cuda_platform", - ]), -) - cc_library( name = "logging", srcs = ["logging.cc"], @@ -1294,13 +1262,6 @@ tsl_pybind_extension( "-fexceptions", "-fno-strict-aliasing", ], - defines = if_google( - [], - select({ - ":gpu_enabled": ["XLA_PYTHON_ENABLE_GPU=1"], - "//conditions:default": [], - }), - ), features = ["-use_header_modules"], linkopts = select({ ":use_jax_cuda_pip_rpaths": [ @@ -1399,8 +1360,12 @@ tsl_pybind_extension( "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform/cloud:gcs_file_system", ] + select({ - # gloo transport only builds on linux - "//xla/tsl:macos": [], + # gloo tcp transport only builds on linux + "//xla/tsl:macos": [ + "//xla/pjrt/cpu:gloo_collectives", + "//xla/pjrt/cpu:gloo_kv_store", + "@gloo//:transport_uv", + ], "//xla/tsl:windows": [], "//conditions:default": [ "//xla/pjrt/cpu:gloo_collectives", @@ -1413,20 +1378,7 @@ tsl_pybind_extension( "//conditions:default": [ "//xla/pjrt/cpu:mpi_collectives", ], - }) + if_google( - [], - select({ - ":gpu_enabled": [ - ":gpu_support", - ], - "//conditions:default": [], - }) + select({ - ":link_gpu_plugin": [ - ":gpu_plugin_deps", - ], - "//conditions:default": [], - }), - ), + }), ) cc_library( diff --git a/third_party/xla/xla/python/ifrt/BUILD b/third_party/xla/xla/python/ifrt/BUILD index b3277fc964809d..6b751b0b079533 100644 --- a/third_party/xla/xla/python/ifrt/BUILD +++ b/third_party/xla/xla/python/ifrt/BUILD @@ -117,6 +117,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@local_tsl//tsl/lib/gtl:int_type", @@ -169,10 +170,10 @@ xla_cc_test( srcs = ["future_test.cc"], deps = [ ":ifrt", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", ], ) @@ -249,11 +250,11 @@ cc_library( deps = [ ":ifrt", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -295,10 +296,10 @@ cc_library( ":ifrt", ":test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -350,8 +351,8 @@ cc_library( ":ifrt", ":test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -576,11 +577,11 @@ cc_library( ":test_util", "//xla:status_macros", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -640,8 +641,8 @@ xla_cc_test( ":plugin_program_serdes", ":serdes", ":serdes_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_tsl//tsl/protobuf:status_proto_cc", @@ -700,11 +701,11 @@ xla_cc_test( ":ifrt", ":program_serdes", ":serdes", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc index 6d5c073c5e29ed..d5f83c8c070eb5 100644 --- a/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/array_impl_test_lib.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/python/ifrt/value.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc index 31a259378695cc..332314a3b3d93d 100644 --- a/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/custom_call_program_serdes_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/shape.h" #include "xla/python/ifrt/sharding.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/ifrt/future_test.cc b/third_party/xla/xla/python/ifrt/future_test.cc index 650f4849c0db1f..808d9a4981494a 100644 --- a/third_party/xla/xla/python/ifrt/future_test.cc +++ b/third_party/xla/xla/python/ifrt/future_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/types/span.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace xla { diff --git a/third_party/xla/xla/python/ifrt/ir/constants.h b/third_party/xla/xla/python/ifrt/ir/constants.h index 27e9d11fb6a1cf..26f8a7e999dd52 100644 --- a/third_party/xla/xla/python/ifrt/ir/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/constants.h @@ -44,6 +44,15 @@ inline constexpr llvm::StringLiteral kIfrtLocalViewAttrName = "ifrt.local_view"; inline constexpr llvm::StringLiteral kIfrtCompileOptionsKey = "ifrt.compile_options_key"; +inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices"; +inline constexpr llvm::StringLiteral kIfrtNumDevicesAttrName = + "ifrt.num_devices"; +inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding"; +inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = + "ifrt.entry_function"; + +inline constexpr llvm::StringLiteral kCalleeMainFuncName = "main"; + } // namespace ifrt } // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td index e46ff35490c2a1..a430bcb38f1b41 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_dialect.td @@ -95,7 +95,7 @@ def Ifrt_UnspecifiedShardingAttr : AttrDef { let mnemonic = "interval"; - let summary = [{ + let description = [{ Half-open interval attribute using the Python slice format `[start:end:step]`. Reverse iteration is not supported for simplicity. Therefore, `start` and `end` must be zero or positive, and `step` @@ -133,7 +133,7 @@ def Ifrt_MappingAttrArrayAttr : def Ifrt_ArrayMappingAttr : AttrDef { let mnemonic = "array_mapping"; - let summary = [{ + let description = [{ Mapping of shards from an input array to an output array. The shards are chosen from input array with index `in_array_index` and are used to assemble the output array with index `out_array_index`. diff --git a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc index 6821c7ef25767d..080f0faf76e725 100644 --- a/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc +++ b/third_party/xla/xla/python/ifrt/ir/ifrt_ops.cc @@ -182,6 +182,46 @@ struct IoAlias { int output_index; }; +mlir::LogicalResult VerifyElementTypeAndPerShardShapeAreEqual( + mlir::Operation* op, IfrtArrayType in, int in_index, IfrtArrayType out, + int out_index) { + if (in.getShape().getElementType() != out.getShape().getElementType()) { + return op->emitOpError() + << "can't alias input #" << in_index << " to output #" << out_index + << " with different element types: " << in << " vs " << out; + } + + absl::StatusOr> in_per_shard_shape = + in.getShardingAttr().LocalShapeFromGlobalShape(in.getShape().getShape()); + if (!in_per_shard_shape.ok()) { + return op->emitOpError() + << "unable to get per-shard shape of aliased input #" << in_index + << ": " << in_per_shard_shape.status().message(); + } + absl::StatusOr> out_per_shard_shape = + out.getShardingAttr().LocalShapeFromGlobalShape( + out.getShape().getShape()); + if (!out_per_shard_shape.ok()) { + return op->emitOpError() + << "unable to get per-shard shape of aliased output #" << out_index + << ": " << out_per_shard_shape.status().message(); + } + if (in_per_shard_shape->size() != out_per_shard_shape->size()) { + return op->emitOpError() + << "can't alias input #" << in_index << " to output #" << out_index + << " with different per-shard shapes: " << in << " vs " << out; + } + for (const auto& [in_dim, out_dim] : + llvm::zip(*in_per_shard_shape, *out_per_shard_shape)) { + if (in_dim != out_dim) { + return op->emitOpError() + << "can't alias input #" << in_index << " to output #" << out_index + << " with different per-shard shapes: " << in << " vs " << out; + } + } + return mlir::success(); +} + mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias, llvm::ArrayRef inputs, llvm::ArrayRef outputs) { @@ -198,11 +238,12 @@ mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias, << " outputs"; } if (inputs[io_alias.input_index] != outputs[io_alias.output_index]) { - return op->emitOpError() - << "can't alias input #" << io_alias.input_index << " to output #" - << io_alias.output_index - << " with different types: " << inputs[io_alias.input_index] - << " vs " << outputs[io_alias.output_index]; + // TODO(icgog): Relax this aliasing check to allow for different per-shard + // shapes as long as the byte size is the same. We cannot do this now + // because we do not have layout information. + return VerifyElementTypeAndPerShardShapeAreEqual( + op, inputs[io_alias.input_index], io_alias.input_index, + outputs[io_alias.output_index], io_alias.output_index); } return mlir::success(); } diff --git a/third_party/xla/xla/python/ifrt/ir/tests/BUILD b/third_party/xla/xla/python/ifrt/ir/tests/BUILD index e068e2b1fb008c..01ca1bff5c92e8 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/tests/BUILD @@ -12,6 +12,7 @@ lit_test_suite( [ "ifrt_duplicated_callee_elimination.mlir", "ifrt_merge_reshards.mlir", + "ifrt_outline_atom_program_to_module.mlir", "ifrt_verify_donation.mlir", "ifrt_verify_sharding_specified.mlir", "spmd_expansion.mlir", @@ -97,10 +98,10 @@ cc_library( "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc index dc5330728d067c..95d04b081ee1be 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/ir/tests/executable_impl_test_lib.cc @@ -37,7 +37,7 @@ limitations under the License. #include "xla/python/pjrt_ifrt/xla_compiler.h" #include "xla/service/computation_placer.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir new file mode 100644 index 00000000000000..c963b4ccb7a604 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_outline_atom_program_to_module.mlir @@ -0,0 +1,247 @@ +// RUN: ifrt-opt %s -ifrt-outline-atom-program-to-module -split-input-file -verify-diagnostics | FileCheck %s + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @call_hlo +module @call_hlo { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + {ifrt.compile_options_key = "fake_compile_options_key"} + : (!array) -> !array + return %0 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_share_a_module +module @calls_share_a_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%[[OUTPUT]]) + %1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [0,1] : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @calls_with_ctrl_dep_share_a_module +module @calls_with_ctrl_dep_share_a_module { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUTPUT:.+]], %[[CTRL_0:.+]] = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%[[OUTPUT]]) after %[[CTRL_0]] + %1, %ctrl_1 = ifrt.Call @add_one(%0) after %ctrl_0 on devices [0,1] + : (!array) -> !array + return %1 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array_unspecified = !ifrt.array, + #ifrt.sharding_unspecified, [0,1]> +// CHECK-LABEL: @call_with_diff_sharding_share_a_module +module @call_with_diff_sharding_share_a_module { + func.func @main(%arg0: !array) -> !array_unspecified + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1] + : (!array) -> !array + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%[[OUT_0]]) + %1, %ctrl_1 = ifrt.Call @add_one(%0) on devices [0, 1] + : (!array) -> !array_unspecified + // CHECK: return %[[OUT_1]] + return %1 : !array_unspecified + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [2,3]> + +// CHECK-LABEL: @call_with_diff_devices_share_a_module +module @call_with_diff_devices_share_a_module { + func.func @main(%arg0: !array0, %arg1: !array1) -> (!array0, !array1) + attributes {ifrt.function} { + // CHECK: %[[OUT_0:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0, 1] + : (!array0) -> !array0 + // CHECK: %[[OUT_1:.+]], %{{.+}} = ifrt.Call @[[MODULE:.+]]::@main(%arg1) + %1, %ctrl_1 = ifrt.Call @add_one(%arg1) on devices [2, 3] + : (!array1) -> !array1 + // CHECK: return %[[OUT_0]], %[[OUT_1]] + return %0, %1 : !array0, !array1 + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +// CHECK-LABEL: @shared_func_is_cloned +module @shared_func_is_cloned { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: %[[OUT:.+]], %{{.+}} = ifrt.Call @[[MODULE1:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [0,1] + : (!array) -> !array + // CHECK: ifrt.Call @[[MODULE2:.+]]::@main(%[[OUT]]) + %1, %ctrl_1 = ifrt.Call @add_two(%0) on devices [0,1] : (!array) -> !array + return %1 : !array + } + + func.func private @add_one_internal(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + // CHECK: module @[[MODULE1]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + // CHECK: func.func private @add_one_internal + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = func.call @add_one_internal(%arg0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + return %0 : tensor<2x2xi32> + } + + // CHECK: module @[[MODULE2]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + // CHECK: func.func private @add_one_internal + func.func private @add_two(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = func.call @add_one_internal(%arg0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + %1 = func.call @add_one_internal(%0) : (tensor<2x2xi32>) -> (tensor<2x2xi32>) + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +// CHECK-LABEL: @callee_with_symbol +module @callee_with_symbol { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + // CHECK: ifrt.Call @[[MODULE:.+]]::@main(%arg0) + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] + : (!array) -> !array + return %0 : !array + } + + // CHECK: module @[[MODULE]] + // CHECK: attributes {sym_visibility = "private"} + // CHECK: func.func @main + func.func private @add_one(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<2> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 {attr_sym = @add_two}: tensor<2x2xi32> + return %0 : tensor<2x2xi32> + } + + // CHECK: func.func private @add_two + // CHECK-NEXT: mhlo.constant + // CHECK-NEXT: mhlo.add + func.func private @add_two(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + %0 = mhlo.constant dense<2> : tensor<2x2xi32> + %1 = mhlo.add %arg0, %0 : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +module @unknown_symbol_in_callee { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] : (!array) -> !array + return %0 : !array + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + // expected-error @+1 {{'mhlo.add' op uses a symbol in attributes `unknown` that does not exist in the ModuleOp}} + %1 = mhlo.add %arg0, %0 {f = @unknown} : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 1>, [2]> +module @wrong_type_for_symbol_in_callee { + func.func @main(%arg0: !array) -> !array attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @add_one(%arg0) on devices [2] : (!array) -> !array + return %0 : !array + } + + func.func private @add_one(%arg0: tensor<2x2xi32>) -> (tensor<2x2xi32>) { + %0 = mhlo.constant dense<1> : tensor<2x2xi32> + // expected-error @+1 {{'mhlo.add' op uses a symbol in attributes `a_module` that is not a FuncOp. Cannot handle such cases for now}} + %1 = mhlo.add %arg0, %0 {f = @a_module} : tensor<2x2xi32> + return %1 : tensor<2x2xi32> + } + + module @a_module {} +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 3f6050206495cf..8c70318c03598c 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -46,7 +46,7 @@ module @donate_to_two_calls_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array - // expected-error @+1 {{'ifrt.Call' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array return %0, %1 : !array, !array @@ -78,7 +78,7 @@ module @program_arg_not_donated_error { module @arg_both_donated_and_not_donated_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> !array0 attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.Call' op input #0 is both donated and not donated.}} + // expected-error @+1 {{'ifrt.Call' op input #0 of @add_two_args was already donated}} %0, %ctrl_0 = ifrt.Call @add_two_args(%arg0, %arg0) on devices [0,1] {io_aliases=[array]} : (!array0, !array0) -> !array0 return %0 : !array0 @@ -101,7 +101,23 @@ module @donate_to_two_reshards_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 - // expected-error @+1 {{'ifrt.Reshard' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} + %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 + return %0, %1 : !array1, !array1 + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @donate_to_two_reshards_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array1, !array1 } @@ -118,7 +134,7 @@ module @donate_to_reshard_and_call_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array0) -> !array0 - // expected-error @+1 {{'ifrt.Reshard' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} %1, %ctrl_1 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array0, !array1 } @@ -138,7 +154,7 @@ module @donate_to_two_copy_arrays_error { func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array1, !array1) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 - // expected-error @+1 {{'ifrt.CopyArrays' op input #0 already donated.}} + // expected-error @+1 {{'ifrt.CopyArrays' op input #0 of op}} %1, %ctrl_1 = ifrt.CopyArrays(%arg0) {donated=true} : (!array0) -> !array1 return %0, %1 : !array1, !array1 } @@ -169,7 +185,7 @@ module @donate_to_reshard_and_call_error { attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] {io_aliases=[array]} : (!array) -> !array - // expected-error @+1 {{'ifrt.RemapArrays' op input #1 already donated.}} + // expected-error @+1 {{'ifrt.RemapArrays' op input #1 of op}} %1 = ifrt.RemapArrays(%0, %arg0) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] @@ -181,3 +197,82 @@ module @donate_to_reshard_and_call_error { return %arg0 : tensor<2xi32> } } + +// ----- + +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @call_after_donation_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array) -> !array + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} + %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + : (!array) -> !array + return %0, %1 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @reshard_with_already_donated_array_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array0, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array0) -> !array0 + // expected-error @+1 {{'ifrt.Reshard' op input #0 of op}} + %1, %ctrl_1 = ifrt.Reshard(%arg0) : (!array0) -> !array1 + return %0, %1 : !array0, !array1 + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array0 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> +module @copy_arrays_with_already_donated_array_error { + func.func @main(%arg0: !array0 {ifrt.donated}) -> (!array0, !array1) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array0) -> !array0 + // expected-error @+1 {{'ifrt.CopyArrays' op input #0 of op}} + %1, %ctrl_1 = ifrt.CopyArrays(%arg0) : (!array0) -> !array1 + return %0, %1 : !array0, !array1 + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @copy_arrays_with_already_donated_array_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array) -> !array + // expected-error @+1 {{'func.return' op result #1 of op at}} + return %0, %arg0 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} diff --git a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir index 28a1dda2b3f77d..4fef0876dc8bb8 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/spmd_expansion.mlir @@ -3,7 +3,7 @@ #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> // CHECK-LABEL: @identity_axis0_sharded -module @identity_axis0_sharded attributes {ifrt.devices = #device} { +module @identity_axis0_sharded attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-NEXT: return %[[ARG]] @@ -23,7 +23,7 @@ module @identity_axis0_sharded attributes {ifrt.devices = #device} { #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // CHECK-LABEL: @identity_axis1_sharded module @identity_axis1_sharded - attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { + attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} { // CHECK-NEXT: func.func @entry_func // CHECK-SAME: %[[ARG:.*]]: tensor<2x1xi32> // CHECK-NEXT: return %[[ARG]] @@ -42,7 +42,7 @@ module @identity_axis1_sharded #device = #ifrt #sharding = #ifrt.sharding_param<3x2 to [1,0] on 2x3> // CHECK-LABEL: @identify_both_axes_sharded -module @identify_both_axes_sharded attributes {ifrt.devices = #device} { +module @identify_both_axes_sharded attributes {ifrt.num_devices = 6} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x1xi32> // CHECK-NEXT: return %[[ARG]] @@ -60,7 +60,7 @@ module @identify_both_axes_sharded attributes {ifrt.devices = #device} { #device = #ifrt // CHECK-LABEL: @with_func_call -module @with_func_call attributes {ifrt.devices = #device} { +module @with_func_call attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-SAME: tensor<1x2xi32> @@ -94,7 +94,7 @@ module @with_func_call attributes {ifrt.devices = #device} { #device = #ifrt // CHECK-LABEL: @with_nested_func_call -module @with_nested_func_call attributes {ifrt.devices = #device} { +module @with_nested_func_call attributes {ifrt.num_devices = 2} { // CHECK-NEXT: func.func @main // CHECK-SAME: %[[ARG:.*]]: tensor<1x2xi32> // CHECK-SAME: tensor<1x2xi32> @@ -139,11 +139,10 @@ module @with_nested_func_call attributes {ifrt.devices = #device} { // ----- -#device = #ifrt #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `main`}} module @missing_main_function - attributes {ifrt.devices = #device} { + attributes {ifrt.num_devices = 2} { } // ----- @@ -152,7 +151,7 @@ module @missing_main_function #sharding = #ifrt.sharding_param<1x2 to [0] on 2> // expected-error@+1 {{cannot find entry function `entry_func`}} module @missing_entry_function - attributes {ifrt.devices = #device, ifrt.entry_function = "entry_func"} { + attributes {ifrt.num_devices = 2, ifrt.entry_function = "entry_func"} { func.func @main( %arg0: tensor<2x2xi32> {ifrt.sharding = #sharding, ifrt.devices = #device}) @@ -166,7 +165,7 @@ module @missing_entry_function #device = #ifrt #sharding = #ifrt.sharding_param<2x1 to [0] on 2> -module @non_divisible_global_shape attributes {ifrt.devices = #device} { +module @non_divisible_global_shape attributes {ifrt.num_devices = 2} { // expected-error@+1 {{Global shape is not divisible by the number of shards in dimension 0. Global size: 3, number of shards: 2}} func.func @main( %arg0: tensor<3x2xi32> {ifrt.sharding = #sharding, diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir index b318756983e497..e512b260600e73 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call.mlir @@ -355,11 +355,32 @@ func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) // ----- -func.func @io_aliases_should_have_same_type( +!array0 = !ifrt.array, + #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> +!array1 = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @io_aliases_of_different_type_but_same_per_shard_shape(%arg0: !array0) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] + {io_aliases=[array]} : (!array0) -> !array1 + return +} + +func.func @callee(%arg0: tensor<2x1xi32>) -> tensor<1x1xi32> { + %0 = mhlo.constant dense<-2147483648> : tensor + %1 = mhlo.reduce(%arg0 init: %0) applies mhlo.maximum across dimensions = [0, 1] + : (tensor<2x1xi32>, tensor) -> tensor + %2 = mhlo.reshape %1 : (tensor) -> tensor<1x1xi32> + return %2 : tensor<1x1xi32> +} + +// ----- + +func.func @io_aliases_should_alias_arrays_with_same_per_shard_shape( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different types: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} + // expected-error@+1 {{'ifrt.Call' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} %0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index d5228419cb7c0f..14485f4c86a4e0 100644 --- a/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/third_party/xla/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -215,7 +215,7 @@ func.func @io_aliases_should_have_same_type( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different types: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 to output #0 with different per-shard shapes: '!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0, 1]>' vs '!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0, 1]>'}} %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD index ccd1919e3ccf5d..620362de4c1b50 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/BUILD +++ b/third_party/xla/xla/python/ifrt/ir/transforms/BUILD @@ -31,6 +31,7 @@ cc_library( srcs = [ "ifrt_duplicated_callee_elimination_pass.cc", "ifrt_merge_reshards_pass.cc", + "ifrt_outline_atom_program_to_module_pass.cc", "ifrt_verify_donation_pass.cc", "ifrt_verify_sharding_specified_pass.cc", "spmd_expandable_interface_verification_pass.cc", @@ -39,8 +40,8 @@ cc_library( hdrs = ["passes.h"], compatible_with = get_compatible_with_portable(), deps = [ - ":constants", ":passes_inc_gen", + ":utils", "//xla/python/ifrt/ir", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log:check", @@ -67,8 +68,14 @@ cc_library( ) cc_library( - name = "constants", - hdrs = ["constants.h"], + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], compatible_with = get_compatible_with_portable(), - deps = ["@llvm-project//llvm:Support"], + deps = [ + "@com_google_absl//absl/log:check", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], ) diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc new file mode 100644 index 00000000000000..3074e67aebcaa3 --- /dev/null +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_outline_atom_program_to_module_pass.cc @@ -0,0 +1,181 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "xla/python/ifrt/ir/constants.h" +#include "xla/python/ifrt/ir/ifrt_ops.h" +#include "xla/python/ifrt/ir/transforms/utils.h" + +namespace xla { +namespace ifrt { + +namespace { + +#define GEN_PASS_DEF_IFRTOUTLINEATOMPROGRAMTOMODULEPASS +#include "xla/python/ifrt/ir/transforms/passes.h.inc" + +class IfrtOutlineAtomProgramToModulePass + : public impl::IfrtOutlineAtomProgramToModulePassBase< + IfrtOutlineAtomProgramToModulePass> { + public: + using impl::IfrtOutlineAtomProgramToModulePassBase< + IfrtOutlineAtomProgramToModulePass>:: + IfrtOutlineAtomProgramToModulePassBase; + + void runOnOperation() override; +}; + +void IfrtOutlineAtomProgramToModulePass::runOnOperation() { + mlir::SymbolTableCollection symbol_table; + mlir::OpBuilder builder(&getContext()); + llvm::DenseSet visited; + llvm::SmallVector to_erase; + mlir::ModuleOp module_op = getOperation(); + mlir::func::FuncOp main_func = GetMainFunction(module_op); + auto result = + main_func.walk([&](xla::ifrt::CallOp call_op) -> mlir::WalkResult { + // Maybe visited by a previous CallOp with the same callee. + if (visited.contains(call_op)) { + return mlir::WalkResult::advance(); + } + + // Find the callee. + mlir::func::FuncOp callee = call_op.getCalleeOp(symbol_table); + if (callee.getSymName() == kCalleeMainFuncName && + llvm::isa(callee->getParentOp())) { + // Atom program is already outlined in module. Do nothing. + return mlir::WalkResult::advance(); + } + + // Create a ModuleOp and clone callee into it. + builder.setInsertionPointAfter(callee); + auto callee_module = builder.create( + callee->getLoc(), callee.getSymName()); + callee_module.setVisibility(mlir::SymbolTable::Visibility::Private); + + mlir::func::FuncOp cloned_callee; + // Find all symbols directly or indirectly referenced by callee and copy + // them to the newly created module. + { + // Setup for DFS. + llvm::DenseSet visited_funcs; + llvm::SmallVector func_stack = {callee}; + while (!func_stack.empty()) { + mlir::func::FuncOp current_func = func_stack.back(); + func_stack.pop_back(); + if (!visited_funcs.insert(current_func).second) { + continue; + } + + // Copy function into the new module. + mlir::func::FuncOp cloned_func = + llvm::cast(current_func->clone()); + if (current_func == callee) { + cloned_callee = cloned_func; + cloned_func.setSymName(kCalleeMainFuncName); + cloned_func.setVisibility(mlir::SymbolTable::Visibility::Public); + } + builder.setInsertionPointToEnd(callee_module.getBody()); + builder.insert(cloned_func); + + // Check all symbols in function. + std::optional sym_uses = + mlir::SymbolTable::getSymbolUses(current_func); + if (!sym_uses.has_value()) { + continue; + } + for (const mlir::SymbolTable::SymbolUse& sym_use : *sym_uses) { + // Ensure the symbol represents a function. + mlir::Operation* sym_op = module_op.lookupSymbol( + sym_use.getSymbolRef().getRootReference()); + if (sym_op == nullptr) { + return sym_use.getUser()->emitOpError() + << "uses a symbol in attributes `" + << sym_use.getSymbolRef().getRootReference().str() + << "` that does not exist in the ModuleOp."; + } + auto func = llvm::dyn_cast(sym_op); + if (func == nullptr) { + return sym_use.getUser()->emitOpError() + << "uses a symbol in attributes `" + << sym_use.getSymbolRef().getRootReference().str() + << "` that is not a FuncOp. Cannot handle such cases " + "for now."; + } + func_stack.push_back(func); + } + } + } + + // Replace all uses of old callee. + mlir::SymbolRefAttr new_symbol = mlir::SymbolRefAttr::get( + callee_module.getSymNameAttr(), + mlir::SymbolRefAttr::get(cloned_callee.getSymNameAttr())); + // It is sufficient to get the symbols in the main func because + // ifrt.Call nested within callees are not supported. + std::optional symbol_uses = + callee.getSymbolUses(main_func); + if (symbol_uses.has_value()) { + for (const mlir::SymbolTable::SymbolUse symbol_use : *symbol_uses) { + auto user = llvm::dyn_cast(symbol_use.getUser()); + if (user == nullptr) { + return symbol_use.getUser()->emitOpError() + << "requires symbol `" << callee.getSymName() + << "` only used by ifrt.Call. Found use by `" + << user.getOperationName() << "`"; + } + user.setCalleeAttr(new_symbol); + visited.insert(user); + } + } + + // Can't erase callee yet during iteration. + to_erase.push_back(callee); + return mlir::WalkResult::advance(); + }); + + if (result.wasInterrupted()) { + signalPassFailure(); + return; + } + for (mlir::Operation* op : to_erase) { + op->erase(); + } +} + +} // namespace + +std::unique_ptr> +CreateIfrtOutlineAtomProgramToModulePass() { + return std::make_unique(); +} + +} // namespace ifrt +} // namespace xla diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index be11610d826884..7e3492147e1665 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -16,10 +16,11 @@ limitations under the License. #include #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" @@ -62,10 +63,15 @@ class IfrtVerifyDonationPass }; void IfrtVerifyDonationPass::runOnOperation() { - mlir::ModuleOp module_op = getOperation(); - llvm::DenseSet donated_values; - mlir::WalkResult result = module_op.walk([&](mlir::Operation* op) - -> mlir::WalkResult { + mlir::func::FuncOp func_op = getOperation(); + // We only need to run this pass on IFRT functions. + if (!func_op->hasAttr(kIfrtFunctionAttrName) && + !func_op->hasAttr(kIfrtReshardFunctionAttrName)) { + return; + } + llvm::DenseMap donated_value_to_op; + mlir::WalkResult result = func_op.walk([&](mlir::Operation* op) + -> mlir::WalkResult { auto result = llvm::TypeSwitch(op) .Case( @@ -78,44 +84,74 @@ void IfrtVerifyDonationPass::runOnOperation() { io_alias.asArrayRef(); donated_input_idxs.insert(io_alias_as_array[0]); auto donated_value = op.getInputs()[io_alias_as_array[0]]; - if (!donated_values.insert(donated_value).second) { + auto donated_it = + donated_value_to_op.try_emplace(donated_value, op); + if (!donated_it.second) { op.emitOpError() << "input #" << io_alias_as_array[0] - << " already donated."; + << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it.first->second->getLoc(); return mlir::failure(); } - if (mlir::failed( VerifyIfInputAndDonated(op, donated_value))) { return mlir::failure(); } } - // Verify that an input is not both donated and not donated. + // Verify non-donated inputs after donated inputs have been + // added to also catch instances such as + // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { - if (donated_values.contains(input) && - !donated_input_idxs.contains(idx)) { - op.emitOpError() << "input #" << idx - << " is both donated and not donated."; - return mlir::failure(); + if (!donated_input_idxs.contains(idx)) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() + << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } } } return mlir::success(); }) .Case([&](auto& op) { + // Verify that no inputs have already been donated. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() + << "input #" << idx << " of op at " << op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } if (op.getDonated()) { - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (donated_values.contains(input)) { - op.emitOpError() << "input #" << idx << " already donated."; - return mlir::failure(); - } + // Add the donated inputs to the map and verify that all the + // donated inputs are also donated to the main func. + for (const auto input : op.getInputs()) { + donated_value_to_op.try_emplace(input, op); if (mlir::failed(VerifyIfInputAndDonated(op, input))) { return mlir::failure(); } } - donated_values.insert(op.getInputs().begin(), - op.getInputs().end()); + } + return mlir::success(); + }) + .Case([&](mlir::func::ReturnOp return_op) { + for (const auto& [idx, result] : + llvm::enumerate(return_op.getOperands())) { + auto donated_it = donated_value_to_op.find(result); + if (donated_it != donated_value_to_op.end()) { + return_op.emitOpError() + << "result #" << idx << " of op at " << return_op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } } return mlir::success(); }) @@ -134,7 +170,7 @@ void IfrtVerifyDonationPass::runOnOperation() { } // namespace -std::unique_ptr> +std::unique_ptr> CreateIfrtVerifyDonationPass() { return std::make_unique(); } diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h index a2cd1748a6c3b0..da7ec1ab599795 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.h @@ -20,10 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" namespace xla { namespace ifrt { @@ -43,6 +40,9 @@ CreateIfrtDuplicatedCalleeEliminationPass(); std::unique_ptr> CreateIfrtMergeReshardsPass(); +std::unique_ptr> +CreateIfrtOutlineAtomProgramToModulePass(); + std::unique_ptr> CreateIfrtVerifyDonationPass(); diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td index d299c6b4786425..10215b72653e0c 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/passes.td +++ b/third_party/xla/xla/python/ifrt/ir/transforms/passes.td @@ -95,6 +95,47 @@ module attributes {ifrt.devices = #device} { let constructor = "CreateSpmdExpansionPass()"; } +def IfrtOutlineAtomProgramToModulePass : + Pass<"ifrt-outline-atom-program-to-module", "mlir::ModuleOp"> { + let summary = "Wraps every atom function with a ModuleOp with a @main FuncOp"; + let description = [{ +For every unique atom program this passes produces a ModuleOp with the same name +as the callee, clones the callee into the ModuleOp, and redirects all the +CallOps calling it to the new callee. + +This pass must be run if the compiler (e.g., the XLA compiler) expects each atom +program to be outlined in a ModuleOp with a @main FuncOp. + +For example, the following code + +```mlir +%0, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +func.func private @callee(%arg0: tensor<2x2xi32>) -> (tensor<4x4xi32>) {} +``` + +will be replaced by + +```mlir +%0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0, 1] + : (!ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, + [0,1]>) + -> !ifrt.array, #ifrt.sharding_param<1x2 to [0] on 2>, + [0,1]> + +module @callee attributes {sym_visibility = "private"} { + func.func @main(%arg0: tensor<2x2xi32>) -> (tensor<4x4xi32>) {} +} +``` + }]; + + let constructor = "CreateIfrtOutlineAtomProgramToModulePass()"; +} + def IfrtDuplicatedCalleeEliminationPass : Pass<"ifrt-duplicated-callee-elimination", "mlir::ModuleOp"> { let summary = "Deduplicate callees of CallOp"; @@ -139,7 +180,8 @@ ifrt.CopyArrays, and ifrt.RemapArrays. let constructor = "CreateIfrtMergeReshardsPass()"; } -def IfrtVerifyDonationPass : Pass<"ifrt-verify-donation", "mlir::ModuleOp"> { +def IfrtVerifyDonationPass : + Pass<"ifrt-verify-donation", "mlir::func::FuncOp"> { let summary = "Verify that `!ifrt.array` are not donated more than once."; let description = [{ Verifiy that no `!ifrt.array` is donated more than once, and that all diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc index 2669dfd73d2256..13d198f2dbf8c8 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc +++ b/third_party/xla/xla/python/ifrt/ir/transforms/spmd_expansion_pass.cc @@ -35,9 +35,8 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "xla/python/ifrt/ir/ifrt_dialect.h" +#include "xla/python/ifrt/ir/constants.h" #include "xla/python/ifrt/ir/ifrt_interfaces.h" -#include "xla/python/ifrt/ir/transforms/constants.h" #include "xla/python/ifrt/ir/transforms/passes.h" namespace xla::ifrt { @@ -272,15 +271,15 @@ mlir::LogicalResult SpmdExpansionPass::spmdExpand(mlir::func::FuncOp func_op) { void SpmdExpansionPass::runOnOperation() { mlir::ModuleOp module_op = getOperation(); // Skip single-device case. - auto devices = module_op->getAttrOfType( - kIfrtDevicesAttrName); - if (devices == nullptr) { + auto num_devices = + module_op->getAttrOfType(kIfrtNumDevicesAttrName); + if (num_devices == nullptr) { module_op->emitOpError() << "`" << module_op.getName()->str() << "` requires `" - << kIfrtDevicesAttrName << "` attribute."; + << kIfrtNumDevicesAttrName << "` attribute."; return signalPassFailure(); } - if (devices.getIds().size() == 1) { + if (num_devices.getInt() == 1) { return; } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc similarity index 60% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h rename to third_party/xla/xla/python/ifrt/ir/transforms/utils.cc index fca921621ac4c1..b1cb219e5e49fe 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.cc @@ -13,21 +13,22 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ +#include "xla/python/ifrt/ir/transforms/utils.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/OpImplementation.h" +#include "absl/log/check.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep namespace xla { -namespace gpu { +namespace ifrt { -// Custom parser to parse IndexingMapAttr. -mlir::FailureOr ParseIndexingMapAttr(mlir::AsmParser& parser); +mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module) { + mlir::func::FuncOp func = + mlir::dyn_cast_or_null(module.lookupSymbol("main")); + CHECK(func); + return func; +} -} // namespace gpu +} // namespace ifrt } // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_ATTRS_H_ diff --git a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h similarity index 53% rename from third_party/xla/xla/python/ifrt/ir/transforms/constants.h rename to third_party/xla/xla/python/ifrt/ir/transforms/utils.h index 98bfd12e2c19b8..81528e97f418ae 100644 --- a/third_party/xla/xla/python/ifrt/ir/transforms/constants.h +++ b/third_party/xla/xla/python/ifrt/ir/transforms/utils.h @@ -1,4 +1,4 @@ -/* Copyright 2023 The OpenXLA Authors. +/* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ -#define XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ +#ifndef XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ +#define XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ -#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" -namespace xla::ifrt { +namespace xla { +namespace ifrt { -inline constexpr llvm::StringLiteral kIfrtDevicesAttrName = "ifrt.devices"; -inline constexpr llvm::StringLiteral kIfrtShardingAttrName = "ifrt.sharding"; -inline constexpr llvm::StringLiteral kIfrtEntryFunctionAttrName = - "ifrt.entry_function"; +// Retrieves the function named "main" from the given module, if it exists, and +// fails otherwise. +mlir::func::FuncOp GetMainFunction(mlir::ModuleOp module); -} // namespace xla::ifrt +} // namespace ifrt +} // namespace xla -#endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_CONSTANTS_H_ +#endif // XLA_PYTHON_IFRT_IR_TRANSFORMS_UTILS_H_ diff --git a/third_party/xla/xla/python/ifrt/memory.cc b/third_party/xla/xla/python/ifrt/memory.cc index c608950e3e8aef..c04bc0bead8ec6 100644 --- a/third_party/xla/xla/python/ifrt/memory.cc +++ b/third_party/xla/xla/python/ifrt/memory.cc @@ -19,8 +19,10 @@ limitations under the License. #include #include +#include "absl/base/thread_annotations.h" #include "absl/container/node_hash_set.h" -#include "xla/pjrt/pjrt_client.h" +#include "absl/strings/string_view.h" +#include "absl/synchronization/mutex.h" #include "xla/python/ifrt/device.h" namespace xla { @@ -52,7 +54,7 @@ MemoryKind::MemoryKind(std::optional memory_kind) { } } -std::string MemoryKind::DebugString() const { +std::string MemoryKind::ToString() const { if (memory_kind_.has_value()) { return std::string(*memory_kind_); } diff --git a/third_party/xla/xla/python/ifrt/memory.h b/third_party/xla/xla/python/ifrt/memory.h index a5f48fa6cf432c..309d49705381e3 100644 --- a/third_party/xla/xla/python/ifrt/memory.h +++ b/third_party/xla/xla/python/ifrt/memory.h @@ -62,17 +62,15 @@ class MemoryKind { template friend void AbslStringify(Sink& sink, const MemoryKind& memory_kind) { - sink.Append(memory_kind.DebugString()); + sink.Append(memory_kind.ToString()); } // Returns a platform-dependent identifier of a memory kind. std::optional memory_kind() const { return memory_kind_; } - // TODO(kedars): Rename & make private after replacing usage with - // AbslStringify. - std::string DebugString() const; - private: + std::string ToString() const; + std::optional memory_kind_; }; @@ -81,8 +79,8 @@ class MemoryKind { // indicated by the device, simply returns `MemoryKind` with no memory kind // chosen. // -// TODO(hyeontaek,yashkatariya): Harden `MemoryKind` creation paths so that -// every `MemoryKind` is canonicalized and does not require on-demand +// TODO(b/356623715): Harden `MemoryKind` creation paths so that every +// `MemoryKind` is canonicalized and does not require on-demand // canonicalization. MemoryKind CanonicalizeMemoryKind(MemoryKind memory_kind, Device* device); diff --git a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc index 31dca456bd0ea4..4edfae40571cae 100644 --- a/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc +++ b/third_party/xla/xla/python/ifrt/plugin_program_serdes_test.cc @@ -18,7 +18,7 @@ #include "xla/python/ifrt/plugin_program.h" #include "xla/python/ifrt/serdes.h" #include "xla/python/ifrt/serdes.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/protobuf/error_codes.pb.h" #include "tsl/protobuf/status.pb.h" diff --git a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc index a55d97d13998e4..85822b51c24e45 100644 --- a/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/remap_impl_test_lib.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/ifrt/sharding.cc b/third_party/xla/xla/python/ifrt/sharding.cc index e302535cc4f974..3cc2bcfb5668d3 100644 --- a/third_party/xla/xla/python/ifrt/sharding.cc +++ b/third_party/xla/xla/python/ifrt/sharding.cc @@ -50,6 +50,14 @@ namespace ifrt { namespace { +// Returns a canonicalized memory kind for the given devices. +// REQUIRES: !devices.empty() +MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind, + const DeviceList& devices) { + CHECK(!devices.empty()); + return CanonicalizeMemoryKind(memory_kind, devices.front()); +} + // Returns if `sharding_param` indicates a fully replicated sharding. bool ComputeIsFullyReplicated(const ShardingParam& sharding_param) { return llvm::all_of(sharding_param.dim_shards(), @@ -155,6 +163,12 @@ char ShardingParamSharding::ID = 0; char DeserializeShardingOptions::ID = 0; +Sharding::Sharding(DeviceList devices, MemoryKind memory_kind, + bool is_fully_replicated) + : devices_(std::move(devices)), + memory_kind_(memory_kind), + is_fully_replicated_(is_fully_replicated) {} + bool Sharding::operator==(const Sharding& other) const { if (this == &other) { return true; @@ -184,6 +198,7 @@ std::ostream& operator<<(std::ostream& os, const Sharding& sharding) { std::unique_ptr SingleDeviceSharding::Create( Device* device, MemoryKind memory_kind) { + memory_kind = CanonicalizeMemoryKind(memory_kind, device); return std::unique_ptr( new SingleDeviceSharding(device, memory_kind)); } @@ -240,13 +255,13 @@ absl::StatusOr> SingleDeviceSharding::IndexDomains( std::string SingleDeviceSharding::DebugString() const { DCHECK(this); - return absl::StrFormat("SingleDeviceSharding(%s, memory_kind: %s)", - devices_.front()->ToString(), - memory_kind_.DebugString()); + return absl::StrFormat("SingleDeviceSharding(%s, memory_kind: %v)", + devices_.front()->ToString(), memory_kind_); } std::unique_ptr OpaqueSharding::Create(DeviceList devices, MemoryKind memory_kind) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr( new OpaqueSharding(std::move(devices), memory_kind)); } @@ -306,18 +321,19 @@ absl::StatusOr> OpaqueSharding::IndexDomains( std::string OpaqueSharding::DebugString() const { DCHECK(this); return absl::StrFormat( - "OpaqueSharding(devices: %s, memory_kind: %s)", + "OpaqueSharding(devices: %s, memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - memory_kind_.DebugString()); + memory_kind_); } std::unique_ptr ConcreteSharding::Create( DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes) { CHECK_EQ(devices.size(), shard_shapes.size()); + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr( new ConcreteSharding(std::move(devices), memory_kind, std::move(shape), std::move(shard_shapes))); @@ -327,6 +343,7 @@ std::unique_ptr ConcreteSharding::Create( DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, std::vector shard_dynamic_shapes) { CHECK_EQ(devices.size(), shard_dynamic_shapes.size()); + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ConcreteSharding( std::move(devices), memory_kind, std::move(dynamic_shape), std::move(shard_dynamic_shapes))); @@ -454,7 +471,7 @@ std::string ConcreteSharding::DebugString() const { [this](const auto& shape, const auto& shard_shapes) { return absl::StrFormat( "ConcreteSharding(devices: %s, shape: %s, shard_shapes: %s, " - "memory_kind: %s)", + "memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); @@ -464,7 +481,7 @@ std::string ConcreteSharding::DebugString() const { [](std::string* out, const auto& shard_shape) { absl::StrAppend(out, shard_shape.DebugString()); }), - memory_kind_.DebugString()); + memory_kind_); }, shape_, shard_shapes_); } @@ -472,6 +489,7 @@ std::string ConcreteSharding::DebugString() const { std::unique_ptr ConcreteEvenSharding::Create( DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape, bool is_fully_replicated) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ConcreteEvenSharding( std::move(devices), memory_kind, std::move(shape), std::move(shard_shape), is_fully_replicated)); @@ -565,13 +583,12 @@ std::string ConcreteEvenSharding::DebugString() const { DCHECK(this); return absl::StrFormat( "ConcreteEvenSharding(devices: %s, shape: %s, shard_shape: %s, " - "memory_kind: %s)", + "memory_kind: %v)", absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - shape_.DebugString(), shard_shape_.DebugString(), - memory_kind_.DebugString()); + shape_.DebugString(), shard_shape_.DebugString(), memory_kind_); } absl::StatusOr> @@ -586,6 +603,7 @@ ShardingParamSharding::Create(ShardingParam sharding_param, DeviceList devices, "%d", device_count, devices.size()); } + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new ShardingParamSharding( std::move(sharding_param), std::move(devices), memory_kind)); } @@ -595,7 +613,8 @@ ShardingParamSharding::ShardingParamSharding(ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind) : llvm::RTTIExtends( - devices, memory_kind, ComputeIsFullyReplicated(sharding_param)), + std::move(devices), memory_kind, + ComputeIsFullyReplicated(sharding_param)), sharding_param_(sharding_param) {} absl::StatusOr>>> @@ -710,13 +729,13 @@ absl::StatusOr> ShardingParamSharding::IndexDomains( std::string ShardingParamSharding::DebugString() const { DCHECK(this); return absl::StrFormat( - "ShardingParamSharding(%s, devices: %s, memory_kind: %s)", + "ShardingParamSharding(%s, devices: %s, memory_kind: %v)", sharding_param_.DebugString(), absl::StrJoin(devices_, ",", [](std::string* out, const Device* device) { absl::StrAppend(out, device->ToString()); }), - memory_kind_.DebugString()); + memory_kind_); } } // namespace ifrt diff --git a/third_party/xla/xla/python/ifrt/sharding.h b/third_party/xla/xla/python/ifrt/sharding.h index c7fbd258cee56d..91b8b8ad1b31cb 100644 --- a/third_party/xla/xla/python/ifrt/sharding.h +++ b/third_party/xla/xla/python/ifrt/sharding.h @@ -125,10 +125,8 @@ class Sharding : public llvm::RTTIExtends { static char ID; // NOLINT protected: - Sharding(DeviceList devices, MemoryKind memory_kind, bool is_fully_replicated) - : devices_(devices), - memory_kind_(memory_kind), - is_fully_replicated_(is_fully_replicated) {} + Sharding(DeviceList devices, MemoryKind memory_kind, + bool is_fully_replicated); DeviceList devices_; MemoryKind memory_kind_; @@ -189,6 +187,7 @@ class SingleDeviceSharding final class OpaqueSharding : public llvm::RTTIExtends { public: // Creates an opaque sharding. `Disassemble()` will fail. + // REQUIRES: !devices.empty() static std::unique_ptr Create(DeviceList devices, MemoryKind memory_kind); @@ -230,6 +229,7 @@ class ConcreteSharding : public llvm::RTTIExtends { public: // Creates a concrete sharding that may contain non-identical shard shapes. // REQUIRES: `devices`.size() == `shard_shapes`.size() + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, Shape shape, std::vector shard_shapes); @@ -237,6 +237,7 @@ class ConcreteSharding : public llvm::RTTIExtends { // Creates a concrete sharding that may contain non-identical shard dynamic // shapes. // REQUIRES: `devices`.size() == `shard_dynamic_shapes`.size() + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, DynamicShape dynamic_shape, std::vector shard_dynamic_shapes); @@ -321,6 +322,7 @@ class ConcreteEvenSharding // Creates a concrete even sharding. // TODO(hyeontaek): Remove the default value of `is_fully_replicated` once all // callers are updated to provide it explicitly. + // REQUIRES: !devices.empty() static std::unique_ptr Create( DeviceList devices, MemoryKind memory_kind, Shape shape, Shape shard_shape, bool is_fully_replicated = false); @@ -371,6 +373,7 @@ class ConcreteEvenSharding class ShardingParamSharding : public llvm::RTTIExtends { public: + // REQUIRES: !devices.empty() static absl::StatusOr> Create( ShardingParam sharding_param, DeviceList devices, MemoryKind memory_kind); diff --git a/third_party/xla/xla/python/ifrt/support/BUILD b/third_party/xla/xla/python/ifrt/support/BUILD index b3c1b01726b35c..83c46b025b2d68 100644 --- a/third_party/xla/xla/python/ifrt/support/BUILD +++ b/third_party/xla/xla/python/ifrt/support/BUILD @@ -37,13 +37,13 @@ xla_cc_test( "//xla/python/ifrt:mock", "//xla/python/ifrt:test_util", "//xla/python/ifrt/ir:sharding_param", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc index da1b26c6bf9555..7973ec0f4abe6a 100644 --- a/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc +++ b/third_party/xla/xla/python/ifrt/support/sharding_conversions_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/shape.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/python/ifrt/test_util.h b/third_party/xla/xla/python/ifrt/test_util.h index 45e1258e8ec0e2..cd7ffc73824806 100644 --- a/third_party/xla/xla/python/ifrt/test_util.h +++ b/third_party/xla/xla/python/ifrt/test_util.h @@ -1,4 +1,4 @@ -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" /* Copyright 2022 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/third_party/xla/xla/python/ifrt/tuple_impl_test_lib.cc b/third_party/xla/xla/python/ifrt/tuple_impl_test_lib.cc index 643421076f3a5f..5a29e6e7587f4c 100644 --- a/third_party/xla/xla/python/ifrt/tuple_impl_test_lib.cc +++ b/third_party/xla/xla/python/ifrt/tuple_impl_test_lib.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/python/ifrt/test_util.h" #include "xla/python/ifrt/tuple.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/ifrt_proxy/client/BUILD b/third_party/xla/xla/python/ifrt_proxy/client/BUILD index 442b8f7bf85d51..8e947a4e68beac 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/client/BUILD @@ -103,7 +103,11 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status_to_from_proto", + "@local_tsl//tsl/profiler/lib:traceme", + "@local_tsl//tsl/profiler/lib:traceme_encode", + "@local_tsl//tsl/profiler/utils:xplane_schema", ] + if_google(["@com_google_absl//absl/types:source_location"]), ) diff --git a/third_party/xla/xla/python/ifrt_proxy/client/client.cc b/third_party/xla/xla/python/ifrt_proxy/client/client.cc index 15c293548b5759..09b4ccb847d7ee 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/client.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/client.cc @@ -82,7 +82,6 @@ absl::StatusOr> Client::Create( for (const auto& d : init_response.devices()) { absl::flat_hash_map pjrt_device_attributes; - AttributeMap::Map attributes; if (rpc_helper->version().protocol_version() <= 3) { for (const auto& [key, attr] : d.deprecated_attributes()) { TF_ASSIGN_OR_RETURN(xla::PjRtDeviceAttribute value, diff --git a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc index 1c334a8d7346a5..ff116a759e31a6 100644 --- a/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc +++ b/third_party/xla/xla/python/ifrt_proxy/client/rpc_helper.cc @@ -23,19 +23,23 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#if defined(PLATFORM_GOOGLE) -#include "absl/types/source_location.h" -#endif #include "xla/python/ifrt/future.h" #include "xla/python/ifrt_proxy/client/client_session.h" #include "xla/python/ifrt_proxy/common/ifrt_service.pb.h" +#include "tsl/platform/random.h" #include "tsl/platform/status_to_from_proto.h" +#include "tsl/profiler/lib/traceme.h" +#include "tsl/profiler/lib/traceme_encode.h" +#include "tsl/profiler/utils/xplane_schema.h" namespace xla { namespace ifrt { namespace proxy { +using ::tsl::profiler::XFlow; + // DoRpc is a templated function that implements the logic of all RPC-wrapping // functions of `RpcHelper`, such as `RpcHelper::MakeArrayFromHostBuffer()`. template @@ -44,14 +48,28 @@ Future> DoRpc(ClientSession* session, void (IfrtRequest::*set_req)(Req*), Resp* (IfrtResponse::*get_resp)(), bool (IfrtResponse::*has_resp)() const, - std::unique_ptr req) { + std::unique_ptr req, + absl::string_view profiling_send_name, + absl::string_view profiling_recv_name) { auto ifrt_req = std::make_unique(); *ifrt_req->mutable_request_metadata() = metadata; (ifrt_req.get()->*set_req)(req.release()); + const uint64_t xflow_id = tsl::random::New64() >> 8; // XFlow IDs are 56 bits + tsl::profiler::TraceMe traceme([xflow_id, profiling_send_name]() { + const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowOut); + return tsl::profiler::TraceMeEncode(profiling_send_name, + {{"flow", flow.ToStatValue()}}); + }); + auto promise = Future>::CreatePromise(); - auto on_ready = [promise, has_resp, get_resp]( + auto on_ready = [promise, has_resp, get_resp, xflow_id, profiling_recv_name]( absl::StatusOr> r) mutable { + tsl::profiler::TraceMe traceme([xflow_id, profiling_recv_name]() { + const XFlow flow(xflow_id, XFlow::FlowDirection::kFlowIn); + return tsl::profiler::TraceMeEncode(profiling_recv_name, + {{"flow", flow.ToStatValue()}}); + }); if (!r.ok()) { LOG_EVERY_N_SEC(ERROR, 10) << "Connection to IFRT proxy server was terminated: " << r.status(); @@ -127,13 +145,14 @@ void RpcHelper::Disconnect() { // TODO(b/266635130): Remove this preprocessor macro. Preprocessor macros // go against the style guide, but are convenient as we are introducing more // RPCs and are making changes to the exact signature of the DoRpc function. -#define RPC(METHOD, PROPERTY) \ - RpcHelper::ResponseFuture RpcHelper::METHOD( \ - std::unique_ptr req) { \ - return DoRpc(session_.get(), ManufactureRequestMetadata(), \ - &IfrtRequest::set_allocated_##PROPERTY##_request, \ - &IfrtResponse::mutable_##PROPERTY##_response, \ - &IfrtResponse::has_##PROPERTY##_response, std::move(req)); \ +#define RPC(METHOD, PROPERTY) \ + RpcHelper::ResponseFuture RpcHelper::METHOD( \ + std::unique_ptr req) { \ + return DoRpc(session_.get(), ManufactureRequestMetadata(), \ + &IfrtRequest::set_allocated_##PROPERTY##_request, \ + &IfrtResponse::mutable_##PROPERTY##_response, \ + &IfrtResponse::has_##PROPERTY##_response, std::move(req), \ + "" #PROPERTY "_send", "" #PROPERTY "_recv"); \ } RPC(Init, init); diff --git a/third_party/xla/xla/python/ifrt_proxy/common/types.cc b/third_party/xla/xla/python/ifrt_proxy/common/types.cc index 9d222a453c58ee..db981531c24c27 100644 --- a/third_party/xla/xla/python/ifrt_proxy/common/types.cc +++ b/third_party/xla/xla/python/ifrt_proxy/common/types.cc @@ -83,7 +83,6 @@ proto::ArrayCopySemantics ToArrayCopySemanticsProto(ArrayCopySemantics s) { absl::StatusOr FromArrayCopySemanticsProto( proto::ArrayCopySemantics s) { - MakeArrayFromHostBufferRequest req; switch (s) { case proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY: return ArrayCopySemantics::kAlwaysCopy; diff --git a/third_party/xla/xla/python/ifrt_proxy/server/BUILD b/third_party/xla/xla/python/ifrt_proxy/server/BUILD index ee8d2519c32268..ed842aef61fa3c 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/BUILD +++ b/third_party/xla/xla/python/ifrt_proxy/server/BUILD @@ -183,6 +183,7 @@ ifrt_proxy_cc_test( "//xla/python/pjrt_ifrt:xla_ifrt", "//xla/service:computation_placer_hdr", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", @@ -195,7 +196,6 @@ ifrt_proxy_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", diff --git a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc index 79da57ce2c3bf3..df3e72d53da438 100644 --- a/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc +++ b/third_party/xla/xla/python/ifrt_proxy/server/ifrt_backend_test.cc @@ -67,8 +67,8 @@ #include "xla/status_macros.h" #include "xla/test.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep diff --git a/third_party/xla/xla/python/pjit.cc b/third_party/xla/xla/python/pjit.cc index 51a7bb7ff976ad..6bbe898eba48fd 100644 --- a/third_party/xla/xla/python/pjit.cc +++ b/third_party/xla/xla/python/pjit.cc @@ -46,8 +46,10 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "nanobind/stl/vector.h" // IWYU pragma: keep +#include "xla/layout.h" #include "xla/pjrt/exceptions.h" #include "xla/pjrt/lru_cache.h" +#include "xla/pjrt/pjrt_layout.h" #include "xla/python/ifrt/array.h" #include "xla/python/ifrt/device.h" #include "xla/python/ifrt/memory.h" @@ -90,6 +92,7 @@ struct PjitCacheEntry { // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args` // in PjitFunction::Call before calling into compiled computation. std::vector kept_var_bitvec; + std::vector in_device_local_layouts; // Ensures a single thread performs the compilation for a given executable. // @@ -351,11 +354,12 @@ PjitFunction::PjitFunction( PjitFunction::~PjitFunction() { GetGlobalPjitFunctionStore().Erase(this); } void CallShardArgFallback( - nb::handle arg, nb::handle sharding, const nb::callable& fallback, + nb::handle arg, nb::handle sharding, nb::handle layout, + const nb::callable& fallback, std::vector>& num_args_arrays, std::vector& keep_alive_objects) { tsl::profiler::TraceMe traceme("cpp_pjit_shard_arg_fallback"); - auto py_array_or_bufs = fallback(arg, sharding); + auto py_array_or_bufs = fallback(arg, sharding, layout); auto py_array = nb::cast(py_array_or_bufs); num_args_arrays.push_back(tsl::FormRef(py_array.ifrt_array())); keep_alive_objects.push_back(std::move(py_array_or_bufs)); @@ -368,6 +372,7 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, absl::Span flat_dynamic_args, bool enable_x64, const std::vector& kept_args, const std::vector& in_shardings, + const std::vector& in_device_local_layouts, const nb::callable& shard_arg_fallback, std::vector& keep_alive_objects) { const auto& addressable_devices = @@ -401,11 +406,13 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, ++dce_i; const nb::object& arg = flat_dynamic_args[i]; + const nb::object& in_device_local_layout = + in_device_local_layouts[dce_index]; auto transfer_guard_formatter = [] { return std::string(""); }; if (arg.type().ptr() != xla::PyArray::type().ptr()) { - if (data_device != nullptr) { + if (data_device != nullptr && in_device_local_layout.is_none()) { TF_RETURN_IF_ERROR( jax::ApplyTransferGuardToHostToDevice(transfer_guard_formatter)); TF_ASSIGN_OR_RETURN( @@ -426,8 +433,8 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, continue; } else { CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } } @@ -442,17 +449,31 @@ PrepareIfrtInputs(const xla::PyLoadedExecutable& executable, DCHECK(py_array.committed() || (!py_array.committed() && sharding_num_devices == 1)); + if (!in_device_local_layout.is_none()) { + TF_ASSIGN_OR_RETURN(auto arr_layout, py_array.ifrt_array()->layout()); + xla::Layout in_xc_layout = nb::cast( + in_device_local_layout.attr("_to_xla_layout")(py_array.dtype())); + if (in_xc_layout != GetXlaLayoutUnsafe(arr_layout)) { + CallShardArgFallback(arg.ptr(), in_shardings[dce_index], + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); + continue; + } + } + if (sharding.type().ptr() == jax::PmapSharding::type().ptr()) { + CHECK(in_device_local_layout.is_none()); CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } if (py_array.num_shards() != addressable_devices.size()) { + CHECK(in_device_local_layout.is_none()); CallShardArgFallback(arg.ptr(), in_shardings[dce_index], - shard_arg_fallback, num_args_arrays, - keep_alive_objects); + in_device_local_layout, shard_arg_fallback, + num_args_arrays, keep_alive_objects); continue; } @@ -659,7 +680,8 @@ absl::StatusOr PjitFunction::Call(nb::handle callable, auto num_args_arrays = PrepareIfrtInputs( *cache_entry->executable, flat_dynamic_args, call_signature.jax_enable_x64, cache_entry->kept_var_bitvec, - cache_entry->in_shardings, shard_arg_fallback_, keep_alive_objects); + cache_entry->in_shardings, cache_entry->in_device_local_layouts, + shard_arg_fallback_, keep_alive_objects); if (!num_args_arrays.ok()) { VLOG(2) << "Failed to prepare IFRT inputs: " << num_args_arrays.status(); @@ -821,6 +843,13 @@ void PjitFunction::PopulateCacheEntry(PjitCacheEntry& cache_entry, for (nb::handle k : kept_var_bitvec) { cache_entry.kept_var_bitvec.push_back(nb::cast(k)); } + + nb::sequence in_device_local_layouts = + fastpath_data.attr("in_device_local_layouts"); + cache_entry.in_device_local_layouts.reserve(nb::len(in_device_local_layouts)); + for (nb::handle dll : in_device_local_layouts) { + cache_entry.in_device_local_layouts.push_back(nb::borrow(dll)); + } } // Helper function used by the tp_clear GC method. diff --git a/third_party/xla/xla/python/pjrt_ifrt/BUILD b/third_party/xla/xla/python/pjrt_ifrt/BUILD index 158bff8d695684..e958a11683b930 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/BUILD +++ b/third_party/xla/xla/python/pjrt_ifrt/BUILD @@ -140,11 +140,11 @@ cc_library( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/python/ifrt/hlo:hlo_program", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -337,6 +337,7 @@ xla_cc_test( "//xla/python/ifrt", "//xla/python/ifrt:test_util", "//xla/tsl/concurrency:ref_count", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -345,7 +346,6 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc index a0d21a4cf11307..108e5a9f982760 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/basic_string_array_test.cc @@ -44,7 +44,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc index 73324e6b1c8c91..751b00c9b37620 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_array.cc @@ -92,9 +92,8 @@ absl::Status ValidateArrayCreationInput( if (canonicalized_sharding_memory_kind != buffer_memory_kind) { return InvalidArgument( "PjRtBuffer's memory kind does not match sharding's memory kind. Got " - "PjRtBuffer's memory kind: %s vs shardings's memory kind: %s", - buffer_memory_kind.DebugString(), - canonicalized_sharding_memory_kind.DebugString()); + "PjRtBuffer's memory kind: %v vs shardings's memory kind: %v", + buffer_memory_kind, canonicalized_sharding_memory_kind); } } return absl::OkStatus(); @@ -116,8 +115,8 @@ absl::StatusOr GetMemoryKindFromPjRtBuffers( pjrt_buffer->device())) { return InvalidArgument( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind: %s and another with memory_kind: %s", - first_memory_kind.DebugString(), memory_kind.DebugString()); + "memory kind: %v and another with memory_kind: %v", + first_memory_kind, memory_kind); } } return first_memory_kind; @@ -440,11 +439,10 @@ absl::StatusOr GetMemorySpaceFromMemoryKind( } if (memory == nullptr) { return InvalidArgument( - "Invalid memory kind: %s; available memory kinds: %s", - memory_kind.DebugString(), + "Invalid memory kind: %v; available memory kinds: %s", memory_kind, absl::StrJoin(device->Memories(), ", ", [](std::string* out, Memory* m) { - absl::StrAppend(out, m->Kind().DebugString()); + absl::StrAppend(out, m->Kind()); })); } return memory; diff --git a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc index d77d1c0bf69650..42ffc9aca0353d 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/pjrt_executable.cc @@ -651,8 +651,8 @@ PjRtLoadedExecutable::Execute(absl::Span> args, memory_kind, pjrt_outputs[j][i]->device())) { return FailedPrecondition( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), memory_kind.DebugString()); + "memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, memory_kind); } } buffers.push_back( diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc index 04da5007591f4f..4a3bc4197c766c 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_executable_impl_test_lib.cc @@ -40,7 +40,7 @@ limitations under the License. #include "xla/python/ifrt/sharding.h" #include "xla/python/ifrt/test_util.h" #include "xla/python/pjrt_ifrt/xla_compiler.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc index 6f79e56502eb77..62a07724d1cc42 100644 --- a/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc +++ b/third_party/xla/xla/python/pjrt_ifrt/xla_sharding.cc @@ -97,11 +97,20 @@ std::vector IndexDomainsSlowPath( return result; } +// Returns a canonicalized memory kind for the given devices. +// REQUIRES: !devices.empty() +MemoryKind CanonicalizeMemoryKindWithDevices(const MemoryKind& memory_kind, + const DeviceList& devices) { + CHECK(!devices.empty()); + return CanonicalizeMemoryKind(memory_kind, devices.front()); +} + } // namespace std::unique_ptr HloSharding::Create( DeviceList devices, MemoryKind memory_kind, xla::HloSharding xla_hlo_sharding) { + memory_kind = CanonicalizeMemoryKindWithDevices(memory_kind, devices); return std::unique_ptr(new HloSharding( std::move(devices), memory_kind, std::move(xla_hlo_sharding))); } @@ -340,9 +349,8 @@ absl::StatusOr> HloSharding::IndexDomains( } std::string HloSharding::DebugString() const { - return absl::StrFormat("HloSharding(memory_kind: %s, hlo_sharding: %s)", - memory_kind_.DebugString(), - xla_hlo_sharding_.ToString()); + return absl::StrFormat("HloSharding(memory_kind: %v, hlo_sharding: %s)", + memory_kind_, xla_hlo_sharding_.ToString()); } std::vector TEST_HloShardingIndexDomainsSlowPath( diff --git a/third_party/xla/xla/python/profiler/internal/python_hooks.h b/third_party/xla/xla/python/profiler/internal/python_hooks.h index a9b502ef3b2e46..29e6b83dac1962 100644 --- a/third_party/xla/xla/python/profiler/internal/python_hooks.h +++ b/third_party/xla/xla/python/profiler/internal/python_hooks.h @@ -77,7 +77,7 @@ struct PythonTraceEntry { Py_XDECREF(m_module); } - PythonTraceEntry(PythonTraceEntry&& other) { + PythonTraceEntry(PythonTraceEntry&& other) noexcept { start_time_ns = other.start_time_ns; end_time_ns = other.end_time_ns; co_firstlineno = other.co_firstlineno; diff --git a/third_party/xla/xla/python/py_array.cc b/third_party/xla/xla/python/py_array.cc index 8d00206e38ee30..b350116d53b043 100644 --- a/third_party/xla/xla/python/py_array.cc +++ b/third_party/xla/xla/python/py_array.cc @@ -74,7 +74,6 @@ limitations under the License. #include "xla/python/pjrt_ifrt/pjrt_array.h" #include "xla/python/pjrt_ifrt/pjrt_client.h" #include "xla/python/pjrt_ifrt/pjrt_device.h" -#include "xla/python/pjrt_ifrt/xla_sharding.h" #include "xla/python/py_client.h" #include "xla/python/py_device.h" #include "xla/python/py_values.h" @@ -88,11 +87,6 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" -// TODO(b/324133505): remove this GOOGLE_CUDA block after JAX OSS migrates -// to cuda plugin. -#if GOOGLE_CUDA -#include "xla/stream_executor/cuda/cuda_driver.h" -#endif #include "xla/tsl/concurrency/ref_count.h" #include "xla/util.h" #include "tsl/platform/errors.h" @@ -184,9 +178,8 @@ tsl::RCReference CreateIfRtArrayFromSingleDeviceShardedPyArrays( throw nb::value_error( absl::StrFormat( "Memory kind mismatch between PjRtBuffers. Got one buffer with " - "memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_arrays.back()->sharding().memory_kind().DebugString()) + "memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_arrays.back()->sharding().memory_kind()) .c_str()); } } @@ -639,10 +632,8 @@ absl::Status PyArray::set_arrays(nb::object obj) { throw nb::value_error( absl::StrFormat( "Memory kind mismatch between single-device arrays. Got one " - "array " - "with memory kind '%s' and another with memory_kind '%s'", - first_memory_kind.DebugString(), - ifrt_array->sharding().memory_kind().DebugString()) + "array with memory kind '%v' and another with memory_kind '%v'", + first_memory_kind, ifrt_array->sharding().memory_kind()) .c_str()); } } @@ -866,19 +857,6 @@ absl::StatusOr CudaArrayInterfaceToBuffer( PrimitiveType element_type, DtypeToPrimitiveType(nb_dtype::from_args(cai["typestr"]))); - // TODO(b/324133505): remove this GOOGLE_CUDA block after JAX OSS migrates - // to cuda plugin. -#ifdef GOOGLE_CUDA - if (!device_id.has_value()) { - // cannot determine device_id/stream when device pointer is NULL. - device_id.emplace( - (data_value == 0 - ? 0 - : stream_executor::gpu::CreatedContexts::GetDeviceOrdinal( - data_ptr))); - } -#endif // GOOGLE_CUDA - if (!device_id.has_value()) { throw XlaRuntimeError( "This operation requires CUDA support from jaxlib or jax cuda plugin."); diff --git a/third_party/xla/xla/python/py_array.h b/third_party/xla/xla/python/py_array.h index cb76a5fcb90272..015c61c391146a 100644 --- a/third_party/xla/xla/python/py_array.h +++ b/third_party/xla/xla/python/py_array.h @@ -27,6 +27,7 @@ limitations under the License. #include // placeholder for index annotation headers +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" diff --git a/third_party/xla/xla/python/py_client.cc b/third_party/xla/xla/python/py_client.cc index 3b4ebcd9901d09..0e36a346f67e39 100644 --- a/third_party/xla/xla/python/py_client.cc +++ b/third_party/xla/xla/python/py_client.cc @@ -34,12 +34,13 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/Pass/PassManager.h" #include "nanobind/nanobind.h" #include "nanobind/stl/optional.h" // IWYU pragma: keep #include "nanobind/stl/pair.h" // IWYU pragma: keep @@ -67,6 +68,7 @@ limitations under the License. #include "xla/python/ifrt/hlo/hlo_program.h" #include "xla/python/ifrt/host_callback.h" #include "xla/python/ifrt/memory.h" +#include "xla/python/ifrt/program.h" #include "xla/python/nb_absl_span.h" // IWYU pragma: keep #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_numpy.h" @@ -86,9 +88,13 @@ limitations under the License. #include "xla/python/types.h" #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" // IWYU pragma: keep +#include "xla/service/spmd/shardy/constants.h" +#include "xla/service/spmd/shardy/sdy_round_trip/pipelines.h" +#include "xla/service/spmd/shardy/utils.h" #include "xla/shape.h" #include "xla/status_macros.h" #include "xla/tsl/concurrency/ref_count.h" +#include "xla/tsl/framework/mlir/status_scoped_diagnostic_handler.h" #include "xla/util.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" @@ -437,6 +443,22 @@ PyClient::CompileIfrtProgram( mlir::MLIRContext context; TF_ASSIGN_OR_RETURN(mlir::OwningOpRef module, ParseMlirModuleString(mlir_module, context)); + if (options.executable_build_options.use_shardy_partitioner()) { + mlir::PassManager pm(&context); + // Since Shardy is inside the middle of the XLA pipeline, after converting + // down to HLO, we need to run the Shardy export pipeline to preserve the + // SDY ops and sharding attributes for when we come back from HLO to MLIR + // when Shardy propagation is run. + xla::sdy::addSdyRoundTripExportPipeline(pm); + // TODO(bartchr): remove setting `kPythonIntegrationComplete` in follow-up + // now that both JAX and PartIR are integrated with Shardy. + xla::sdy::addFrontendAttribute(*module, + xla::sdy::kPythonIntegrationComplete, + mlir::StringAttr::get(&context, "t")); + TF_RETURN_IF_ERROR( + tsl::StatusScopedDiagnosticHandler(&context).consumeStatus( + pm.run(*module))); + } return CompileIfrtProgram( client, std::make_unique(module.get()), MakeIfrtCompileOptions(std::move(options), std::move(host_callbacks))); diff --git a/third_party/xla/xla/python/py_compile_only_client.cc b/third_party/xla/xla/python/py_compile_only_client.cc index 9d9db9afccffec..6f5aff61938b88 100644 --- a/third_party/xla/xla/python/py_compile_only_client.cc +++ b/third_party/xla/xla/python/py_compile_only_client.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/Support/Casting.h" @@ -79,6 +81,40 @@ namespace xla { namespace { +class CompileOnlyMemory + : public llvm::RTTIExtends { + public: + explicit CompileOnlyMemory( + int id, const PjRtMemorySpaceDescription* memory_description, + ifrt::Device* device) + : id_(id), + kind_(memory_description->kind()), + debug_string_(absl::StrFormat("CompileOnlyMemory(id=%d, kind=%s)", id, + memory_description->kind())), + device_(device) {} + + ifrt::MemoryId Id() const override { return ifrt::MemoryId(id_); } + + const ifrt::MemoryKind& Kind() const override { return kind_; } + + absl::string_view ToString() const override { return debug_string_; } + absl::string_view DebugString() const override { return debug_string_; } + + absl::Span Devices() const override { + return absl::Span{&device_, 1}; + } + + static char ID; // NOLINT + + private: + int id_; + ifrt::MemoryKind kind_; + std::string debug_string_; + ifrt::Device* device_; +}; + +[[maybe_unused]] char CompileOnlyMemory::ID = 0; + class CompileOnlyDevice : public llvm::RTTIExtends { public: @@ -108,16 +144,31 @@ class CompileOnlyDevice return description_->DebugString(); } - absl::Span Memories() const override { return {}; } + absl::Span Memories() const override { + return unowned_memories_; + } absl::StatusOr DefaultMemory() const override { + if (default_memory_) { + return default_memory_; + } return Unimplemented("DefaultMemory is not supported"); } const ifrt::AttributeMap& Attributes() const override { return attributes_; } + void AttachMemory(std::unique_ptr memory) { + unowned_memories_.push_back(memory.get()); + owned_memories_.push_back(std::move(memory)); + } + + void SetDefaultMemory(ifrt::Memory* memory) { default_memory_ = memory; } + private: const PjRtDeviceDescription* description_; ifrt::AttributeMap attributes_; + ifrt::Memory* default_memory_ = nullptr; + std::vector unowned_memories_; + std::vector> owned_memories_; }; class InvalidIfrtCompiler final @@ -153,10 +204,24 @@ class CompileOnlyIfRtClient final : topology_(std::move(topology)), descriptions_(topology_->DeviceDescriptions()), attributes_(ifrt::AttributeMap::Map()) { + int offset = 0; for (auto& description : descriptions_) { owned_devices_.push_back( std::make_unique(description.get())); - devices_.push_back(owned_devices_.back().get()); + auto* device = owned_devices_.back().get(); + devices_.push_back(device); + if (description->process_index() == process_index()) { + auto default_memory = description->default_memory_space(); + for (auto* memory_description : description->memory_spaces()) { + auto memory = std::make_unique( + offset, memory_description, device); + if (default_memory.ok() && memory_description == *default_memory) { + device->SetDefaultMemory(memory.get()); + } + device->AttachMemory(std::move(memory)); + ++offset; + } + } } } diff --git a/third_party/xla/xla/python/py_values.h b/third_party/xla/xla/python/py_values.h index 9733a42c3e2ec1..51bfdb919cf487 100644 --- a/third_party/xla/xla/python/py_values.h +++ b/third_party/xla/xla/python/py_values.h @@ -49,8 +49,8 @@ struct DevicePutResult { // dangerous due to `owning_pybuffer`. DevicePutResult(const DevicePutResult&) = delete; DevicePutResult& operator=(const DevicePutResult&) = delete; - DevicePutResult(DevicePutResult&&) = default; - DevicePutResult& operator=(DevicePutResult&&) = default; + DevicePutResult(DevicePutResult&&) noexcept = default; + DevicePutResult& operator=(DevicePutResult&&) noexcept = default; // Points to the on-device array. Not owned. tsl::RCReference ifrt_array; diff --git a/third_party/xla/xla/python/python_ref_manager.h b/third_party/xla/xla/python/python_ref_manager.h index 815e80e03f2455..4f1d8212fe6ea4 100644 --- a/third_party/xla/xla/python/python_ref_manager.h +++ b/third_party/xla/xla/python/python_ref_manager.h @@ -57,7 +57,7 @@ class PythonRefManager { ManagedPyObjects(const ManagedPyObjects& other) = delete; ManagedPyObjects(ManagedPyObjects&& other) = default; ManagedPyObjects& operator=(const ManagedPyObjects& other) = delete; - ManagedPyObjects& operator=(ManagedPyObjects&& other) = default; + ManagedPyObjects& operator=(ManagedPyObjects&& other) noexcept = default; private: PythonRefManager* manager_ = nullptr; diff --git a/third_party/xla/xla/python/pytree.cc b/third_party/xla/xla/python/pytree.cc index 68a483cd51e97f..65bfb3fe5305e4 100644 --- a/third_party/xla/xla/python/pytree.cc +++ b/third_party/xla/xla/python/pytree.cc @@ -1249,7 +1249,7 @@ nb_class_ptr PyTreeDef::MakeFromNodeDataAndChildren( nb::cast(nb::repr(node_data->first)))); } node.kind = registration->kind; - if (node.kind == PyTreeKind::kCustom) { + if (node.kind == PyTreeKind::kCustom || node.kind == PyTreeKind::kDataclass) { node.custom = registration; node.node_data = node_data->second; } else if (node.kind == PyTreeKind::kNamedTuple) { diff --git a/third_party/xla/xla/python/pytree_test.py b/third_party/xla/xla/python/pytree_test.py index 4125d7a28257a3..922a4d78fd6b56 100644 --- a/third_party/xla/xla/python/pytree_test.py +++ b/third_party/xla/xla/python/pytree_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import collections +import dataclasses from absl.testing import absltest @@ -44,6 +45,15 @@ def from_iterable(state, values): registry.register_node(ExampleType2, ExampleType2.to_iterable, from_iterable) +@dataclasses.dataclass +class Custom: + a: int + b: str + + +registry.register_dataclass_node(Custom, ["a"], ["b"]) + + class PyTreeTest(absltest.TestCase): def roundtrip(self, example): @@ -92,6 +102,15 @@ def testCompose(self): y = registry.flatten((0, 0))[1] self.assertEqual((x.compose(y)).num_leaves, 2) + def testDataclassMakeFromNodeData(self): + c = Custom(1, "a") + c_leafs, c_tree = registry.flatten(c) + c_tree2 = c_tree.make_from_node_data_and_children( + registry, c_tree.node_data(), c_tree.children() + ) + self.assertEqual(c_tree2.unflatten(c_leafs), c) + self.assertEqual(str(c_tree2), str(c_tree)) + if __name__ == "__main__": absltest.main() diff --git a/third_party/xla/xla/python/refine_polymorphic_shapes.cc b/third_party/xla/xla/python/refine_polymorphic_shapes.cc index cbd42e928ef576..e358b2e549a8e2 100644 --- a/third_party/xla/xla/python/refine_polymorphic_shapes.cc +++ b/third_party/xla/xla/python/refine_polymorphic_shapes.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/Regex.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Bytecode/BytecodeWriter.h" diff --git a/third_party/xla/xla/python/sharding.cc b/third_party/xla/xla/python/sharding.cc index acbe324ac75f47..e995df9285d8b9 100644 --- a/third_party/xla/xla/python/sharding.cc +++ b/third_party/xla/xla/python/sharding.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "nanobind/stl/string.h" // IWYU pragma: keep #include "nanobind/stl/string_view.h" // IWYU pragma: keep #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/pjrt/status_casters.h" #include "xla/python/ifrt/device.h" #include "xla/python/nb_class_ptr.h" #include "xla/python/nb_helpers.h" @@ -176,8 +178,7 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, nb::object memory_kind, nb::object parsed_pspec, nb::object manual_axes) : Sharding(/*num_devices=*/[&mesh]() { - xla::nb_numpy_ndarray devices = mesh.attr("devices"); - return devices.size(); + return nb::cast(mesh.attr("size")); }()), mesh_(std::move(mesh)), spec_(std::move(spec)), @@ -185,10 +186,18 @@ NamedSharding::NamedSharding(nb::object mesh, nb::object spec, parsed_pspec_(std::move(parsed_pspec)), manual_axes_(std::move(manual_axes)) { nb::object idl = nb::object(mesh_.attr("_internal_device_list")); - internal_device_list_ = nb::cast>( - nb::object(mesh_.attr("_internal_device_list"))); - memory_kind_ = - CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_); + if (idl.is_none()) { + internal_device_list_ = std::nullopt; + } else { + internal_device_list_ = nb::cast>( + nb::object(mesh_.attr("_internal_device_list"))); + } + if (internal_device_list_) { + memory_kind_ = + CheckAndCanonicalizeMemoryKind(memory_kind_, *internal_device_list_); + } else { + memory_kind_ = nb::none(); + } nb::module_ si = nb::module_::import_("jax._src.sharding_impls"); parsed_pspec_ = @@ -265,8 +274,9 @@ void RegisterSharding(nb::module_& m) { .def_prop_ro("_manual_axes", &NamedSharding::manual_axes) .def_prop_rw("_parsed_pspec", &NamedSharding::parsed_pspec, &NamedSharding::set_parsed_pspec) - .def_prop_ro("_internal_device_list", - &NamedSharding::internal_device_list); + .def_prop_ro("_internal_device_list", [](const NamedSharding& s) { + return xla::ValueOrThrow(s.internal_device_list()); + }); nb::class_(m, "SingleDeviceSharding", nb::dynamic_attr()) diff --git a/third_party/xla/xla/python/sharding.h b/third_party/xla/xla/python/sharding.h index d3b1211619cd7d..1e28b7aecff6b8 100644 --- a/third_party/xla/xla/python/sharding.h +++ b/third_party/xla/xla/python/sharding.h @@ -22,6 +22,7 @@ limitations under the License. // placeholder for index annotation headers #include "absl/hash/hash.h" +#include "absl/status/statusor.h" #include "nanobind/nanobind.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/pjrt/status_casters.h" @@ -86,8 +87,13 @@ class NamedSharding : public Sharding { return type; } - xla::nb_class_ptr internal_device_list() const { - return internal_device_list_; + absl::StatusOr> internal_device_list() const { + if (internal_device_list_) { + return *internal_device_list_; + } + return xla::InvalidArgument( + "internal_device_list is not implemented for " + "`jax.sharding.AbstractMesh`"); } private: @@ -96,7 +102,7 @@ class NamedSharding : public Sharding { nanobind::object memory_kind_; nanobind::object parsed_pspec_; nanobind::object manual_axes_; - xla::nb_class_ptr internal_device_list_; + std::optional> internal_device_list_; }; class SingleDeviceSharding : public Sharding { diff --git a/third_party/xla/xla/python/tools/BUILD b/third_party/xla/xla/python/tools/BUILD index 6d57e560d70cf6..cc0c5e0c189713 100644 --- a/third_party/xla/xla/python/tools/BUILD +++ b/third_party/xla/xla/python/tools/BUILD @@ -86,7 +86,7 @@ py_strict_test( ":types", "@absl_py//absl/testing:absltest", "@absl_py//absl/testing:parameterized", - #internal proto upb dep + # copybara:uncomment "//third_party/py/google/protobuf:use_fast_cpp_protos", "//third_party/py/numpy", "//xla:xla_data_proto_py", ], diff --git a/third_party/xla/xla/python/traceback.cc b/third_party/xla/xla/python/traceback.cc index 5a86924002f2b0..19e4f94d4f8d9b 100644 --- a/third_party/xla/xla/python/traceback.cc +++ b/third_party/xla/xla/python/traceback.cc @@ -99,7 +99,8 @@ Traceback::~Traceback() { } } -Traceback::Traceback(Traceback&& other) : frames_(std::move(other.frames_)) { +Traceback::Traceback(Traceback&& other) noexcept + : frames_(std::move(other.frames_)) { // absl::InlinedVector does not always clear itself if moved. Since we rely on // its empty() method to destroy Traceback differently, we explicitly clear // here. @@ -222,6 +223,7 @@ PyType_Slot traceback_slots_[] = { void BuildTracebackSubmodule(nb::module_& m) { nb::class_(m, "Frame") + .def(nb::init()) .def_ro("file_name", &Traceback::Frame::file_name) .def_ro("function_name", &Traceback::Frame::function_name) .def_ro("function_start_line", &Traceback::Frame::function_start_line) @@ -271,6 +273,33 @@ void BuildTracebackSubmodule(nb::module_& m) { traceback.def("__str__", &Traceback::ToString); traceback.def("as_python_traceback", &Traceback::AsPythonTraceback); + traceback.def_static( + "traceback_from_frames", + [](std::vector frames) { + nb::object traceback = nb::none(); + nb::dict globals; + nb::handle traceback_type( + reinterpret_cast(&PyTraceBack_Type)); + for (const Traceback::Frame& frame : frames) { + PyCodeObject* py_code = + PyCode_NewEmpty(frame.file_name.c_str(), + frame.function_name.c_str(), frame.line_num); + PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), py_code, + globals.ptr(), /*locals=*/ + nullptr); + Py_DECREF(py_code); + traceback = traceback_type( + /*tb_next=*/std::move(traceback), + /*tb_frame=*/ + nb::steal(reinterpret_cast(py_frame)), + /*tb_lasti=*/0, + /*tb_lineno=*/ + frame.line_num); + } + return traceback; + }, + "Creates a traceback from a list of frames."); + traceback.def_static( "code_addr2line", [](nb::handle code, int lasti) { diff --git a/third_party/xla/xla/python/traceback.h b/third_party/xla/xla/python/traceback.h index c93860b7c3d9ce..da8036272ee3b4 100644 --- a/third_party/xla/xla/python/traceback.h +++ b/third_party/xla/xla/python/traceback.h @@ -48,7 +48,7 @@ class Traceback { ~Traceback(); Traceback(const Traceback&) = delete; - Traceback(Traceback&& other); + Traceback(Traceback&& other) noexcept; Traceback& operator=(const Traceback&) = delete; Traceback& operator=(Traceback&&) = delete; diff --git a/third_party/xla/xla/python/xla.cc b/third_party/xla/xla/python/xla.cc index 19a9d94e1d1b7b..2136e981507f10 100644 --- a/third_party/xla/xla/python/xla.cc +++ b/third_party/xla/xla/python/xla.cc @@ -61,17 +61,18 @@ limitations under the License. #include "xla/python/py_client.h" #include "xla/python/py_program.h" #include "xla/service/cpu/collectives_interface.h" -#include "xla/tsl/python/lib/core/numpy.h" //NOLINT -#ifdef XLA_PYTHON_ENABLE_GPU -#include "xla/python/gpu_support.h" -#endif // XLA_PYTHON_ENABLE_GPU +#include "xla/tsl/python/lib/core/numpy.h" // NOLINT -#ifdef __linux__ +#if defined(__linux__) #include "gloo/transport/tcp/attr.h" #include "gloo/transport/tcp/device.h" #include "xla/pjrt/cpu/gloo_collectives.h" #include "xla/pjrt/cpu/gloo_kv_store.h" -#endif // __linux__ +#elif defined(__APPLE__) +#include "gloo/transport/uv/device.h" +#include "xla/pjrt/cpu/gloo_collectives.h" // NOLINT +#include "xla/pjrt/cpu/gloo_kv_store.h" // NOLINT +#endif // defined(__linux__) #if !defined(_WIN32) && !defined(PLATFORM_GOOGLE) #include "xla/pjrt/cpu/mpi_collectives.h" @@ -257,7 +258,7 @@ NB_MODULE(xla_extension, m_nb) { std::optional hostname, std::optional interface) -> std::shared_ptr { -#ifdef __linux__ +#if defined(__linux__) std::shared_ptr kv_store = nullptr; if (distributed_client != nullptr) { kv_store = GetDistributedKeyValueStore(distributed_client, @@ -274,10 +275,27 @@ NB_MODULE(xla_extension, m_nb) { auto tcp_device = gloo::transport::tcp::CreateDevice(tcp_attrs); return std::make_shared(std::move(gloo_kv_store), std::move(tcp_device)); -#else // __linux__ +#elif defined(__APPLE__) + std::shared_ptr kv_store = nullptr; + if (distributed_client != nullptr) { + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"cpu:"); + } + auto gloo_kv_store = std::make_unique(kv_store); + auto uv_attrs = gloo::transport::uv::attr(); + if (hostname) { + uv_attrs.hostname = *hostname; + } + if (interface) { + uv_attrs.iface = *interface; + } + auto uv_device = gloo::transport::uv::CreateDevice(uv_attrs); + return std::make_shared(std::move(gloo_kv_store), + std::move(uv_device)); +#else // defined(__linux__) throw xla::XlaRuntimeError( - "make_gloo_tcp_collectives only implemented for linux"); -#endif // __linux__ + "make_gloo_tcp_collectives only implemented for linux and macos"); +#endif // defined(__linux__) }, nb::arg("distributed_client"), nb::arg("hostname").none() = std::nullopt, nb::arg("interface").none() = std::nullopt); @@ -366,10 +384,6 @@ NB_MODULE(xla_extension, m_nb) { return xla::ThrowIfError(pjrt::InitializePjrtPlugin(platform_name)); }); -#ifdef XLA_PYTHON_ENABLE_GPU - RegisterGpuClientAndDefineGpuAllocatorConfig(m_nb); -#endif // XLA_PYTHON_ENABLE_GPU - m_nb.def( "get_c_api_client", [](std::string platform_name, diff --git a/third_party/xla/xla/python/xla_client.py b/third_party/xla/xla/python/xla_client.py index 97f01bec9bb0d5..294f109a8d8f7e 100644 --- a/third_party/xla/xla/python/xla_client.py +++ b/third_party/xla/xla/python/xla_client.py @@ -50,7 +50,7 @@ # Just an internal arbitrary increasing number to help with backward-compatible # changes. In JAX, reference this via jax._src.lib.xla_extension_version. -_version = 279 +_version = 282 # Version number for MLIR:Python components. mlir_api_version = 57 diff --git a/third_party/xla/xla/python/xla_client.pyi b/third_party/xla/xla/python/xla_client.pyi index bf5a6d9cc1a890..8731080c99b52a 100644 --- a/third_party/xla/xla/python/xla_client.pyi +++ b/third_party/xla/xla/python/xla_client.pyi @@ -222,12 +222,16 @@ class GatherDimensionNumbers: collapsed_slice_dims: list[int] start_index_map: list[int] index_vector_dim: int + operand_batching_dims: list[int] + start_indices_batching_dims: list[int] class ScatterDimensionNumbers: update_window_dims: list[int] inserted_window_dims: list[int] scatter_dims_to_operand_dims: list[int] index_vector_dim: int + input_batching_dims: list[int] + scatter_indices_batching_dims: list[int] class ReplicaGroup: replica_ids: list[int] diff --git a/third_party/xla/xla/python/xla_client_test.py b/third_party/xla/xla/python/xla_client_test.py index b84e094b1d841b..37484ccffec93b 100644 --- a/third_party/xla/xla/python/xla_client_test.py +++ b/third_party/xla/xla/python/xla_client_test.py @@ -65,8 +65,12 @@ # pylint: disable=invalid-name -def jax_array_convert_to_array(self): - return self._single_device_array_to_np_array() +def jax_array_convert_to_array(self, dtype=None, copy=None): + del copy + out = self._single_device_array_to_np_array() + if dtype is not None: + out = out.astype(dtype) + return out def jax_array_device(self): @@ -586,7 +590,10 @@ class ParametersTest(ComputationTest): def testScalarTimesVector(self, dtype): c = self._NewComputation() arg0 = np.array(3, dtype=dtype) - arg1 = np.array([10, 15, -2, 7], dtype=dtype) + if np.issubdtype(dtype, np.unsignedinteger): + arg1 = np.array([10, 15, 2, 7], dtype=dtype) + else: + arg1 = np.array([10, 15, -2, 7], dtype=dtype) p0 = ops.Parameter(c, 0, xla_client.shape_from_pyval(arg0)) p1 = ops.Parameter(c, 1, xla_client.shape_from_pyval(arg1)) ops.Mul(p0, p1) @@ -2990,6 +2997,49 @@ def testAccessingLocalsDoesNotCrash(self): for frame, _ in traceback.walk_tb(python_tb): _ = frame.f_locals # should not crash + def testTracebackFromFrames(self): + def FooFn(x): + return x + 1 + + def BarFn(y): + y = y + 1 + y = y + 2 + return y * 2 + + frame_foo = xla_client.Frame( + __file__, + FooFn.__code__.co_name, + FooFn.__code__.co_firstlineno, + FooFn.__code__.co_firstlineno + 1, + ) + frame_bar = xla_client.Frame( + __file__, + BarFn.__code__.co_name, + BarFn.__code__.co_firstlineno, + BarFn.__code__.co_firstlineno + 2, + ) + frames = [frame_foo, frame_bar] + tb = xla_client.Traceback.traceback_from_frames(frames) + + with self.subTest("WalkDoesNotError"): + for frame, _ in traceback.walk_tb(tb): + _ = frame.f_locals # should not crash + + with self.subTest("TracebackCorrectness"): + tb_string = traceback.format_tb(tb) + # The traceback should have the format: + # File , line N in BarFn + # y = y + 2 + # File , line N in FooFn + # return x + 1 + self.assertLen(tb_string, len(frames)) + bar_frame = tb_string[0].split("\n") + self.assertEndsWith(bar_frame[0], "BarFn") + self.assertEqual(bar_frame[1].strip(), "y = y + 2") + foo_frame = tb_string[1].split("\n") + self.assertEndsWith(foo_frame[0], "FooFn") + self.assertEqual(foo_frame[1].strip(), "return x + 1") + tests.append(TracebackTest) class ClientTest(ComputationTest): diff --git a/third_party/xla/xla/python/xla_compiler.cc b/third_party/xla/xla/python/xla_compiler.cc index 2259083a2da478..f58a59a3cc715c 100644 --- a/third_party/xla/xla/python/xla_compiler.cc +++ b/third_party/xla/xla/python/xla_compiler.cc @@ -80,10 +80,10 @@ limitations under the License. #include "xla/service/tuple_simplifier.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" @@ -1199,10 +1199,7 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { &DebugOptions::xla_gpu_dump_autotune_logs_to, [](DebugOptions* self, std::string value) { self->set_xla_gpu_dump_autotune_logs_to(value); - }) - // TODO(b/352486192): Move this to `ExecutableBuildOptions`. - .def_prop_rw("xla_use_shardy", &DebugOptions::xla_use_shardy, - &DebugOptions::set_xla_use_shardy); + }); nb::class_(m, "ExecutableBuildOptions") .def(nb::init<>()) @@ -1276,7 +1273,10 @@ void BuildXlaCompilerSubmodule(nb::module_& m) { [](ExecutableBuildOptions& options, std::vector values) { absl::InlinedVector v(values.begin(), values.end()); options.set_allow_spmd_sharding_propagation_to_output(v); - }); + }) + .def_prop_rw("use_shardy_partitioner", + &ExecutableBuildOptions::use_shardy_partitioner, + &ExecutableBuildOptions::set_use_shardy_partitioner); nb::enum_ op_sharding_type(m, "OpSharding_Type", nb::is_arithmetic()); diff --git a/third_party/xla/xla/python/xla_extension/__init__.pyi b/third_party/xla/xla/python/xla_extension/__init__.pyi index e19bf8546491ab..5e2982184a3fca 100644 --- a/third_party/xla/xla/python/xla_extension/__init__.pyi +++ b/third_party/xla/xla/python/xla_extension/__init__.pyi @@ -318,8 +318,6 @@ class DebugOptions: xla_gpu_dump_autotune_results_to: str xla_gpu_load_autotune_results_from: str xla_gpu_dump_autotune_logs_to: str - # TODO(b/352486192): Move this to `ExecutableBuildOptions`. - xla_use_shardy: bool class CompiledMemoryStats: generated_code_size_in_bytes: int @@ -348,6 +346,7 @@ class ExecutableBuildOptions: use_auto_spmd_partitioning: bool auto_spmd_partitioning_mesh_shape: List[int] auto_spmd_partitioning_mesh_ids: List[int] + use_shardy_partitioner: bool class PrecisionConfig_Precision(enum.IntEnum): DEFAULT: int @@ -752,12 +751,19 @@ class Frame: function_name: str function_line_start: int line_num: int + def __init__(self, + file_name: str, + function_name: str, + function_line_start: int, + line_num: int): ... def __repr__(self) -> str: ... class Traceback: enabled: ClassVar[bool] @staticmethod def get_traceback() -> Traceback: ... + @staticmethod + def traceback_from_frames(frames: Sequence[Frame]) -> Any: ... frames: Sequence[Frame] def __str__(self) -> str: ... def as_python_traceback(self) -> Any: ... diff --git a/third_party/xla/xla/python_api/xla_literal.py b/third_party/xla/xla/python_api/xla_literal.py index 3471f3a99cc2db..4ad7bf0a36c587 100644 --- a/third_party/xla/xla/python_api/xla_literal.py +++ b/third_party/xla/xla/python_api/xla_literal.py @@ -50,9 +50,8 @@ def ConvertLiteralToNumpyArray(literal): numpy_reshaper = lambda arr: arr.reshape(numpy_shape, order='C') else: raise NotImplementedError('Unsupported layout: {0}'.format(layout_order)) - ndarray = _np.array( + ndarray = _np.asarray( getattr(literal, type_record.literal_field_name), - copy=False, dtype=type_record.numpy_dtype) return numpy_reshaper(ndarray) diff --git a/third_party/xla/xla/reference_util.cc b/third_party/xla/xla/reference_util.cc index 33fb500b2b4d24..d7461cadcfe4a4 100644 --- a/third_party/xla/xla/reference_util.cc +++ b/third_party/xla/xla/reference_util.cc @@ -25,11 +25,19 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" +#include "absl/types/span.h" +#include "xla/array2d.h" +#include "xla/array3d.h" +#include "xla/array4d.h" +#include "xla/client/padding.h" #include "xla/client/xla_builder.h" #include "xla/hlo/evaluator/hlo_evaluator.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/literal.h" #include "xla/literal_util.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/shape_inference.h" +#include "xla/shape.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" #include "tsl/lib/math/math_util.h" diff --git a/third_party/xla/xla/reference_util.h b/third_party/xla/xla/reference_util.h index 9a124d6f577f5c..a086fdbf8cef76 100644 --- a/third_party/xla/xla/reference_util.h +++ b/third_party/xla/xla/reference_util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "absl/functional/function_ref.h" +#include "absl/log/check.h" #include "absl/types/span.h" #include "xla/array2d.h" #include "xla/array3d.h" diff --git a/third_party/xla/xla/reference_util_test.cc b/third_party/xla/xla/reference_util_test.cc index 320d1cac5e63f7..c27e5414525553 100644 --- a/third_party/xla/xla/reference_util_test.cc +++ b/third_party/xla/xla/reference_util_test.cc @@ -22,7 +22,9 @@ limitations under the License. #include "xla/array3d.h" #include "xla/array4d.h" #include "xla/client/padding.h" +#include "xla/error_spec.h" #include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/test.h" #include "xla/tests/literal_test_util.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 729c234ba6897f..9b55155ffab786 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -142,7 +142,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -256,53 +256,6 @@ xla_cc_test( ], ) -cc_library( - name = "all_reduce_splitter", - srcs = ["all_reduce_splitter.cc"], - hdrs = ["all_reduce_splitter.h"], - deps = [ - ":collective_opt_utils", - ":hlo_module_config", - ":hlo_pass", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "all_reduce_splitter_test", - srcs = ["all_reduce_splitter_test.cc"], - deps = [ - ":all_reduce_splitter", - ":hlo_module_config", - ":hlo_pass_pipeline", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_reduce_scatter_creator", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "float_support", srcs = ["float_support.cc"], @@ -501,6 +454,7 @@ xla_cc_test( ":hlo_parser", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/hlo/utils:hlo_query", "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "@com_google_googletest//:gtest", @@ -578,7 +532,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -602,7 +556,9 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_instruction_utils", "//xla/hlo/utils:hlo_query", + "//xla/service:call_graph", "//xla/service:hlo_parser", + "//xla/service:tuple_points_to_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -667,10 +623,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -686,6 +642,7 @@ cc_library( "//xla:util", "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_set", @@ -702,7 +659,6 @@ cc_library( "@llvm-project//mlir:Transforms", "@local_tsl//tsl/lib/io:zlib_compression_options", "@local_tsl//tsl/lib/io:zlib_outputbuffer", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -721,8 +677,8 @@ xla_cc_test( ":hlo_parser", "//xla:xla_proto_cc", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -773,7 +729,6 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", # fixdeps: keep "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", @@ -919,8 +874,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -955,7 +910,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -964,12 +919,17 @@ cc_library( hdrs = ["pattern_matcher.h"], deps = [ ":hlo_parser", + "//xla:comparison_util", + "//xla:literal", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:ptrvec", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_absl//absl/utility", @@ -982,11 +942,16 @@ xla_cc_test( deps = [ ":hlo_parser", ":pattern_matcher", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla:test", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -997,7 +962,9 @@ cc_library( hdrs = ["pattern_matcher_gmock.h"], deps = [ ":pattern_matcher", + "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", "@local_tsl//tsl/platform:test", ], ) @@ -1010,7 +977,31 @@ xla_cc_test( ":pattern_matcher_gmock", "//xla:shape_util", "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "fuzzy_matcher", + hdrs = ["fuzzy_matcher.h"], + deps = [ + ":pattern_matcher", + "//xla/hlo/ir:hlo", + ], +) + +xla_cc_test( + name = "fuzzy_matcher_test", + srcs = ["fuzzy_matcher_test.cc"], + deps = [ + ":fuzzy_matcher", + ":pattern_matcher", + "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) @@ -1052,7 +1043,8 @@ xla_cc_test( deps = [ ":pattern_matcher", ":pattern_matcher_gmock", - "//xla:literal", + "//xla:comparison_util", + "//xla:literal_util", "//xla:protobuf_util", "//xla:shape_util", "//xla:test", @@ -1064,8 +1056,11 @@ xla_cc_test( "//xla/service/gpu:backend_configs_cc", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1121,11 +1116,11 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -1180,17 +1175,20 @@ xla_cc_test( srcs = ["call_inliner_test.cc"], deps = [ ":call_inliner", - ":hlo_pass", + ":hlo_parser", "//xla:literal", + "//xla:literal_util", "//xla:shape_util", "//xla:test", - "//xla:types", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", ], ) @@ -1230,8 +1228,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -1252,7 +1250,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -1543,11 +1541,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", @@ -1700,6 +1698,7 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor:device_memory_allocator", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", @@ -1707,7 +1706,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", @@ -1761,8 +1759,8 @@ xla_test( "//xla/stream_executor/gpu:gpu_init", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -1965,9 +1963,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -2011,9 +2009,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -2031,7 +2029,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2108,10 +2106,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -2129,9 +2127,9 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -2187,11 +2185,11 @@ xla_cc_test( "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", @@ -2467,7 +2465,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2509,7 +2507,7 @@ xla_cc_test( "//xla:types", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2587,7 +2585,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2784,13 +2782,13 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -2991,7 +2989,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -3189,7 +3187,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -3327,7 +3325,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", ], ) @@ -3376,8 +3374,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", ], ) @@ -3703,10 +3701,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -3894,8 +3892,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -3984,6 +3982,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) @@ -3999,8 +3998,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -4160,8 +4159,8 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -4236,13 +4235,13 @@ xla_test( "//xla/tests:llvm_irgen_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -4268,7 +4267,7 @@ xla_cc_test( "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test_benchmark", ], @@ -4287,7 +4286,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -4417,9 +4416,9 @@ xla_cc_test( "//xla/stream_executor/host:host_platform_id", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -4511,7 +4510,9 @@ xla_cc_test( ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", - "//xla:literal", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_tree", "//xla:shape_util", "//xla:test", "//xla:test_helpers", @@ -4521,7 +4522,9 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) @@ -4542,11 +4545,11 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -4680,11 +4683,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -4783,7 +4786,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -4893,7 +4896,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], @@ -5149,7 +5152,7 @@ xla_cc_test( ":memory_space_propagation", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5243,6 +5246,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:log_severity", "@com_google_absl//absl/log:scoped_mock_log", "@com_google_absl//absl/status", @@ -5250,7 +5254,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -5280,8 +5283,8 @@ xla_cc_test( ":hlo_verifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -5354,10 +5357,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -5379,9 +5382,9 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -5397,7 +5400,7 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5423,11 +5426,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -5492,7 +5495,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5506,6 +5509,7 @@ cc_library( "//xla:literal", "//xla:shape_util", "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:errors", @@ -5531,6 +5535,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", ], ) @@ -5652,7 +5657,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -5709,7 +5714,7 @@ xla_cc_test( "//xla/tests:literal_test_util", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6172,7 +6177,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6202,7 +6207,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6337,11 +6342,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -6386,11 +6391,63 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "host_offload_utils", + srcs = ["host_offload_utils.cc"], + hdrs = ["host_offload_utils.h"], + deps = [ + ":call_graph", + ":hlo_buffer", + ":host_memory_offload_annotations_hdr", + ":pattern_matcher", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "host_offload_utils_test", + srcs = ["host_offload_utils_test.cc"], + deps = [ + ":hlo_verifier", + ":host_memory_offload_annotations_hdr", + ":host_offload_utils", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) @@ -6407,6 +6464,7 @@ cc_library( ":hlo_pass", ":hlo_value", ":host_memory_offload_annotations_hdr", + ":host_offload_utils", ":pattern_matcher", "//xla:literal_util", "//xla:shape_util", @@ -6443,12 +6501,12 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -6482,10 +6540,10 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -6567,9 +6625,9 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -6610,9 +6668,9 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -6658,8 +6716,8 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -6759,7 +6817,7 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -6879,6 +6937,7 @@ cc_library( ":hlo_proto_cc", ":name_uniquer", ":shape_inference", + "//xla:array", "//xla:comparison_util", "//xla:literal", "//xla:literal_util", @@ -6888,6 +6947,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:tile_assignment", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", @@ -6903,6 +6963,7 @@ cc_library( "@local_tsl//tsl/lib/gtl:map_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", ], ) @@ -6911,18 +6972,26 @@ xla_cc_test( size = "small", srcs = ["hlo_parser_test.cc"], deps = [ + ":hlo_lexer", + ":hlo_module_config", ":hlo_parser", ":pattern_matcher", ":pattern_matcher_gmock", + "//xla:array", "//xla:shape_util", "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -7226,6 +7295,7 @@ cc_library( ":__subpackages__", "//tensorflow/compiler/tf2xla:__pkg__", "//xla/pjrt:__subpackages__", + "//xla/backends/cpu/runtime:__subpackages__", ]), deps = [ ":custom_call_status", @@ -7275,7 +7345,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -7378,10 +7448,10 @@ xla_cc_test( "//xla/hlo/utils:hlo_matchers", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -7398,10 +7468,10 @@ xla_cc_test( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -7440,7 +7510,7 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -7636,9 +7706,9 @@ xla_cc_test( ":mapped_ptr_container_sorter", "//xla:test", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/log", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -7753,8 +7823,8 @@ xla_cc_test( "//xla:test", "//xla:xla_proto_cc", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:casts", "@local_tsl//tsl/platform:protobuf", ], @@ -7886,8 +7956,7 @@ cc_library( deps = [ ":hlo_creation_utils", ":hlo_pass", - "//xla/service/cpu:onednn_convolution_rewriter", - "//xla/service/cpu:onednn_matmul_rewriter", + "//xla/service/cpu:onednn_contraction_rewriter", ], ) @@ -7921,6 +7990,25 @@ cc_library( ], ) +cc_library( + name = "batched_gather_scatter_normalizer", + srcs = ["batched_gather_scatter_normalizer.cc"], + hdrs = ["batched_gather_scatter_normalizer.h"], + deps = [ + ":op_expander_pass", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "reduce_window_rewriter", srcs = ["reduce_window_rewriter.cc"], @@ -8047,6 +8135,18 @@ xla_cc_test( ], ) +xla_cc_test( + name = "batched_gather_scatter_normalizer_test", + srcs = ["batched_gather_scatter_normalizer_test.cc"], + deps = [ + ":batched_gather_scatter_normalizer", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + xla_cc_test( name = "change_op_data_type_test", srcs = ["change_op_data_type_test.cc"], @@ -8166,9 +8266,9 @@ xla_cc_test( "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -8194,8 +8294,8 @@ xla_cc_test( "//xla/client:local_client", "//xla/service/cpu:cpu_compiler", "//xla/tests:literal_test_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -8235,7 +8335,7 @@ xla_cc_test( "//xla/client:executable_build_options", "//xla:literal", "//xla:shape_util", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -8288,8 +8388,8 @@ xla_cc_test( ":gpu_compilation_environment", "//xla:parse_flags_from_env", "//xla:xla_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", @@ -8345,4 +8445,42 @@ cc_library( ], ) +cc_library( + name = "add_original_value", + srcs = ["add_original_value.cc"], + hdrs = ["add_original_value.h"], + deps = [ + ":hlo_pass", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "add_original_value_test", + srcs = ["add_original_value_test.cc"], + deps = [ + ":add_original_value", + ":pattern_matcher", + ":pattern_matcher_gmock", + "//xla:shape_util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + exports_files(["xla_aot_compile_test_gpu_target_config.prototxt"]) diff --git a/third_party/xla/xla/service/add_original_value.cc b/third_party/xla/xla/service/add_original_value.cc new file mode 100644 index 00000000000000..37cab3c7cad81a --- /dev/null +++ b/third_party/xla/xla/service/add_original_value.cc @@ -0,0 +1,68 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/add_original_value.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" +#include "xla/shape_util.h" + +namespace xla { + +absl::StatusOr AddOriginalValue::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + for (const auto computation : module->computations()) { + for (const auto instruction : computation->instructions()) { + auto original_value = + std::make_shared(instruction->shape()); + + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + const auto* tuple = instruction->operand(0); + original_value->CopySubtreeFrom(*tuple->original_value(), + {instruction->tuple_index()}, {}); + } else if (instruction->opcode() == HloOpcode::kTuple) { + for (int64_t operand_number = 0; + operand_number < instruction->operand_count(); ++operand_number) { + original_value->CopySubtreeFrom( + *instruction->operand(operand_number)->original_value(), {}, + {operand_number}); + } + } else { + for (auto& leaf : original_value->leaves()) { + leaf.second = {std::string(instruction->name()), leaf.first}; + } + } + instruction->set_original_value(original_value); + changed = true; + } + } + + return changed; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/add_original_value.h b/third_party/xla/xla/service/add_original_value.h new file mode 100644 index 00000000000000..b4fb093fca2d9c --- /dev/null +++ b/third_party/xla/xla/service/add_original_value.h @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ +#define XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ + +#include "absl/status/statusor.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass adds to each op in the HLO graph the original_value attribute, +// which is used for HLO value tracking. See go/hlo-value-tracking for more +// details. +class AddOriginalValue : public HloModulePass { + public: + absl::string_view name() const override { return "add-original-value"; } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_ADD_ORIGINAL_VALUE_H_ diff --git a/third_party/xla/xla/service/add_original_value_test.cc b/third_party/xla/xla/service/add_original_value_test.cc new file mode 100644 index 00000000000000..f69ba94cba440e --- /dev/null +++ b/third_party/xla/xla/service/add_original_value_test.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/add_original_value.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using AddOriginalValueTest = HloTestBase; + +using ::absl::string_view; + +TEST_F(AddOriginalValueTest, Basic) { + constexpr absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={(s32[]{:T(256)})->u32[2]{0:T(256)}} + +ENTRY test { + Arg_0.1 = s32[] parameter(0) + constant.2 = s32[] constant(32) + shift-right-logical.3 = s32[] shift-right-logical(Arg_0.1, constant.2) + convert.4 = u32[] convert(shift-right-logical.3) + reshape.5 = u32[1]{0} reshape(convert.4) + convert.6 = u32[] convert(Arg_0.1) + reshape.7 = u32[1]{0} reshape(convert.6) + ROOT concatenate.8 = u32[2]{0} concatenate(reshape.5, reshape.7), dimensions={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AddOriginalValue pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + EXPECT_TRUE(changed); +} + +TEST_F(AddOriginalValueTest, Tuple) { + constexpr absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} + +ENTRY test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]{0}), f32[2,3]{1,0}) { + v1 = f32[] parameter(0) + v2 = f32[3]{0} parameter(1) + v3 = f32[2,3]{1,0} parameter(2) + t1 = (f32[], f32[3]{0}) tuple(f32[] v1, f32[3]{0} v2) + ROOT t2 = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) t1, f32[2,3]{1,0} v3) +} + +)"; + + RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( +CHECK: %[[V1:.*]] = f32[] parameter(0), original_value={{[{]}}{"[[V1]]"} +CHECK: %[[V2:.*]] = f32[3]{0} parameter(1), original_value={{[{]}}{"[[V2]]"} +CHECK: %[[TUPLE:.*]] = (f32[], f32[3]{0}) tuple(%[[V1]], %[[V2]]), original_value={({"[[V1]]"}, {"[[V2]]"})} +CHECK: %[[V3:.*]] = f32[2,3]{1,0} parameter(2), original_value={{[{]}}{"[[V3]]"} +CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), original_value={(({"v1"}, {"v2"}), {"v3"})} + )"); +} + +TEST_F(AddOriginalValueTest, GetTupleElement) { + constexpr absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={()->s32[2,3]{1,0}} + +ENTRY test { + constant = f32[3]{0} constant({1, 2, 3}) + constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }) + tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} constant, s32[2,3]{1,0} constant.1) + ROOT get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) tuple), index=1 +} + +)"; + + RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( +CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), original_value={{[{]}}{"[[CONSTANT1]]"} +CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), original_value={{[{]}}{"[[CONSTANT2]]"} +CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), original_value={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})} +CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, original_value={{[{]}}{"[[CONSTANT2]]"} + )"); +} + +TEST_F(AddOriginalValueTest, GetTupleElementNonSymbolic) { + constexpr absl::string_view hlo_string = R"( +HloModule test, entry_computation_layout={((f32[], s32[]))->s32[]} + +ENTRY test { + p = (f32[], s32[]) parameter(0) + ROOT get-tuple-element = s32[] get-tuple-element(p), index=1 +} + +)"; + + RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"( +CHECK: %[[PARAM:.*]] = (f32[], s32[]) parameter(0), original_value={({"p" {0}{{[}]}}, {"p" {1}})} +CHECK: s32[] get-tuple-element(%[[PARAM]]), index=1, original_value={{[{]}}{"[[PARAM]]" {1} + )"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/algebraic_simplifier.cc b/third_party/xla/xla/service/algebraic_simplifier.cc index fad9dcacaa4ab9..f54864220e2e69 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.cc +++ b/third_party/xla/xla/service/algebraic_simplifier.cc @@ -179,6 +179,11 @@ std::optional GetConstantValue(const HloInstruction* inst) { using NativeT = NativeTypeOf; return static_cast( inst->literal().GetFirstElement()); + } else if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + using NativeT = NativeTypeOf; + return static_cast( + inst->literal().GetFirstElement()); } return std::nullopt; }, @@ -608,6 +613,11 @@ std::unique_ptr MakeScalarInstruction(HloInstruction* target, using NativeT = NativeTypeOf; return HloInstruction::CreateConstant( LiteralUtil::CreateR0(static_cast(multiplier))); + } else if constexpr (primitive_util::IsIntegralType( + primitive_type_constant)) { + using NativeT = NativeTypeOf; + return HloInstruction::CreateConstant( + LiteralUtil::CreateR0(static_cast(multiplier))); } LOG(FATAL) << "Unsupported data type: " << target->shape().element_type(); @@ -3331,12 +3341,11 @@ AlgebraicSimplifierVisitor::AssociativeReorderDotOperator( } else if (opcode == HloOpcode::kBroadcast) { // Broadcasts of dot contracting dimensions can be reordered to reduces // of the corresponding contracting dimensions in the other dot operand - DimensionVector reduce_dims, broadcast_dim_sizes; + DimensionVector reduce_dims; const int64_t pre_broadcast_rank = reorder_from->mutable_operand(0)->shape().rank(); int64_t post_broadcast_rank = reorder_from->shape().rank(); Shape new_broadcast_shape = reorder_from->shape(); - DimensionVector contracting_reordered; // Construct map from broadcasted shape to its original shape. Broadcast // dimensions are mapped to -1 since they were not present @@ -3554,48 +3563,28 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { other_index = outer_dnums.lhs_batch_dimensions(i); } - // Once we have the inner_index, we determine whether this index - // corresponds to a dimension coming from the lhs or rhs of inner - bool from_inner_lhs = map_inner_rhs[inner_index] == -1; - - // The map we use depends on which operand of inner this dim comes from - std::vector map; - if (from_inner_lhs) { - map = map_inner_lhs; - } else { - map = map_inner_rhs; - } - - // Whether the mapped value goes into the lhs or rhs of the new dnums - // depends on whether inner was the lhs or rhs operand of outer - int64_t lhs_index, rhs_index; - if (outer_lhs_dot) { - lhs_index = map[inner_index]; - rhs_index = other_index; - } else { - lhs_index = other_index; - rhs_index = map[inner_index]; - } - - // Finally, we have to determine which dnums to add to - DotDimensionNumbers* dnums; - if (outer_lhs_dot) { - if (from_inner_lhs) { - dnums = &ac_dnums; - } else { - dnums = &bc_dnums; - } - } else { - if (from_inner_lhs) { - dnums = &ab_dnums; - } else { - dnums = &ac_dnums; + auto add_batch_dims = [](DotDimensionNumbers& dnums, int64_t lhs_ix, + int64_t rhs_ix) { + dnums.add_lhs_batch_dimensions(lhs_ix); + dnums.add_rhs_batch_dimensions(rhs_ix); + }; + + for (auto& map : {map_inner_lhs, map_inner_rhs}) { + int64_t mapped_index = map[inner_index]; + if (mapped_index != -1) { + // Whether the mapped value is the lhs or rhs of the new dnums + // depends on whether inner is the lhs or rhs operand of outer. The + // dnums itself depends on this and also on which map we are + // iterating through + if (outer_lhs_dot) { + add_batch_dims(map == map_inner_lhs ? ac_dnums : bc_dnums, + mapped_index, other_index); + } else { + add_batch_dims(map == map_inner_lhs ? ab_dnums : ac_dnums, + other_index, mapped_index); + } } } - - // Add the batch dimensions - dnums->add_lhs_batch_dimensions(lhs_index); - dnums->add_rhs_batch_dimensions(rhs_index); } // We now do the same thing for the contracting dimensions of outer @@ -3614,7 +3603,14 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // Once we have the inner_index, we determine whether this index // corresponds to a dimension coming from the lhs or rhs of inner - bool from_inner_lhs = map_inner_rhs[inner_index] == -1; + bool from_inner_lhs = map_inner_lhs[inner_index] != -1; + bool from_inner_rhs = map_inner_rhs[inner_index] != -1; + + // If a dimension of inner is the result of batching and it is + // contracted in outer, we stop trying to reorder + if (from_inner_lhs && from_inner_rhs) { + return absl::OkStatus(); + } // The map we use depends on which operand of inner this dim comes from std::vector map; @@ -3714,8 +3710,11 @@ absl::Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { rhs_index = other_index; } - new_outer_dnums.add_lhs_batch_dimensions(lhs_index); - new_outer_dnums.add_rhs_batch_dimensions(rhs_index); + if (!absl::c_linear_search(new_outer_dnums.lhs_batch_dimensions(), + lhs_index)) { + new_outer_dnums.add_lhs_batch_dimensions(lhs_index); + new_outer_dnums.add_rhs_batch_dimensions(rhs_index); + } } for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { int64_t new_inner_index, other_index; @@ -4280,6 +4279,19 @@ absl::StatusOr> MinMaxToClamp( } } // namespace +bool AlgebraicSimplifierVisitor::IsNondecreasingSublinear( + const HloInstruction* hlo) { + switch (hlo->opcode()) { + case HloOpcode::kCbrt: + case HloOpcode::kErf: + case HloOpcode::kLogistic: + case HloOpcode::kTanh: + return true; + default: + return false; + } +} + absl::Status AlgebraicSimplifierVisitor::HandleMaximum( HloInstruction* maximum) { HloInstruction *lhs, *rhs; @@ -4350,6 +4362,33 @@ absl::Status AlgebraicSimplifierVisitor::HandleMaximum( } } + // If the operands of the max are the same non-decreasing function, then we + // can sink it; i.e. max(tanh(x), tanh(y)) to tanh(max(x, y)) + // We only do this if the function asymptotically satisfies |f(x)| <= |x| to + // guarantee that no overflow occurs. Proof of correctness: + /* https://cvc5.github.io/app/ + (set-logic ALL) + (declare-fun f (Float32) Float32) + (assert (forall ((x Float32) (y Float32)) + (=> (fp.lt x y) (fp.leq (f x) (f y))))) ; NonDecreasing + (assert (forall ((x Float32)) + (fp.leq (fp.abs (f x)) (fp.abs x)))) ; Sublinear + (assert (not (forall ((x Float32) (y Float32)) + (fp.eq (fp.max (f x) (f y)) + (f (fp.max x y)))))) ; Expect unsat + (check-sat) + */ + if (lhs->opcode() == rhs->opcode() && IsNondecreasingSublinear(lhs)) { + TF_ASSIGN_OR_RETURN( + auto new_maximum, + MakeBinaryHlo(HloOpcode::kMaximum, lhs->mutable_operand(0), + rhs->mutable_operand(0))); + VLOG(10) << "Sinking nondecreasing op through max"; + return ReplaceWithNewInstruction( + maximum, HloInstruction::CreateUnary(maximum->shape(), lhs->opcode(), + new_maximum)); + } + return absl::OkStatus(); } @@ -6919,7 +6958,8 @@ absl::Status AlgebraicSimplifierVisitor::HandleDynamicSlice( // Convert a dynamic slice into a slice if all offsets are constant, the // operand is not constant, and the input and output memory spaces are the // same. - if (operand->opcode() != HloOpcode::kConstant && + if (!options_.disable_dynamic_slice_to_slice_conversion() && + operand->opcode() != HloOpcode::kConstant && absl::c_all_of(absl::MakeSpan(dynamic_slice->operands().begin() + 1, dynamic_slice->operands().end()), [](HloInstruction* operand) { @@ -7359,6 +7399,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (multi_output_reduce) { std::vector broadcast_inits; int64_t inputs = reduce->input_count(); + broadcast_inits.reserve(inputs); for (int64_t i = 0; i < inputs; ++i) { broadcast_inits.push_back(reduce->init_values()[i]->AddInstruction( HloInstruction::CreateBroadcast(reduce->shape().tuple_shapes(i), @@ -7404,6 +7445,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { if (multi_output_reduce) { std::vector reshaped_args; int64_t inputs = reduce->input_count(); + reshaped_args.reserve(inputs); for (int64_t i = 0; i < inputs; ++i) { reshaped_args.push_back( reduce->AddInstruction(HloInstruction::CreateReshape( @@ -7837,27 +7879,6 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } - // Replace Reduce(Broadcast(x), dims, Sum()) with Broadcast(x * prod(dims)). - if (HloInstruction * broadcast_arg; - Match(arg, m::Broadcast(m::ConstantScalar(&broadcast_arg))) && - Match(function->root_instruction(), - m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { - if (auto broadcast_value = GetConstantValue(broadcast_arg); - broadcast_value.has_value() && - // Skip float64, where product is too accurate compared to repeated-sum. - broadcast_arg->shape().element_type() != PrimitiveType::F64) { - auto result_value = broadcast_value.value() * - ShapeUtil::ElementsIn(arg->shape()) / - ShapeUtil::ElementsIn(reduce_result_shape); - return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateBroadcast( - reduce_result_shape, - reduce->AddInstruction( - MakeScalarInstruction(reduce, result_value)), - {})); - } - } - // For Computation equal to Min, Max, And or Or, replace Reduce(Broadcast(x), // a, Computation()) with Computation(x, a) when x is a scalar and the // broadcast is reduced to a scalar. @@ -7881,26 +7902,52 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* hlo) { } } - // Replace Reduce(Broadcast(Scalar)) with Broadcast(Multiply(Scalar)) when the - // reduction operation is addition + // Replace Reduce(Broadcast(x), +, init_value) with Broadcast(Add(Multiply(x), + // init_value))) if all reduction dimensions were introduced by Broadcast if (arg->opcode() == HloOpcode::kBroadcast && - ShapeUtil::IsScalar(arg->operand(0)->shape())) { - if (Match(reduce->to_apply()->root_instruction(), - m::AddAnyOrder(m::Parameter(0), m::Parameter(1))) && - IsScalarConstantZero(init_value)) { - int64_t reduction_dims_prod = 1; - for (auto i : reduce->dimensions()) { - reduction_dims_prod *= arg->shape().dimensions(i); + Match(reduce->to_apply()->root_instruction(), + m::AddAnyOrder(m::Parameter(0), m::Parameter(1)))) { + bool only_reduce_dims_from_broadcast = true; + int64_t common_dims_prod = 1; + int64_t num_common_dims = 0; + Shape new_broadcast_shape = arg->shape(); + std::vector new_broadcast_dims; + + // Now we build up the new broadcast shape and dims vector + for (int64_t i = 0; i < arg->shape().rank(); ++i) { + bool added_by_broadcast = !absl::c_linear_search(arg->dimensions(), i); + bool removed_by_reduce = absl::c_linear_search(reduce->dimensions(), i); + + if (removed_by_reduce && !added_by_broadcast) { + only_reduce_dims_from_broadcast = false; + break; + } else if (removed_by_reduce && added_by_broadcast) { + new_broadcast_shape.DeleteDimension(i - num_common_dims); + common_dims_prod *= arg->shape().dimensions(i); + num_common_dims++; + } else if (!removed_by_reduce && !added_by_broadcast) { + new_broadcast_dims.push_back(i - num_common_dims); } + } + if (only_reduce_dims_from_broadcast) { + // HloConstantFolding will later remove any unnecessary multiply and add + // instructions. HloInstruction* multiplier = - MakeScalarLike(arg->mutable_operand(0), reduction_dims_prod); + MakeScalarLike(arg->mutable_operand(0), common_dims_prod); TF_ASSIGN_OR_RETURN(HloInstruction * multiplied_scalar, MakeBinaryHlo(HloOpcode::kMultiply, arg->mutable_operand(0), multiplier)); + TF_ASSIGN_OR_RETURN( + HloInstruction * add, + MakeBinaryHlo( + HloOpcode::kAdd, + MakeBroadcastHlo(init_value, {}, multiplied_scalar->shape()), + multiplied_scalar)); + VLOG(10) << "Converting common reduce(broadcast) dimensions to multiply"; return ReplaceWithNewInstruction( - reduce, HloInstruction::CreateBroadcast(reduce->shape(), - multiplied_scalar, {})); + reduce, HloInstruction::CreateBroadcast(new_broadcast_shape, add, + new_broadcast_dims)); } } @@ -7930,6 +7977,7 @@ absl::Status AlgebraicSimplifierVisitor::HandleReduceWindow( if (ShapeUtil::IsZeroElementArray(*input_shapes[0]) || ShapeUtil::IsZeroElementArray(*output_shapes[0])) { std::vector broadcast_inits; + broadcast_inits.reserve(input_count); for (int64_t i = 0; i < input_count; ++i) { broadcast_inits.push_back( hlo->AddInstruction(HloInstruction::CreateBroadcast( diff --git a/third_party/xla/xla/service/algebraic_simplifier.h b/third_party/xla/xla/service/algebraic_simplifier.h index bdd4f915a54ec9..185792b336cfa8 100644 --- a/third_party/xla/xla/service/algebraic_simplifier.h +++ b/third_party/xla/xla/service/algebraic_simplifier.h @@ -287,6 +287,14 @@ class AlgebraicSimplifierOptions { executing_on_cpu_ = executing_on_cpu; } + // Option to disable conversion of dynamic-slice to slice. + void set_disable_dynamic_slice_to_slice_conversion(bool disable) { + disable_dynamic_slice_to_slice_conversion_ = disable; + } + bool disable_dynamic_slice_to_slice_conversion() const { + return disable_dynamic_slice_to_slice_conversion_; + } + private: // Metadata struct can be used to store any metadata information encapsulated // with the AlgebraicSimplifierOptions that can be later used in an @@ -325,6 +333,7 @@ class AlgebraicSimplifierOptions { bool raise_slice_and_reduce_through_dot_{false}; double raise_slice_and_reduce_through_dot_threshold_{2.0}; bool use_convert_constant_folding_{false}; + bool disable_dynamic_slice_to_slice_conversion_{false}; Metadata metadata_; }; @@ -487,6 +496,10 @@ class AlgebraicSimplifierVisitor : public DfsHloRewriteVisitor { static bool IsNonNegative(const HloInstruction* hlo, const AlgebraicSimplifierOptions& options); + // Check if the opcode of a given instruction is a non-decreasing function + // asymptotically satisfying |f(x)| <= |x| + static bool IsNondecreasingSublinear(const HloInstruction* hlo); + // Modify the layout dimensions of result_shape, so that it becomes the // re-shaped result of applying bitcast to the original_shape, by using // dim_map to re-shape layout dimensions of original_shape. Returns the diff --git a/third_party/xla/xla/service/algebraic_simplifier_test.cc b/third_party/xla/xla/service/algebraic_simplifier_test.cc index 5b8e4db491c13b..b36c9ca5b5cf79 100644 --- a/third_party/xla/xla/service/algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/algebraic_simplifier_test.cc @@ -58,10 +58,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -6411,6 +6411,94 @@ TEST_F(AlgebraicSimplifierTest, DotAssociativeReorder) { m::Dot(m::Parameter(1), m::Parameter(2))))); } +TEST_F(AlgebraicSimplifierTest, DotLeftDotSharedBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[5,150,5] parameter(0) + b = f32[5,5,5] parameter(1) + c = f32[5,5,5] parameter(2) + + inner = f32[5,150,5] dot(a,b), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + ROOT outer = f32[5,150,5] dot(inner,c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Parameter(0), + m::Dot(m::Parameter(1), m::Parameter(2))))); +} + +TEST_F(AlgebraicSimplifierTest, DotRightDotSharedBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[2,3,3] parameter(0) + b = f32[2,3,3] parameter(1) + c = f32[2,3,16] parameter(2) + + inner = f32[2,3,16] dot(b,c), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + ROOT outer = f32[2,3,16] dot(a,inner), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).value()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Dot(m::Dot(m::Parameter(0), m::Parameter(1)), + m::Parameter(2)))); +} + +TEST_F(AlgebraicSimplifierTest, DotRightDotContractBatchReorder) { + const char* hlo_string = R"( + HloModule module + + ENTRY test { + a = f32[80,38,1536] parameter(0) + b = f32[80,38,4] parameter(1) + c = f32[80,4,1536] parameter(2) + inner = f32[80,38,1536] dot(b, c), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + ROOT outer = f32[1536,1536] dot(a, inner), + lhs_contracting_dims={0,1}, + rhs_contracting_dims={0,1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + options.set_use_associative_reordering(true); + options.set_associative_reordering_threshold(1.5); + AlgebraicSimplifier simplifier(options); + EXPECT_FALSE(simplifier.Run(module.get()).value()); +} + TEST_F(AlgebraicSimplifierTest, DotReverseLeftReorder) { const char* hlo_string = R"( HloModule module @@ -10075,24 +10163,25 @@ TEST_F(AlgebraicSimplifierTest, UnaryVariadicReduce) { TEST_F(AlgebraicSimplifierTest, ReplaceReduceSumOfConstantBroadcast) { const char* kModuleStr = R"( -HloModule ReplaceReduceSumOfConstantBroadcast + HloModule ReplaceReduceSumOfConstantBroadcast -add_f32 { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT r = f32[] add(p0, p1) -} + add_f32 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT r = f32[] add(p0, p1) + } -ENTRY main { - init_value = f32[] constant(0) - const_value = f32[] constant(1) - const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={} - ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32 -} -)"; + ENTRY main { + init_value = f32[] constant(0) + const_value = f32[] constant(1) + const_bcast = f32[8, 128] broadcast(f32[] const_value), dimensions={} + ROOT reduce = f32[8] reduce(f32[8, 128] const_bcast, f32[] init_value), dimensions={1}, to_apply=add_f32 + } + )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); - ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); int64_t reduce_count = absl::c_count_if(m->entry_computation()->instructions(), HloPredicateIsOp); @@ -11714,6 +11803,88 @@ ENTRY main.1 { HloOpcode::kParameter); } +TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastS32) { + const std::string hlo_string = R"( + HloModule test + add_s32 { + p0 = s32[] parameter(0) + p1 = s32[] parameter(1) + ROOT r = s32[] add(p0, p1) + } + ENTRY test.1 { + one = s32[] constant(2) + init = s32[] constant(10) + bcast = s32[1,7,7,1] broadcast(one), dimensions={} + ROOT out = s32[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_s32 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto clone = m->Clone(); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); + std::cout << m->ToString() << std::endl; + int64_t reduce_count = + absl::c_count_if(m->entry_computation()->instructions(), + HloPredicateIsOp); + // Expect no Reduce operation after simplification. + EXPECT_EQ(0, reduce_count); +} + +TEST_F(AlgebraicSimplifierTest, ReduceOfConstantBroadcastBF16) { + const std::string hlo_string = R"( + HloModule test + add_bf16 { + p0 = bf16[] parameter(0) + p1 = bf16[] parameter(1) + ROOT r = bf16[] add(p0, p1) + } + ENTRY test.1 { + one = bf16[] constant(2.12) + init = bf16[] constant(10.34) + bcast = bf16[1,7,7,1] broadcast(one), dimensions={} + ROOT out = bf16[1,7,1] reduce(bcast, init), dimensions={1}, to_apply=add_bf16 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + auto clone = m->Clone(); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); + int64_t reduce_count = + absl::c_count_if(m->entry_computation()->instructions(), + HloPredicateIsOp); + // Expect no Reduce operation after simplification. + EXPECT_EQ(0, reduce_count); +} + +TEST_F(AlgebraicSimplifierTest, ReduceOfNonScalarBroadcast) { + const std::string hlo_string = R"( + HloModule module + add { + a = f32[] parameter(0) + b = f32[] parameter(1) + ROOT sum = f32[] add(a, b) + } + + ENTRY test { + a = f32[64,1001] parameter(0) + broadcast = f32[64,7,7,1001] broadcast(a), dimensions={0,3} + zero = f32[] constant(0) + ROOT reduce = f32[64,7,1001] reduce(broadcast, zero), dimensions={2}, + to_apply=add + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + int64_t reduce_count = + absl::c_count_if(m->entry_computation()->instructions(), + HloPredicateIsOp); + // Expect no Reduce operation after simplification. + EXPECT_EQ(0, reduce_count); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Multiply()))); +} + TEST_F(AlgebraicSimplifierTest, RemoveConvertConstant) { const std::string hlo_string = R"( HloModule module @@ -11773,11 +11944,86 @@ TEST_F(AlgebraicSimplifierTest, ReduceBroadcastScalarToBroadcastMultiply) { } )"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); - EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloPassFix simplifier(default_options_); + EXPECT_TRUE(simplifier.Run(m.get()).value()); HloInstruction* root = m->entry_computation()->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast); EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMultiply); } +TEST_F(AlgebraicSimplifierTest, SinkCbrtThroughMax) { + absl::string_view hlo_string = R"( + HloModule module + + ENTRY test { + a = bf16[17,96,120] parameter(0) + b = bf16[17,96,120] parameter(1) + cbrt_a = bf16[17,96,120] cbrt(a) + cbrt_b = bf16[17,96,120] cbrt(b) + ROOT max = bf16[17,96,120] maximum(cbrt_a, cbrt_b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_string)); + EXPECT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).value()); + HloInstruction* root = m->entry_computation()->root_instruction(); + EXPECT_THAT( + root, GmockMatch(m::Cbrt(m::Maximum(m::Parameter(0), m::Parameter(1))))); +} + +TEST_F(AlgebraicSimplifierTest, + DynamicSlicePreservedWithTrivialConstantIndices) { + const char* hlo_string = R"( + HloModule module + + ENTRY f { + %operand = s32[2,2] parameter(0) + %constant = u32[] constant(0) + ROOT %dynamic-slice = s32[2,1] dynamic-slice(%operand, %constant, %constant), + dynamic_slice_sizes={2,1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Disable dynamic-slice to slice conversion + default_options_.set_disable_dynamic_slice_to_slice_conversion(true); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_FALSE(simplifier.Run(module.get()).value()); + + // Expect the dynamic-slice to be preserved + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::DynamicSlice(m::Parameter(0), m::Constant(), + m::Constant()))); +} + +TEST_F(AlgebraicSimplifierTest, + DynamicSliceConvertedToConstantSliceWithConstantIndices) { + const char* hlo_string = R"( + HloModule module + + ENTRY f { + %operand = s32[2,2] parameter(0) + %constant = u32[] constant(0) + ROOT %dynamic-slice = s32[2,1] dynamic-slice(%operand, %constant, %constant), + dynamic_slice_sizes={2,1} + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Enable dynamic-slice to slice conversion (default behavior) + ASSERT_FALSE(default_options_.disable_dynamic_slice_to_slice_conversion()); + + AlgebraicSimplifier simplifier(default_options_); + ASSERT_TRUE(simplifier.Run(module.get()).value()); + + // Expect the dynamic-slice to be converted to a constant slice + EXPECT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(m::Slice(m::Parameter(0)))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/all_reduce_simplifier_test.cc b/third_party/xla/xla/service/all_reduce_simplifier_test.cc index e78881a0c19292..35f5955076ad7e 100644 --- a/third_party/xla/xla/service/all_reduce_simplifier_test.cc +++ b/third_party/xla/xla/service/all_reduce_simplifier_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/window_util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/async_collective_creator_test.cc b/third_party/xla/xla/service/async_collective_creator_test.cc index 8c9b574003da9a..75556260cf2e14 100644 --- a/third_party/xla/xla/service/async_collective_creator_test.cc +++ b/third_party/xla/xla/service/async_collective_creator_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc new file mode 100644 index 00000000000000..441c3b69f3da28 --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/batched_gather_scatter_normalizer.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/types/span.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { + +namespace { +bool IsBatchGather(const HloGatherInstruction* gather) { + const auto& dims = gather->gather_dimension_numbers(); + return !dims.operand_batching_dims().empty(); +} + +bool IsBatchScatter(const HloScatterInstruction* scatter) { + const auto& dims = scatter->scatter_dimension_numbers(); + return !dims.input_batching_dims().empty(); +} + +// Update gather/scater indices by adding fake batching iota dimensions. +HloInstruction* CreateConcatIndices( + HloInstruction* inst, HloInstruction* indices, int64_t index_vector_dim, + absl::Span indices_batching_dims, + BatchedGatherScatterNormalizer* normalizer) { + const bool index_vector_dim_on_last_dim = + index_vector_dim == indices->shape().rank(); + + Shape iota_shape = indices->shape(); + if (index_vector_dim_on_last_dim) { + std::vector dimensions(iota_shape.dimensions().begin(), + iota_shape.dimensions().end()); + dimensions.push_back(1); + iota_shape = ShapeUtil::MakeShape(iota_shape.element_type(), dimensions); + } + iota_shape.set_dimensions(index_vector_dim, 1); + normalizer->UpdateLayout(&iota_shape); + + std::vector indices_to_concat; + for (int64_t indices_batching_dim : indices_batching_dims) { + indices_to_concat.push_back(inst->parent()->AddInstruction( + HloInstruction::CreateIota(iota_shape, indices_batching_dim))); + } + if (index_vector_dim_on_last_dim) { + std::vector dimensions(indices->shape().dimensions().begin(), + indices->shape().dimensions().end()); + dimensions.push_back(1); + Shape reshape_shape = + ShapeUtil::MakeShape(indices->shape().element_type(), dimensions); + normalizer->UpdateLayout(&reshape_shape); + HloInstruction* reshaped_indices = inst->AddInstruction( + HloInstruction::CreateReshape(reshape_shape, indices)); + indices_to_concat.push_back(reshaped_indices); + } else { + indices_to_concat.push_back(indices); + } + Shape concat_shape = iota_shape; + concat_shape.set_dimensions( + index_vector_dim, + indices_batching_dims.size() + + (index_vector_dim_on_last_dim + ? 1 + : indices->shape().dimensions(index_vector_dim))); + normalizer->UpdateLayout(&concat_shape); + return inst->AddInstruction(HloInstruction::CreateConcatenate( + concat_shape, indices_to_concat, index_vector_dim)); +} + +absl::StatusOr NormalizeBatchGather( + HloGatherInstruction* gather, BatchedGatherScatterNormalizer* normalizer) { + HloInstruction* gather_operand = gather->mutable_operand(0); + HloInstruction* gather_indices = gather->mutable_operand(1); + const auto& dims = gather->gather_dimension_numbers(); + CHECK_EQ(dims.operand_batching_dims_size(), + dims.start_indices_batching_dims_size()); + // Update start_index_map. + std::vector start_index_map(dims.operand_batching_dims().begin(), + dims.operand_batching_dims().end()); + absl::c_copy(dims.start_index_map(), std::back_inserter(start_index_map)); + gather_indices = + CreateConcatIndices(gather, gather_indices, dims.index_vector_dim(), + dims.start_indices_batching_dims(), normalizer); + // Update collapsed_slice_dims. + std::vector collapsed_slice_dims(dims.collapsed_slice_dims().begin(), + dims.collapsed_slice_dims().end()); + absl::c_copy(dims.operand_batching_dims(), + std::back_inserter(collapsed_slice_dims)); + absl::c_sort(collapsed_slice_dims); + + GatherDimensionNumbers updated_dims = + HloGatherInstruction::MakeGatherDimNumbers( + dims.offset_dims(), collapsed_slice_dims, start_index_map, + dims.index_vector_dim()); + return gather->AddInstruction(HloInstruction::CreateGather( + gather->shape(), gather_operand, gather_indices, updated_dims, + gather->gather_slice_sizes(), gather->indices_are_sorted())); +} + +absl::StatusOr NormalizeBatchScatter( + HloScatterInstruction* scatter, + BatchedGatherScatterNormalizer* normalizer) { + auto scatter_operands = scatter->scatter_operands(); + HloInstruction* scatter_indices = scatter->scatter_indices(); + auto scatter_updates = scatter->scatter_updates(); + const auto& dims = scatter->scatter_dimension_numbers(); + CHECK_EQ(dims.input_batching_dims_size(), + dims.scatter_indices_batching_dims_size()); + // Update scatter_dims_to_operand_dims. + std::vector scatter_dims_to_operand_dims( + dims.input_batching_dims().begin(), dims.input_batching_dims().end()); + absl::c_copy(dims.scatter_dims_to_operand_dims(), + std::back_inserter(scatter_dims_to_operand_dims)); + scatter_indices = + CreateConcatIndices(scatter, scatter_indices, dims.index_vector_dim(), + dims.scatter_indices_batching_dims(), normalizer); + // Update inserted_window_dims. + std::vector inserted_window_dims(dims.inserted_window_dims().begin(), + dims.inserted_window_dims().end()); + absl::c_copy(dims.input_batching_dims(), + std::back_inserter(inserted_window_dims)); + absl::c_sort(inserted_window_dims); + + ScatterDimensionNumbers updated_dims = + HloScatterInstruction::MakeScatterDimNumbers( + dims.update_window_dims(), inserted_window_dims, + scatter_dims_to_operand_dims, dims.index_vector_dim()); + return scatter->AddInstruction(HloInstruction::CreateScatter( + scatter->shape(), scatter_operands, scatter_indices, scatter_updates, + scatter->to_apply(), updated_dims, scatter->indices_are_sorted(), + scatter->unique_indices())); +} + +} // namespace + +absl::StatusOr +BatchedGatherScatterNormalizer::ExpandInstruction(HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kGather) { + auto* gather = DynCast(inst); + return NormalizeBatchGather(gather, this); + } + if (inst->opcode() == HloOpcode::kScatter) { + auto* scatter = DynCast(inst); + return NormalizeBatchScatter(scatter, this); + } + return absl::InvalidArgumentError(absl::StrFormat( + "Instruction: %s is not a batch gather or scatter.", inst->ToString())); +} + +bool BatchedGatherScatterNormalizer::InstructionMatchesPattern( + HloInstruction* inst) { + if (inst->opcode() == HloOpcode::kGather) { + auto* gather = DynCast(inst); + return IsBatchGather(gather); + } + if (inst->opcode() == HloOpcode::kScatter) { + auto* scatter = DynCast(inst); + return IsBatchScatter(scatter); + } + return false; +} + +} // namespace xla diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer.h b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h new file mode 100644 index 00000000000000..4b5560d38dceec --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer.h @@ -0,0 +1,42 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ +#define XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/op_expander_pass.h" + +namespace xla { + +// This pass rewrites normalize batch gather and scatter operations into a +// non-batch version. +class BatchedGatherScatterNormalizer : public OpExpanderPass { + public: + absl::string_view name() const override { + return "gather_scatter_normalizer"; + } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla + +#endif // XLA_SERVICE_BATCHED_GATHER_SCATTER_NORMALIZER_H_ diff --git a/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc new file mode 100644 index 00000000000000..81f0882c977ca2 --- /dev/null +++ b/third_party/xla/xla/service/batched_gather_scatter_normalizer_test.cc @@ -0,0 +1,186 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/batched_gather_scatter_normalizer.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +class BatchedGatherScatterNormalizerTest : public HloTestBase {}; + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGather) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512], start_indices: s64[10,9,8,7,5,512]) -> f32[10,9,8,7,30,29,28,27,26,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} + gather(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %start_indices), + offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, operand_batching_dims={5}, + start_indices_batching_dims={5}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA:.*]] = s64[10,9,8,7,1,512]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512]{{.*}} concatenate(%[[IOTA]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,30,29,28,27,26,512]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={4,5,6,7,8}, + CHECK-SAME: collapsed_slice_dims={5}, + CHECK-SAME: start_index_map={5,0,1,2,3,4}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: slice_sizes={30,29,28,27,26,1} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchGather2) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512,1024,100], start_indices: s64[10,9,8,7,6,512,1024]) -> f32[10,9,8,7,30,29,28,27,26,512,1024] { + %input_tensor = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} + gather(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %start_indices), + offset_dims={4,5,6,7,8}, collapsed_slice_dims={7}, start_index_map={0,1,2,3,4,7}, operand_batching_dims={5,6}, + start_indices_batching_dims={5,6}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1,1,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,8,512,1024]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %start_indices) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,30,29,28,27,26,512,1024]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={4,5,6,7,8}, + CHECK-SAME: collapsed_slice_dims={5,6,7}, + CHECK-SAME: start_index_map={5,6,0,1,2,3,4,7}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: slice_sizes={30,29,28,27,26,1,1,1} + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatter) { + constexpr absl::string_view kModuleStr = R"( + +HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512]{5,4,3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512], scatter_indices: s64[10,9,8,7,5,512], updates: f32[10,9,8,7,30,29,28,27,26,512]) -> f32[50,49,48,47,46,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46,512]{5,4,3,2,1,0} scatter( + f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, + s64[10,9,8,7,5,512]{5,4,3,2,1,0} %scatter_indices, + f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} %updates), + update_window_dims={4,5,6,7,8}, inserted_window_dims={}, + scatter_dims_to_operand_dims={0,1,2,3,4}, input_batching_dims={5}, + scatter_indices_batching_dims={5}, index_vector_dim=4, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA:.*]] = s64[10,9,8,7,1,512]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512]{{.*}} concatenate(%[[IOTA]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[50,49,48,47,46,512]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={4,5,6,7,8}, + CHECK-SAME: inserted_window_dims={5}, + CHECK-SAME: scatter_dims_to_operand_dims={5,0,1,2,3,4}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, NormalizeBatchScatter2) { + constexpr absl::string_view kModuleStr = R"( + +HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512,1024,100], scatter_indices: s64[10,9,8,7,6,512,1024], updates: f32[10,9,8,7,30,29,28,27,26,512,1024]) -> f32[50,49,48,47,46,512,1024,100] { + %input_tensor = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} scatter( + f32[50,49,48,47,46,512,1024,100]{7,6,5,4,3,2,1,0} %input_tensor, + s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %scatter_indices, + f32[10,9,8,7,30,29,28,27,26,512,1024]{10,9,8,7,6,5,4,3,2,1,0} %updates), + update_window_dims={4,5,6,7,8}, inserted_window_dims={7}, + scatter_dims_to_operand_dims={0,1,2,3,4,7}, input_batching_dims={5,6}, + scatter_indices_batching_dims={5,6}, index_vector_dim=4, to_apply=%add_F32.v3 +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,1,512,1024]{{.*}} iota() + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,8,512,1024]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %scatter_indices) + CHECK: ROOT %[[SCATTER:.*]] = f32[50,49,48,47,46,512,1024,100]{{.*}} scatter( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]], %updates), + CHECK-SAME: update_window_dims={4,5,6,7,8}, + CHECK-SAME: inserted_window_dims={5,6,7}, + CHECK-SAME: scatter_dims_to_operand_dims={5,6,0,1,2,3,4,7}, + CHECK-SAME: index_vector_dim=4, + CHECK-SAME: to_apply=%add_F32.v3 + )"); +} + +TEST_F(BatchedGatherScatterNormalizerTest, IndexVectorDimOnLastDim) { + constexpr absl::string_view kModuleStr = R"( +HloModule StringifyGather, entry_computation_layout={(f32[50,512,1024]{2,1,0}, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0})->f32[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,512,1024], start_indices: s64[10,9,8,7,6,512,1024]) -> f32[10,9,8,7,6,512,1024] { + %input_tensor = f32[50,512,1024]{2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} + gather(f32[50,512,1024]{2,1,0} %input_tensor, s64[10,9,8,7,6,512,1024]{6,5,4,3,2,1,0} %start_indices), + offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, operand_batching_dims={1,2}, + start_indices_batching_dims={5,6}, index_vector_dim=7, slice_sizes={1,1,1} +})"; + + RunAndFilecheckHloRewrite(kModuleStr, BatchedGatherScatterNormalizer(), R"( + CHECK: %[[IOTA1:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} iota() + CHECK: %[[IOTA2:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} iota() + CHECK: %[[RESHAPE:.*]] = s64[10,9,8,7,6,512,1024,1]{{.*}} reshape(%start_indices) + CHECK: %[[INDICES_CONCAT:.*]] = s64[10,9,8,7,6,512,1024,3]{{.*}} concatenate(%[[IOTA1]], %[[IOTA2]], %[[RESHAPE]]) + CHECK: ROOT %[[GATHER:.*]] = f32[10,9,8,7,6,512,1024]{{.*}} gather( + CHECK-SAME: %input_tensor, %[[INDICES_CONCAT]]), + CHECK-SAME: offset_dims={}, + CHECK-SAME: collapsed_slice_dims={0,1,2}, + CHECK-SAME: start_index_map={1,2,0}, + CHECK-SAME: index_vector_dim=7, + CHECK-SAME: slice_sizes={1,1,1} + )"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc index b145e8ceb7b5fe..a5dc3b882446cc 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc +++ b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/buffer_assignment_test.cc b/third_party/xla/xla/service/buffer_assignment_test.cc index a11b86ca357043..04238c4fd39f5a 100644 --- a/third_party/xla/xla/service/buffer_assignment_test.cc +++ b/third_party/xla/xla/service/buffer_assignment_test.cc @@ -46,9 +46,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/call_graph.cc b/third_party/xla/xla/service/call_graph.cc index ea16ca0c57f7e7..80515e13ea7515 100644 --- a/third_party/xla/xla/service/call_graph.cc +++ b/third_party/xla/xla/service/call_graph.cc @@ -214,8 +214,8 @@ CallContext UnionContexts(CallContext a, CallContext b) { } else if (a == b) { return a; } else { - // Contexts are different and neither is kNone, ie one is kSequential and - // the other is kParallel. + // Contexts are different and neither is kNone, i.e. one is kControlFlow and + // the other is kEmbedded. return CallContext::kBoth; } } diff --git a/third_party/xla/xla/service/call_graph.h b/third_party/xla/xla/service/call_graph.h index c6f933ef1a250d..0d15a64cafd144 100644 --- a/third_party/xla/xla/service/call_graph.h +++ b/third_party/xla/xla/service/call_graph.h @@ -141,7 +141,7 @@ class CallGraphNode { CallGraphNode(const CallGraphNode&) = delete; CallGraphNode& operator=(const CallGraphNode&) = delete; CallGraphNode(CallGraphNode&&) = default; - CallGraphNode& operator=(CallGraphNode&&) = default; + CallGraphNode& operator=(CallGraphNode&&) noexcept = default; private: // Only CallGraph can modify CallGraphNode. diff --git a/third_party/xla/xla/service/call_graph_test.cc b/third_party/xla/xla/service/call_graph_test.cc index dfa7d28f06ab1d..a619cd5ffe6e28 100644 --- a/third_party/xla/xla/service/call_graph_test.cc +++ b/third_party/xla/xla/service/call_graph_test.cc @@ -36,9 +36,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/call_inliner.cc b/third_party/xla/xla/service/call_inliner.cc index 579de41179ff76..0605fbd6457ff7 100644 --- a/third_party/xla/xla/service/call_inliner.cc +++ b/third_party/xla/xla/service/call_inliner.cc @@ -139,6 +139,16 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { CallInliner::Inline(HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); + if (call->is_composite()) { + // Remove composite FE attrs before inlining, else they will appear on the + // inlined instructions. + FrontendAttributes frontend_attributes = call->frontend_attributes(); + frontend_attributes.mutable_map()->erase("composite.name"); + frontend_attributes.mutable_map()->erase("composite.attributes"); + frontend_attributes.mutable_map()->erase("composite.version"); + call->set_frontend_attributes(frontend_attributes); + } + const auto& callees = call->called_computations(); TF_RET_CHECK(callees.size() == 1); HloComputation* callee = callees[0]; diff --git a/third_party/xla/xla/service/call_inliner_test.cc b/third_party/xla/xla/service/call_inliner_test.cc index 4248c012444803..ad6ee73eb14e8a 100644 --- a/third_party/xla/xla/service/call_inliner_test.cc +++ b/third_party/xla/xla/service/call_inliner_test.cc @@ -15,25 +15,25 @@ limitations under the License. #include "xla/service/call_inliner.h" +#include #include -#include #include -#include -#include +#include "absl/log/log.h" +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/layout_util.h" #include "xla/literal.h" -#include "xla/service/hlo_pass_fix.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "xla/types.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace op = xla::testing::opcode_matchers; @@ -346,5 +346,36 @@ ENTRY %main_outer (p0: u32[]) -> u32[] { op::Constant(LiteralUtil::CreateR0(2)))); } } + +TEST_F(CallInlinerTest, InlineCompositeCall) { + const absl::string_view hlo_string = R"( + HloModule composite + + %add (lhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] constant(2) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) + } + + ENTRY %main () -> f32[] { + %lhs = f32[] constant(42) + ROOT %call = f32[] call(f32[] %lhs), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + })"; + + auto module = ParseAndReturnVerifiedModule(hlo_string).value(); + CallInliner call_inliner(/*single_call_site=*/true); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); + + ASSERT_EQ(module->entry_computation()->instruction_count(), 3); + auto inst = module->entry_computation()->instructions().begin(); + EXPECT_THAT(*inst, op::Constant()); + ++inst; + EXPECT_THAT(*inst, op::Constant()); + ++inst; + EXPECT_THAT(*inst, op::Add()); + EXPECT_TRUE((*inst)->frontend_attributes().map().empty()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/change_op_data_type.cc b/third_party/xla/xla/service/change_op_data_type.cc index 7c907d80c84e25..3c7875a2836ceb 100644 --- a/third_party/xla/xla/service/change_op_data_type.cc +++ b/third_party/xla/xla/service/change_op_data_type.cc @@ -19,8 +19,7 @@ limitations under the License. #include "xla/service/hlo_creation_utils.h" #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) -#include "xla/service/cpu/onednn_convolution_rewriter.h" -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" #endif // INTEL_MKL && ENABLE_ONEDNN_V3 namespace xla { @@ -65,11 +64,11 @@ absl::StatusOr ChangeOpDataType::Run( } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) if (instr->opcode() == HloOpcode::kDot && - cpu::OneDnnMatMulRewriter::ShouldRewrite(instr)) { + cpu::OneDnnContractionRewriter::ShouldRewriteDot(instr, true)) { continue; } if (instr->opcode() == HloOpcode::kConvolution && - cpu::OneDnnConvolutionRewriter::ShouldRewrite(instr)) { + cpu::OneDnnContractionRewriter::ShouldRewriteConv(instr)) { continue; } #endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/collective_ops_utils.cc b/third_party/xla/xla/service/collective_ops_utils.cc index 11095f87bee210..e3949569386bf3 100644 --- a/third_party/xla/xla/service/collective_ops_utils.cc +++ b/third_party/xla/xla/service/collective_ops_utils.cc @@ -582,7 +582,7 @@ bool ReplicaGroupsEqual(absl::Span first, return true; } -bool IsCollective(const HloInstruction* instruction) { +bool IsNonFusionCollective(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kAllReduceStart: @@ -597,24 +597,30 @@ bool IsCollective(const HloInstruction* instruction) { case HloOpcode::kCollectivePermuteDone: case HloOpcode::kReduceScatter: return true; - case HloOpcode::kFusion: - if (instruction->IsCustomFusion()) { - for (const auto* inner_inst : instruction->fused_instructions()) { - if (IsCollective(inner_inst)) { - return true; - } - } - } - return false; case HloOpcode::kAsyncStart: case HloOpcode::kAsyncUpdate: case HloOpcode::kAsyncDone: - return IsCollective(instruction->async_wrapped_instruction()); + return IsNonFusionCollective(instruction->async_wrapped_instruction()); default: return false; } } +bool IsCollective(const HloInstruction* instruction) { + if (IsNonFusionCollective(instruction)) { + return true; + } + if (instruction->opcode() == HloOpcode::kFusion && + instruction->IsCustomFusion()) { + for (const auto* inner_inst : instruction->fused_instructions()) { + if (IsCollective(inner_inst)) { + return true; + } + } + } + return false; +} + HloInstruction* IsOrHasCollectiveWithChannelId(HloInstruction* instruction) { if (instruction->opcode() == HloOpcode::kFusion) { for (auto* inner_inst : instruction->fused_instructions()) { diff --git a/third_party/xla/xla/service/collective_ops_utils.h b/third_party/xla/xla/service/collective_ops_utils.h index c611d57a6e6264..3c2ebd3d523da0 100644 --- a/third_party/xla/xla/service/collective_ops_utils.h +++ b/third_party/xla/xla/service/collective_ops_utils.h @@ -196,6 +196,10 @@ inline constexpr absl::string_view kNopCustomCallTarget = "AllocateBuffer"; inline constexpr absl::string_view kNopReturnTokenCustomCallTarget = "NopReturnToken"; +// Returns true if instruction is a collective op that is not a collective +// fusion. +bool IsNonFusionCollective(const HloInstruction* instruction); + // Returns true if instruction is a collective op or a collective fusion. bool IsCollective(const HloInstruction* instruction); diff --git a/third_party/xla/xla/service/collective_ops_utils_test.cc b/third_party/xla/xla/service/collective_ops_utils_test.cc index 64ec33866d2b32..f1a7ab1f4561f0 100644 --- a/third_party/xla/xla/service/collective_ops_utils_test.cc +++ b/third_party/xla/xla/service/collective_ops_utils_test.cc @@ -32,8 +32,8 @@ limitations under the License. #include "xla/service/global_device_id.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/collective_permute_decomposer_test.cc b/third_party/xla/xla/service/collective_permute_decomposer_test.cc index b80a52b51e9f1a..eac5ab0707418a 100644 --- a/third_party/xla/xla/service/collective_permute_decomposer_test.cc +++ b/third_party/xla/xla/service/collective_permute_decomposer_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/utils/hlo_matchers.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo_parser.h" @@ -105,7 +106,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { EXPECT_THAT( recv->ToString(), HasSubstr( - "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); check_metadata(recv); check_not_pipelined(recv); HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); @@ -117,7 +118,7 @@ TEST_F(CollectivePermuteDecomposerTest, TransformedExplicitChannelId) { EXPECT_THAT( send->ToString(), HasSubstr( - "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); check_metadata(send); check_not_pipelined(send); HloInstruction* send_done = FindInstruction(module.get(), "send-done"); @@ -211,7 +212,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { EXPECT_THAT( recv->ToString(), HasSubstr( - "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); EXPECT_THAT(recv->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); HloInstruction* recv_done = FindInstruction(module.get(), "recv-done"); @@ -223,7 +224,7 @@ TEST_F(CollectivePermuteDecomposerTest, Pipeline1) { EXPECT_THAT( send->ToString(), HasSubstr( - "_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3},{3,4}}\"")); + "_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}")); EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); EXPECT_THAT(send->ToString(), HasSubstr("_xla_other_attribute=\"xyz\"")); HloInstruction* send_done = FindInstruction(module.get(), "send-done"); @@ -289,18 +290,18 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { HloInstruction* recv = FindInstruction(module.get(), "recv"); EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT(recv->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}")); EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); HloInstruction* send = FindInstruction(module.get(), "send"); EXPECT_THAT(send->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{3,0}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{3,0}}")); EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); EXPECT_EQ(recv1->channel_id().value(), 2); EXPECT_THAT( recv1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}")); EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); HloInstruction* recv_done1 = FindInstruction(module.get(), "recv-done.1"); EXPECT_THAT(recv_done1->ToString(), @@ -308,13 +309,130 @@ TEST_F(CollectivePermuteDecomposerTest, ForwardPipeline2) { HloInstruction* send1 = FindInstruction(module.get(), "send.1"); EXPECT_THAT( send1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,1},{1,2},{2,3}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3}}")); EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); HloInstruction* send_done1 = FindInstruction(module.get(), "send-done.1"); EXPECT_THAT(send_done1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); } +TEST_F(CollectivePermuteDecomposerTest, ForwardPipelineWithMatmul) { + // The HLO module below is generated by passing the HLO in + // CollectiveOpsTest.CollectivePermute_CircularPipelinePreOptimization through + // the collective_permute_cycle_decomposer.transformation. + const char* const kModuleStr = R"( + HloModule test + + while_body { + inputs = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(inputs), index=0 + iter_increment = u32[] constant(1) + next_iter = u32[] add(iter, iter_increment) + partition-id = u32[] partition-id() + zero = u32[] constant(0) + compare = pred[] compare(partition-id, zero), direction=EQ + broadcast = pred[2,2] broadcast(compare), dimensions={} + + weights = f32[2,2] get-tuple-element(inputs), index=2 + data = f32[2,2] get-tuple-element(inputs), index=1 + + cp_back = f32[2,2] collective-permute(data), channel_id=1, + source_target_pairs={{3,0}}, + frontend_attributes={_xla_send_recv_validation="{{3,10}}"} + cp_forward = f32[2,2] collective-permute(data), channel_id=2, + source_target_pairs={{0,1},{1,2},{2,3}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9}}"} + + select = f32[2,2] select(broadcast, cp_back, cp_forward) + + matmul = f32[2,2] dot(weights, select), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) + } + + while_cond { + inputs = (u32[], f32[2,2], f32[2,2]) parameter(0) + iter = u32[] get-tuple-element(inputs), index=0 + max_iter = u32[] constant(3) + ROOT compare = pred[] compare(iter, max_iter), direction=LT + } + + ENTRY test_computation { + start_iter = u32[] constant(0) + input_data = f32[2,2] parameter(0) + input_weights = f32[2,2] parameter(1) + input = (u32[], f32[2,2], f32[2,2]) tuple(start_iter, input_data, input_weights) + while_result = (u32[], f32[2,2], f32[2,2]) while(input), condition=while_cond, body=while_body + ROOT data_out = f32[2,2] get-tuple-element(while_result), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnUnverifiedModule((kModuleStr))); + CollectivePermuteDecomposer decomposer(/*threshold_in_bytes=*/0); + TF_ASSERT_OK_AND_ASSIGN(bool changed, decomposer.Run(module.get())); + EXPECT_TRUE(changed); + HloModule* transformed_module = module.get(); + // Check the annotations and ordering of the decomposed send-recv pairs. + // We expect the recv to come before the send in the while body, both for the + // forward edge ({0,1},{1,2},{2,3}}) and the backward edge ({3,0}). This is + // an XLA invariant that shouldn't be broken (see + // https://openxla.org/xla/operation_semantics#send for details of the + // semantics). + HloInstruction* recv_bwd = FindInstruction(transformed_module, "recv"); + EXPECT_EQ(recv_bwd->channel_id().value(), 1); + auto recv_bwd_frontend_attributes = recv_bwd->frontend_attributes().map(); + EXPECT_EQ(recv_bwd_frontend_attributes.size(), 3); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvValidationAttr), + "{{3,10}}"); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvPipelineAttr), "0"); + EXPECT_EQ(recv_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{3,0}}"); + + HloInstruction* send_bwd = FindInstruction(transformed_module, "send"); + auto send_bwd_frontend_attributes = send_bwd->frontend_attributes().map(); + EXPECT_THAT(send_bwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{3,0}}"); + + HloInstruction* recv_fwd = FindInstruction(transformed_module, "recv.1"); + EXPECT_EQ(recv_fwd->channel_id().value(), 2); + auto recv_fwd_frontend_attributes = recv_fwd->frontend_attributes().map(); + EXPECT_EQ(recv_fwd_frontend_attributes.size(), 3); + EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); + EXPECT_EQ(recv_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{0,1},{1,2},{2,3}}"); + + HloInstruction* send_fwd = FindInstruction(transformed_module, "send.1"); + auto send_fwd_frontend_attributes = send_fwd->frontend_attributes().map(); + EXPECT_EQ(send_fwd_frontend_attributes.size(), 3); + EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvPipelineAttr), "1"); + EXPECT_EQ(send_fwd_frontend_attributes.at(kSendRecvSourceTargetPairsAttr), + "{{0,1},{1,2},{2,3}}"); + + HloComputation* while_body = + FindComputation(transformed_module, "while_body"); + EXPECT_NE(while_body, nullptr); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv", "send")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "recv", "recv-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send", "recv-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send", "send-done")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "send-done", "send-done.1")); + EXPECT_TRUE( + hlo_query::IsBeforeInComputation(while_body, "recv-done", "send-done.1")); + EXPECT_TRUE(hlo_query::IsBeforeInComputation(while_body, "recv-done.1", + "send-done.1")); + auto recv_done_fwd = FindInstruction(transformed_module, "recv-done"); + auto recv_done_bwd = FindInstruction(transformed_module, "recv-done.1"); + + // TODO: b/356201477 - Investigate potential NCCL deadlock in + // collective_permute_decomposer + EXPECT_EQ(recv_done_fwd->control_predecessors()[0], send_bwd); + EXPECT_EQ(recv_done_bwd->control_predecessors()[0], send_fwd); +} + TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { const char* const kModuleStr = R"( HloModule module @@ -371,22 +489,22 @@ TEST_F(CollectivePermuteDecomposerTest, BackwardPipeline2) { EXPECT_EQ(recv->channel_id().value(), 1); EXPECT_THAT( recv->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}")); EXPECT_THAT(recv->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); HloInstruction* send = FindInstruction(module.get(), "send"); EXPECT_THAT( send->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{1,0},{2,1},{3,2}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{1,0},{2,1},{3,2}}")); EXPECT_THAT(send->ToString(), HasSubstr("_xla_send_recv_pipeline=\"1\"")); HloInstruction* recv1 = FindInstruction(module.get(), "recv.1"); EXPECT_EQ(recv1->channel_id().value(), 2); EXPECT_THAT(recv1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}")); EXPECT_THAT(recv1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); HloInstruction* send1 = FindInstruction(module.get(), "send.1"); EXPECT_THAT(send1->ToString(), - HasSubstr("_xla_send_recv_source_target_pairs=\"{{0,3}}\"")); + HasSubstr("_xla_send_recv_source_target_pairs={{0,3}}")); EXPECT_THAT(send1->ToString(), HasSubstr("_xla_send_recv_pipeline=\"0\"")); } diff --git a/third_party/xla/xla/service/collective_pipeliner.cc b/third_party/xla/xla/service/collective_pipeliner.cc index 859b6c9b2540c2..232dae4ec7718d 100644 --- a/third_party/xla/xla/service/collective_pipeliner.cc +++ b/third_party/xla/xla/service/collective_pipeliner.cc @@ -50,10 +50,12 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/map_util.h" #include "xla/primitive_util.h" +#include "xla/service/call_graph.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/constant_value.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" +#include "xla/service/tuple_points_to_analysis.h" #include "xla/service/value_range.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -445,7 +447,6 @@ std::vector CollectDependenciesToPipeline( ops.end()); formatting_set.insert(source_ops.begin(), source_ops.end()); std::vector to_return; - absl::flat_hash_set already_inserted; for (const HloInstruction* op : ops) { for (HloInstruction* operand : op->operands()) { if (!formatting_set.count(operand)) { @@ -697,10 +698,13 @@ class WhileLoopAnalysis { explicit WhileLoopAnalysis( HloInstruction* while_instr, int64_t max_pipelining_per_loop, bool pipeline_use_tree, bool process_different_sized_options, + TuplePointsToAnalysis* tuple_points_to_analysis, CallGraph* call_graph, std::optional known_start = std::nullopt) : while_(while_instr), loop_start_(known_start), max_pipelining_per_loop_(max_pipelining_per_loop), + tuple_points_to_analysis_(tuple_points_to_analysis), + call_graph_(call_graph), pipeline_use_tree_(pipeline_use_tree), process_different_sized_options_(process_different_sized_options) {} std::optional GetLoopIterationCount() const; @@ -796,6 +800,14 @@ class WhileLoopAnalysis { absl::flat_hash_set invariant_loop_parameters_; absl::flat_hash_set invariant_loop_instructions_; int64_t max_pipelining_per_loop_; + + // Precomputed TuplePointsToAnalysis for the HLO module containing `while_`. + // May be null, in which case the analysis will be performed from scratch. + TuplePointsToAnalysis* tuple_points_to_analysis_; + // Precomputed CallGraph analysis for the HLO module containing `while_`. + // May be null, in which case the analysis will be performed from scratch. + CallGraph* call_graph_; + bool pipeline_use_tree_; bool process_different_sized_options_; }; @@ -834,8 +846,8 @@ bool WhileLoopAnalysis::ComputeLoopStatistics() { if (loop_iteration_count_) { return true; } - std::optional parsed_loop = - PatternMatchParseWhileLoop(while_); + std::optional parsed_loop = PatternMatchParseWhileLoop( + while_, {tuple_points_to_analysis_, call_graph_}); if (!parsed_loop || !parsed_loop->static_while_loop) { return false; } @@ -1380,7 +1392,6 @@ bool IsLoopInvariant( // to still visit before visiting the HLO itself. std::vector> stack( 1, std::make_pair(instr, 0)); - absl::flat_hash_set visited; while (!stack.empty()) { auto& current = stack.back(); invariant_cache[std::get<0>(current)] = true; @@ -1796,6 +1807,8 @@ absl::Status TransformLoopForward( WhileLoopAnalysis new_loop_analysis( new_while_loop, loop_analysis.GetMaxPipeliningPerLoop(), pipeline_use_tree, process_different_sized_ops, + /*tuple_points_to_analysis=*/nullptr, + /*call_graph=*/nullptr, loop_analysis.GetLoopStart()->add(*loop_analysis.GetLoopIncrement())); new_loop_analysis.ComputeLoopStatistics(); new_loop_analysis.CollectCollectivesToMove( @@ -2035,6 +2048,17 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, << "Expected only one parameter"; HloInstruction* loop_parameter = while_body->parameter_instructions()[0]; HloInstruction* loop_init = while_loop->mutable_operand(0); + + // Clean up the SunkByPreviousStep custom calls that were inserted before. + for (HloInstruction* inst : while_body->root_instruction()->operands()) { + if (inst->opcode() == HloOpcode::kDynamicUpdateSlice && + inst->operand(1)->IsCustomCall( + CollectivePipeliner::kSunkByPreviousStep)) { + HloInstruction* cc = inst->mutable_operand(1); + TF_RETURN_IF_ERROR(inst->ReplaceOperandWith(1, cc->mutable_operand(0))); + TF_RETURN_IF_ERROR(cc->parent()->RemoveInstruction(cc)); + } + } CHECK_EQ(while_body->root_instruction()->opcode(), HloOpcode::kTuple); for (int i = 0; i < while_body->root_instruction()->operand_count(); ++i) { is_output_instruction[while_body->root_instruction()->mutable_operand(i)] = @@ -2125,7 +2149,6 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, new_parameter_shapes.push_back(expanded_shape); new_init_operands.push_back(CreateZero(loop_computation, expanded_shape, expanded_shape.element_type())); - indices_to_insert.insert(new_root_operands.size()); Shape extra_trivial_dim_shape = ShapeUtil::PrependMajorDimension(1, pipelined->shape()); HloInstruction* reshaped = body_computation->AddInstruction( @@ -2255,8 +2278,7 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, TF_RETURN_IF_ERROR(output->ReplaceOperandWith(0, new_param)); TF_RETURN_IF_ERROR( old_operand_param->parent()->RemoveInstruction(old_operand_param)); - // TODO(sacer): Consider relaxing this to all inserted operands. - if (insert_non_alias_custom_call && original_to_move_indices.contains(i)) { + if (insert_non_alias_custom_call && indices_to_insert.contains(i)) { auto* old_operand = output->mutable_operand(1); auto* custom_call = cloned_body->AddInstruction(HloInstruction::CreateCustomCall( @@ -2491,17 +2513,6 @@ absl::Status TransformLoopForwardSink(const WhileLoopAnalysis& loop_analysis, pipelined_map[formatting_op] = expanded_transpose; continue; } - if (formatting_op->IsCustomCall( - CollectivePipeliner::kSunkByPreviousStep)) { - HloInstruction* expanded_custom_call = - loop_computation->AddInstruction(HloInstruction::CreateCustomCall( - ComputeFullOutputShape(to_move, formatting_op->shape()), - collect_operands(formatting_op), - /*custom_call_target=*/ - CollectivePipeliner::kSunkByPreviousStep)); - pipelined_map[formatting_op] = expanded_custom_call; - continue; - } CHECK(false) << "Unsupported instruction " << formatting_op->ToString(); } for (int64_t i = 0; i < to_move.output_indices.size(); ++i) { @@ -2775,8 +2786,6 @@ static absl::Status TransformLoopBackward( instruction, false, CollectivePipeliner::PipeliningDirection::kBackward, loop_analysis)); } - absl::flat_hash_map - loop_cond_replacements; auto cond_builder = HloComputation::Builder(while_loop->while_condition()->name()); HloInstruction* new_cond_param = @@ -2878,12 +2887,38 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; - std::vector while_loop_instructions; + + // Precompute module-scoped analyses. Because we are running a while-loop + // analysis over all while instructions in the module, computing them here and + // passing them in avoids recomputing them once for each while instruction. + TF_ASSIGN_OR_RETURN( + std::unique_ptr tuple_points_to_analysis, + TuplePointsToAnalysis::Run(module)); + std::unique_ptr call_graph = CallGraph::Build(module); + + std::vector>> + loop_analyses; for (HloComputation* computation : module->MakeComputationPostOrder()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kWhile) { - while_loop_instructions.push_back(instruction); + if (instruction->opcode() != HloOpcode::kWhile) { + continue; + } + if (std::none_of(instruction->while_body()->instructions().begin(), + instruction->while_body()->instructions().end(), + config_.should_process)) { + continue; + } + VLOG(1) << "Pipelinable while: " << instruction->name(); + auto loop_analysis = std::make_unique( + instruction, config_.max_pipelining_per_loop, + config_.pipeline_use_tree, config_.process_different_sized_ops, + tuple_points_to_analysis.get(), call_graph.get()); + loop_analysis->ComputeLoopStatistics(); + if (loop_analysis->GetLoopIterationCount() && + loop_analysis->GetLoopIterationCount()->GetUnsignedValue() > 0) { + loop_analyses.push_back( + std::make_pair(instruction, std::move(loop_analysis))); } } } @@ -2892,32 +2927,23 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( int64_t next_channel_id = hlo_query::NextChannelId(*module); VLOG(1) << "Pipelining on direction: " << GetPipelineDirectionString(config_.pipelining_direction); - for (HloInstruction* instruction : while_loop_instructions) { - VLOG(1) << "While: " << instruction->name(); - WhileLoopAnalysis loop_analysis( - instruction, config_.max_pipelining_per_loop, config_.pipeline_use_tree, - config_.process_different_sized_ops); - loop_analysis.ComputeLoopStatistics(); - if (!loop_analysis.GetLoopIterationCount() || - loop_analysis.GetLoopIterationCount()->GetUnsignedValue() == 0) { - continue; - } + for (auto& [instruction, loop_analysis] : loop_analyses) { VLOG(1) << "While iterations: " - << loop_analysis.GetLoopIterationCount()->ToString(); - loop_analysis.CollectCollectivesToMove( + << loop_analysis->GetLoopIterationCount()->ToString(); + loop_analysis->CollectCollectivesToMove( config_.level_to_operate_on, config_.pipelining_direction, config_.should_process, config_.acceptable_formatting, config_.should_allow_loop_variant_parameter_in_chain, config_.should_allow_control_dependencies, config_.should_add_loop_invariant_op_in_chain); - if (loop_analysis.GetMoveInfos().empty()) { + if (loop_analysis->GetMoveInfos().empty()) { continue; } - transformed_instructions += loop_analysis.GetMoveInfos().size(); + transformed_instructions += loop_analysis->GetMoveInfos().size(); VLOG(1) << "Found Collectives to optimize"; if (VLOG_IS_ON(1)) { int64_t id = 0; - for (auto& to_move : loop_analysis.GetMoveInfos()) { + for (auto& to_move : loop_analysis->GetMoveInfos()) { VLOG(1) << "Move info id: " << id++ << " with " << to_move.collectives_to_move.size() << " collectives " << to_move.dynamic_update_slices.size() @@ -2937,20 +2963,20 @@ absl::StatusOr CollectivePipeliner::RunPipeliner( if (config_.pipelining_direction == PipeliningDirection::kForward) { CHECK(config_.reuse_pipelined_op_buffer); TF_RETURN_IF_ERROR(TransformLoopForward( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, config_.reuse_pipelined_op_buffer, next_channel_id)); } else if (config_.pipelining_direction == PipeliningDirection::kForwardSink) { TF_RETURN_IF_ERROR(TransformLoopForwardSink( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.pipeline_use_tree, config_.process_different_sized_ops, config_.should_process, next_channel_id)); } else { CHECK_EQ(config_.pipelining_direction, PipeliningDirection::kBackward); TF_RETURN_IF_ERROR(TransformLoopBackward( - loop_analysis, !config_.last_run, config_.level_to_operate_on, + *loop_analysis, !config_.last_run, config_.level_to_operate_on, config_.process_different_sized_ops, config_.should_process, config_.acceptable_formatting, config_.postprocess_backward_peeled_op, config_.postprocess_backward_rotated_op, next_channel_id)); diff --git a/third_party/xla/xla/service/collective_pipeliner_test.cc b/third_party/xla/xla/service/collective_pipeliner_test.cc index 5492cbc582458d..53529e822bf72f 100644 --- a/third_party/xla/xla/service/collective_pipeliner_test.cc +++ b/third_party/xla/xla/service/collective_pipeliner_test.cc @@ -239,14 +239,14 @@ ENTRY entry { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: HloModule // CHECK: %while_body - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}}" + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,5},{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12}{{[}]}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) // CHECK: } // CHECK: ENTRY %entry - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}" + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}} // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) @@ -315,14 +315,14 @@ ENTRY entry { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: HloModule // CHECK: %while_body - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}}" + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6},{0,5}{{[}]}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[dus]], {{.+}} %[[dus]]) // CHECK: %[[dus2:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus2]], {{.+}}) // CHECK: } // CHECK: ENTRY %entry - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}" + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}{_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}} // CHECK: %[[ds:.+]] = {{.+}} dynamic-slice({{.+}} %[[cp]], {{.+}}) // CHECK: %[[mul:.+]] = {{.+}} multiply({{.+}} %[[ds]], {{.+}} %[[ds]]) // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[mul]], {{.+}}) @@ -1507,13 +1507,13 @@ ENTRY entry { XLA_VLOG_LINES(1, module->ToString()); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %while_body - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}"} + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8},{3,9},{4,10},{5,11},{6,12},{7,12}{{[}]}}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) // CHECK: ENTRY %entry // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}" + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 @@ -1586,13 +1586,13 @@ ENTRY entry { XLA_VLOG_LINES(1, module->ToString()); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %while_body - // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation="{{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}"} + // CHECK: %[[cp:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}_xla_send_recv_validation={{[{]}}{7,12},{6,12},{5,11},{4,10},{3,9},{2,8},{1,7},{0,6}{{[}]}}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) // CHECK: ENTRY %entry // CHECK: %[[while:.+]] = {{.+}} while({{.+}}) // CHECK: %[[gte:.+]] = {{.+}} get-tuple-element({{.+}} %[[while]]), index=1 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}" + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[gte]]), {{.+}}_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}} // CHECK: %[[dus:.+]] = {{.+}} dynamic-update-slice({{.+}} %[[cp2]], {{.+}}) // CHECK: %[[tuple:.+]] = {{.+}} tuple({{.+}} %[[dus]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} get-tuple-element({{.+}} %[[tuple]]), index=1 @@ -2426,7 +2426,7 @@ TEST_F(CollectivePipelinerTest, EXPECT_EQ(recv1->channel_id(), send1->channel_id()); - const char* kSourceTarget = "_xla_send_recv_source_target_pairs=\"{{3,0}}\""; + const char* kSourceTarget = "_xla_send_recv_source_target_pairs={{3,0}}"; const char* kPeeledAttr = "_xla_other_attr=\"1\""; const char* kRotatedAttr = "_xla_other_attr=\"2\""; EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kSourceTarget)); @@ -3083,6 +3083,16 @@ ENTRY entry { const HloInstruction* all_reduce2 = find_all_reduce(all_reduce1); EXPECT_NE(all_reduce2, nullptr); EXPECT_THAT(all_reduce2, op::AllReduce(op::GetTupleElement(op::While()))); + // The root of while body should have a dynamic-update-slice operand which has + // a custom call at operand index 1. + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); + const HloInstruction* dynamic_update_slice = + while_instr->while_body()->root_instruction()->operands().back(); + CHECK_EQ(dynamic_update_slice->opcode(), HloOpcode::kDynamicUpdateSlice); + const HloInstruction* custom_call = dynamic_update_slice->operand(1); + CHECK(custom_call->IsCustomCall("SunkByPreviousStep")); } TEST_F(CollectivePipelinerTest, ForwardSinkFirstDimNotMatchingLoopCount) { @@ -3375,6 +3385,7 @@ ENTRY entry { XLA_VLOG_LINES(1, module->ToString()); const HloInstruction* while_instr = FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); EXPECT_TRUE( absl::c_any_of(while_instr->users(), [](const HloInstruction* user) { return absl::c_any_of( @@ -3394,6 +3405,13 @@ ENTRY entry { return operand->opcode() == HloOpcode::kReshape; }), 2); + // The root of while body should have a dynamic-update-slice operand which has + // a custom call at operand index 1. + const HloInstruction* dynamic_update_slice = + while_instr->while_body()->root_instruction()->operand(4); + CHECK_EQ(dynamic_update_slice->opcode(), HloOpcode::kDynamicUpdateSlice); + const HloInstruction* custom_call = dynamic_update_slice->operand(1); + CHECK(custom_call->IsCustomCall("SunkByPreviousStep")); } TEST_F(CollectivePipelinerTest, CollectiveWithMultipleDUSSameBuffer) { @@ -3670,6 +3688,22 @@ ENTRY entry { op::Reshape(op::Multiply()), op::Reshape(op::Divide()), op::Reshape(op::Abs()), op::GetTupleElement(op::While()), op::GetTupleElement(op::While())))); + // The root of while body should have two dynamic-update-slice operands each + // of which has a custom call at operand index 1. + std::function is_dus_with_custom_call = + [&](const HloInstruction* inst) -> bool { + if (inst->opcode() != HloOpcode::kDynamicUpdateSlice) { + return false; + } + return inst->operand(1)->IsCustomCall("SunkByPreviousStep"); + }; + const HloInstruction* while_instr = + FindInstruction(module.get(), HloOpcode::kWhile); + CHECK_NE(while_instr, nullptr); + CHECK(is_dus_with_custom_call( + while_instr->while_body()->root_instruction()->operand(7))); + CHECK(is_dus_with_custom_call( + while_instr->while_body()->root_instruction()->operand(8))); } } // namespace diff --git a/third_party/xla/xla/service/collective_quantizer_test.cc b/third_party/xla/xla/service/collective_quantizer_test.cc index a095e3ef4e19a1..fff673e4707b7a 100644 --- a/third_party/xla/xla/service/collective_quantizer_test.cc +++ b/third_party/xla/xla/service/collective_quantizer_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc b/third_party/xla/xla/service/collective_transformation_reorderer_test.cc index 3721406e64901a..73f185e1caf73f 100644 --- a/third_party/xla/xla/service/collective_transformation_reorderer_test.cc +++ b/third_party/xla/xla/service/collective_transformation_reorderer_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/compilation_environments_test.cc b/third_party/xla/xla/service/compilation_environments_test.cc index b3cd2946cf06f4..35058aefd45994 100644 --- a/third_party/xla/xla/service/compilation_environments_test.cc +++ b/third_party/xla/xla/service/compilation_environments_test.cc @@ -21,8 +21,8 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/service/test_compilation_environment.pb.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/service/compiler_test.cc b/third_party/xla/xla/service/compiler_test.cc index c2743c15aff889..951330e94d375e 100644 --- a/third_party/xla/xla/service/compiler_test.cc +++ b/third_party/xla/xla/service/compiler_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_init.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/conditional_canonicalizer_test.cc b/third_party/xla/xla/service/conditional_canonicalizer_test.cc index 3d5e1e976da0d1..beba61a5a67832 100644 --- a/third_party/xla/xla/service/conditional_canonicalizer_test.cc +++ b/third_party/xla/xla/service/conditional_canonicalizer_test.cc @@ -25,9 +25,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/conditional_code_motion.cc b/third_party/xla/xla/service/conditional_code_motion.cc index c7ef8e609df40b..cd22c9d4ca0c0a 100644 --- a/third_party/xla/xla/service/conditional_code_motion.cc +++ b/third_party/xla/xla/service/conditional_code_motion.cc @@ -1005,6 +1005,7 @@ class MoveOperandIntoBranch { CHECK_NE(new_tuple, nullptr); VLOG(5) << "Cloned new tuple:" << new_tuple->parent()->ToString() << "\n"; std::vector> gte_users; + gte_users.reserve(branch_param->shape().tuple_shapes_size()); for (int64_t j = 0; j < branch_param->shape().tuple_shapes_size(); ++j) { gte_users.push_back(std::vector()); } diff --git a/third_party/xla/xla/service/conditional_code_motion_test.cc b/third_party/xla/xla/service/conditional_code_motion_test.cc index fcfe91d7a21dfa..0a3d74327dd522 100644 --- a/third_party/xla/xla/service/conditional_code_motion_test.cc +++ b/third_party/xla/xla/service/conditional_code_motion_test.cc @@ -29,9 +29,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/service/conditional_simplifier_test.cc b/third_party/xla/xla/service/conditional_simplifier_test.cc index 083ef03453d67f..24a7c0a68045b0 100644 --- a/third_party/xla/xla/service/conditional_simplifier_test.cc +++ b/third_party/xla/xla/service/conditional_simplifier_test.cc @@ -26,9 +26,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" namespace xla { diff --git a/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc index c155c2ff21397f..a404f03e5301cf 100644 --- a/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/service/convert_async_collectives_to_sync_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/utils/hlo_matchers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index 5f045b5fc75aeb..39b5b70d51de30 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -99,7 +99,7 @@ filegroup( "runtime_matmul_f64.cc", "runtime_matmul_s32.cc", "runtime_fork_join.cc", - "//xla/service/cpu/runtime:runtime_srcs", + "//xla/backends/cpu/runtime:runtime_srcs", #"runtime_handle_ffi_call.cc", # TODO(b/338344732): Add "runtime_handle_ffi_call.cc". ], visibility = internal_visibility([":friends"]), @@ -127,7 +127,7 @@ filegroup( "runtime_fork_join.h", "runtime_lightweight_check.h", "runtime_matmul.h", - "//xla/service/cpu/runtime:runtime_hdrs", + "//xla/backends/cpu/runtime:runtime_hdrs", #"runtime_handle_ffi_call.h", # TODO(b/338344732): Add "runtime_handle_ffi_call.h" ], visibility = internal_visibility([":friends"]), @@ -222,8 +222,7 @@ cc_library( ":ir_emission_utils", ":ir_emitter", ":ir_emitter2", - ":onednn_convolution_rewriter", - ":onednn_matmul_rewriter", + ":onednn_contraction_rewriter", ":onednn_ops_rewriter", ":parallel_task_assignment", ":simple_orc_jit", @@ -240,6 +239,7 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla:xla_proto_cc", + "//xla/backends/cpu/runtime:thunk", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", "//xla/mlir_hlo", @@ -327,8 +327,8 @@ cc_library( "//xla/service:while_loop_constant_sinking", "//xla/service:while_loop_invariant_code_motion", "//xla/service:while_loop_simplifier", + "//xla/service:while_loop_trip_count_annotator", "//xla/service:zero_sized_hlo_elimination", - "//xla/service/cpu/runtime:thunk", "//xla/service/llvm_ir:llvm_command_line_options", "//xla/service/llvm_ir:llvm_util", "//xla/service/spmd:stateful_rng_spmd_partitioner", @@ -557,6 +557,9 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/backends/cpu/runtime:buffer_allocations", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:thunk_executor", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:computation_layout", @@ -570,9 +573,6 @@ cc_library( "//xla/service:maybe_owning_device_memory", "//xla/service:shaped_buffer", "//xla/service:xla_debug_info_manager", - "//xla/service/cpu/runtime:buffer_allocations", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:thunk_executor", "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor/host:host_kernel_c_api", @@ -588,6 +588,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@llvm-project//llvm:Core", "@llvm-project//llvm:ExecutionEngine", "@llvm-project//llvm:OrcJIT", "@llvm-project//llvm:OrcShared", @@ -662,6 +663,28 @@ cc_library( ], ) +xla_cc_test( + name = "ir_emitter_test", + srcs = ["ir_emitter_test.cc"], + deps = [ + ":ir_emitter", + ":ir_function", + ":target_machine_features_fake", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:hlo_module_config", + "//xla/service:hlo_ordering", + "//xla/service:hlo_parser", + "//xla/service:logical_buffer", + "//xla/tests:hlo_test_base", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "ir_emitter2_test", srcs = ["ir_emitter2_test.cc"], @@ -681,12 +704,10 @@ xla_cc_test( "//xla/service/llvm_ir:llvm_util", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Core", - "@llvm-project//llvm:OrcJIT", - "@llvm-project//llvm:Support", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -828,33 +849,33 @@ cc_library( "//xla:shape_util", "//xla:status_macros", "//xla:util", + "//xla/backends/cpu/runtime:all_gather_thunk", + "//xla/backends/cpu/runtime:all_reduce_thunk", + "//xla/backends/cpu/runtime:all_to_all_thunk", + "//xla/backends/cpu/runtime:call_thunk", + "//xla/backends/cpu/runtime:collective_permute_thunk", + "//xla/backends/cpu/runtime:collective_thunk", + "//xla/backends/cpu/runtime:conditional_thunk", + "//xla/backends/cpu/runtime:convolution_thunk", + "//xla/backends/cpu/runtime:copy_thunk", + "//xla/backends/cpu/runtime:custom_call_thunk", + "//xla/backends/cpu/runtime:dot_thunk", + "//xla/backends/cpu/runtime:fft_thunk", + "//xla/backends/cpu/runtime:infeed_thunk", + "//xla/backends/cpu/runtime:kernel_thunk", + "//xla/backends/cpu/runtime:logical_id_thunk", + "//xla/backends/cpu/runtime:outfeed_thunk", + "//xla/backends/cpu/runtime:reduce_scatter_thunk", + "//xla/backends/cpu/runtime:resource_use", + "//xla/backends/cpu/runtime:rng_state_thunk", + "//xla/backends/cpu/runtime:sort_thunk", + "//xla/backends/cpu/runtime:thunk", + "//xla/backends/cpu/runtime:topk_thunk", + "//xla/backends/cpu/runtime:while_thunk", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/service:collective_ops_utils", "//xla/service:hlo_module_config", - "//xla/service/cpu/runtime:all_gather_thunk", - "//xla/service/cpu/runtime:all_reduce_thunk", - "//xla/service/cpu/runtime:all_to_all_thunk", - "//xla/service/cpu/runtime:call_thunk", - "//xla/service/cpu/runtime:collective_permute_thunk", - "//xla/service/cpu/runtime:collective_thunk", - "//xla/service/cpu/runtime:conditional_thunk", - "//xla/service/cpu/runtime:convolution_thunk", - "//xla/service/cpu/runtime:copy_thunk", - "//xla/service/cpu/runtime:custom_call_thunk", - "//xla/service/cpu/runtime:dot_thunk", - "//xla/service/cpu/runtime:fft_thunk", - "//xla/service/cpu/runtime:infeed_thunk", - "//xla/service/cpu/runtime:kernel_thunk", - "//xla/service/cpu/runtime:logical_id_thunk", - "//xla/service/cpu/runtime:outfeed_thunk", - "//xla/service/cpu/runtime:reduce_scatter_thunk", - "//xla/service/cpu/runtime:resource_use", - "//xla/service/cpu/runtime:rng_state_thunk", - "//xla/service/cpu/runtime:sort_thunk", - "//xla/service/cpu/runtime:thunk", - "//xla/service/cpu/runtime:topk_thunk", - "//xla/service/cpu/runtime:while_thunk", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", @@ -1043,7 +1064,7 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1061,7 +1082,7 @@ cc_library( deps = [ ":runtime_lightweight_check", "//xla:executable_run_options", - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1189,7 +1210,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1205,7 +1226,7 @@ cc_library( copts = runtime_copts(), visibility = ["//visibility:public"], deps = [ - "//xla/service/cpu/runtime:conv_impl", + "//xla/backends/cpu/runtime:convolution_thunk_internal", "//xla/tsl/framework/contraction:eigen_contraction_kernel", "//xla/tsl/framework/convolution:eigen_helpers", "@com_google_absl//absl/base:dynamic_annotations", @@ -1359,6 +1380,7 @@ xla_cc_test( ":runtime_matmul_acl", ":runtime_single_threaded_matmul", "//xla:array2d", + "//xla:executable_run_options", "//xla:types", "//xla/client:local_client", "//xla/service:custom_call_status_internal", @@ -1398,7 +1420,7 @@ xla_cc_test( ":cpu_runtime", "//xla:shape_util", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", @@ -1584,8 +1606,8 @@ xla_cc_test( "//xla/service/cpu:target_machine_features", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -1848,17 +1870,19 @@ cc_library( ) cc_library( - name = "onednn_matmul_rewriter", - srcs = ["onednn_matmul_rewriter.cc"], + name = "onednn_contraction_rewriter", + srcs = ["onednn_contraction_rewriter.cc"], hdrs = [ + "onednn_contraction_rewriter.h", + "onednn_convolution.h", "onednn_matmul.h", - "onednn_matmul_rewriter.h", "//xla/tsl/util:onednn_util_hdrs", ], copts = tsl_copts(), deps = [ ":backend_config_proto_cc", ":onednn_config_proto_cc", + ":onednn_convolution", ":onednn_matmul", ":onednn_memory_util", ":onednn_pattern_utils", @@ -1874,6 +1898,7 @@ cc_library( "//xla/service:hlo_pass", "//xla/service:pattern_matcher", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", @@ -1906,44 +1931,13 @@ cc_library( ] + mkl_deps(), ) -cc_library( - name = "onednn_convolution_rewriter", - srcs = ["onednn_convolution_rewriter.cc"], - hdrs = ["onednn_convolution_rewriter.h"], - copts = tsl_copts(), - deps = [ - ":backend_config_proto_cc", - ":onednn_config_proto_cc", - ":onednn_convolution", - ":onednn_memory_util", - ":onednn_util", - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:status_macros", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", - ] + mkl_deps(), -) - cc_library( name = "cpu_float_support", srcs = ["cpu_float_support.cc"], hdrs = ["cpu_float_support.h"], copts = tsl_copts(), deps = [ - ":onednn_convolution_rewriter", - ":onednn_matmul_rewriter", + ":onednn_contraction_rewriter", "//xla/service:float_support", ], ) diff --git a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc index 94418f2ab82aee..bf06650e96173f 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/elementwise_benchmark_test.cc @@ -52,13 +52,63 @@ static void BM_AddF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } -BENCHMARK(BM_AddF32) - ->MeasureProcessCPUTime() - ->Arg(128) - ->Arg(256) - ->Arg(512) - ->Arg(1024) - ->Arg(8192) - ->Arg(16384); +static void BM_AddBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule add_bf16_$d0 + + ENTRY e { + p0 = bf16[1,2,1,$d0,256] parameter(0) + p1 = bf16[1,2,1,$d0,256] parameter(1) + ROOT add = bf16[1,2,1,$d0,256] add(p0, p1) + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(BF16, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + auto p1 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0, &p1}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +static void BM_ConvertF32ToBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule convert_f32_to_bf16_$d0 + + ENTRY e { + p0 = f32[1,2,1,$d0,256] parameter(0) + ROOT convert = bf16[1,2,1,$d0,256] convert(p0) + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(F32, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +#define BENCHMARK_SIZES(NAME) \ + BENCHMARK(NAME) \ + ->MeasureProcessCPUTime() \ + ->Arg(128) \ + ->Arg(256) \ + ->Arg(512) \ + ->Arg(1024) \ + ->Arg(8192) \ + ->Arg(16384) \ + ->Arg(32768) + +BENCHMARK_SIZES(BM_AddF32); +BENCHMARK_SIZES(BM_AddBF16); +BENCHMARK_SIZES(BM_ConvertF32ToBF16); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc index 1ec04bc4cae4d8..c5399e93c8d7cd 100644 --- a/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc +++ b/third_party/xla/xla/service/cpu/benchmarks/reduction_benchmark_test.cc @@ -57,13 +57,45 @@ static void BM_ReduceAddF32(benchmark::State& state) { CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); } -BENCHMARK(BM_ReduceAddF32) - ->MeasureProcessCPUTime() - ->Arg(128) - ->Arg(256) - ->Arg(512) - ->Arg(1024) - ->Arg(8192) - ->Arg(16384); +static void BM_ReduceAddBF16(benchmark::State& state) { + int64_t d0 = state.range(0); + + std::string_view hlo = R"( + HloModule reduce_add_bf16_$d0 + + add { + p0 = bf16[] parameter(0) + p1 = bf16[] parameter(1) + ROOT add = bf16[] add(p0, p1) + } + + ENTRY e { + p0 = bf16[1,2,1,$d0,256] parameter(0) + c0 = bf16[] constant(0) + ROOT reduce = bf16[1,2] reduce(p0, c0), dimensions={2,3,4}, to_apply=add + } + )"; + + std::minstd_rand0 engine; + + auto shape = ShapeUtil::MakeShape(BF16, {1, 2, 1, d0, 256}); + auto p0 = *LiteralUtil::CreateRandomLiteral(shape, &engine, 1.0f, 0.1f); + + std::vector args = {&p0}; + CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}})); +} + +#define BENCHMARK_SIZES(NAME) \ + BENCHMARK(NAME) \ + ->MeasureProcessCPUTime() \ + ->Arg(128) \ + ->Arg(256) \ + ->Arg(512) \ + ->Arg(1024) \ + ->Arg(8192) \ + ->Arg(16384) + +BENCHMARK_SIZES(BM_ReduceAddF32); +BENCHMARK_SIZES(BM_ReduceAddBF16); } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 7254f2b1380f03..13a47eb2b36b61 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Transforms/DialectConversion.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -113,7 +114,6 @@ limitations under the License. #include "xla/service/cpu/ir_emitter.h" #include "xla/service/cpu/ir_emitter2.h" #include "xla/service/cpu/parallel_task_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/cpu/thunk_emitter.h" @@ -180,6 +180,7 @@ limitations under the License. #include "xla/service/while_loop_constant_sinking.h" #include "xla/service/while_loop_invariant_code_motion.h" #include "xla/service/while_loop_simplifier.h" +#include "xla/service/while_loop_trip_count_annotator.h" #include "xla/service/zero_sized_hlo_elimination.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -204,8 +205,7 @@ limitations under the License. #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) #include "xla/service/cpu/cpu_float_support.h" -#include "xla/service/cpu/onednn_convolution_rewriter.h" -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" #include "xla/service/cpu/onednn_ops_rewriter.h" #include "xla/service/simplify_fp_conversions.h" #endif @@ -447,7 +447,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); spmd_pipeline.AddPass(); - if (module->config().debug_options().xla_use_shardy()) { + if (module->config().use_shardy_partitioner()) { spmd_pipeline.AddPass(); } else { spmd_pipeline.AddPass( @@ -520,7 +520,8 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( // Rewrite to custom calls with target as oneDNN library calls. #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) // AOT compiled code runs in single thread. - if (!is_aot_compile) { + bool is_thunk_runtime = debug_options.xla_cpu_use_thunk_runtime(); + if (!is_aot_compile && !is_thunk_runtime) { // Placing OneDnnOpsRewriter here to match the flax patterns // TODO: Decide where would be the appropriate place for this pass to make // it more generic @@ -540,7 +541,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( FloatSupport bf16_support(BF16); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) CpuFloatSupport onednn_bf16_support(BF16); - if (!is_aot_compile) { + if (!is_aot_compile && !is_thunk_runtime) { pipeline.AddPass(&onednn_bf16_support); } else { pipeline.AddPass(&bf16_support); @@ -685,6 +686,10 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + // Annotate while loops with statically known trip counts, so that at run time + // we can avoid running the loop condition computations. + pipeline.AddPass(); + // Layout assignment uses alias analysis, which requires the call graph to be // flattened. pipeline.AddPass(); @@ -741,8 +746,11 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( : tsl::port::NumSchedulableCPUs(); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + auto& debug_options = module->config().debug_options(); + bool is_thunk_runtime = debug_options.xla_cpu_use_thunk_runtime(); + // AOT compiled code runs in single thread. - if (!is_aot_compile) { + if (!is_aot_compile && !is_thunk_runtime) { auto debug_options = module->config().debug_options(); // Run SimplifyFPConversions pass to simplify the BF16 pattern and make it // easier to match. @@ -750,11 +758,10 @@ absl::Status CpuCompiler::RunHloPassesAfterLayoutAssn( if (debug_options.xla_allow_excess_precision()) { pipeline.AddPass(); } - pipeline.AddPass(); - pipeline.AddPass(max_parallelism, - compile_options.thread_pool); + pipeline.AddPass(max_parallelism, + compile_options.thread_pool); // Run SimplifyFPConversions pass again to remove redundant Convert ops - // that may exist as a result of running OneDnnMatMulRewriter pass. + // that may exist as a result of running OneDnnContractionRewriter pass. if (debug_options.xla_allow_excess_precision()) { pipeline.AddPass(); } @@ -1264,11 +1271,17 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { cantFail((*jit)->AddModule(llvm::orc::ThreadSafeModule( std::move(llvm_module), std::move(llvm_context)))); + auto mangle = [&](std::string_view name) { + llvm::SmallVector mangled; + llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout()); + return std::string(mangled.begin(), mangled.end()); + }; + // TODO(ezhulenev): We should be able to make it lazy on-demand, but today // we capture obj_files by reference and it leads to asan errors. Figure out // lifetime issues and move compilation to Thunk initialization stage. for (const auto& kernel : ir_emitter2.kernels()) { - if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) { + if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) { return Internal("Failed to find compiled symbol for kernel %s", kernel.name); } @@ -1276,7 +1289,7 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr module) { // Compile auxiliary comparator functions used by sort thunks. for (const auto& comparator : ir_emitter2.comparators()) { - if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) { + if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) { return Internal("Failed to find compiled symbol for comparator %s", comparator.name); } @@ -1776,16 +1789,22 @@ CpuExecutableAotCompilationResult::LoadExecutable( TF_ASSIGN_OR_RETURN(ThunkSequence thunks, thunk_emitter.EmitEntryComputation(*module)); + auto mangle = [&](std::string_view name) { + llvm::SmallVector mangled; + llvm::Mangler::getNameWithPrefix(mangled, name, (*jit)->data_layout()); + return std::string(mangled.begin(), mangled.end()); + }; + // Lookup all kernel functions by name in the loaded object file. for (const auto& kernel : ir_emitter2.kernels()) { - if (auto sym = (*jit)->FindCompiledSymbol(kernel.name); !sym) { + if (auto s = (*jit)->FindCompiledSymbol(mangle(kernel.name)); !s) { return Internal("Failed to find compiled symbol for kernel %s", kernel.name); } } for (const auto& comparator : ir_emitter2.comparators()) { - if (auto sym = (*jit)->FindCompiledSymbol(comparator.name); !sym) { + if (auto s = (*jit)->FindCompiledSymbol(mangle(comparator.name)); !s) { return Internal("Failed to find compiled symbol for comparator %s", comparator.name); } diff --git a/third_party/xla/xla/service/cpu/cpu_executable.cc b/third_party/xla/xla/service/cpu/cpu_executable.cc index a37a40e9d19acd..e1f4b213170651 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.cc +++ b/third_party/xla/xla/service/cpu/cpu_executable.cc @@ -37,8 +37,13 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "unsupported/Eigen/CXX11/Tensor" +#include "llvm/ADT/SmallVector.h" #include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" +#include "llvm/IR/Mangler.h" #include "llvm/Support/Error.h" +#include "xla/backends/cpu/runtime/buffer_allocations.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" @@ -46,9 +51,6 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/cpu_runtime.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -80,12 +82,18 @@ using FunctionRegistry = CpuExecutable::FunctionRegistry; FunctionRegistry::FunctionRegistry(SimpleOrcJIT* jit) : jit_(jit) {} +std::string FunctionRegistry::Mangle(std::string_view name) { + llvm::SmallVector mangled; + llvm::Mangler::getNameWithPrefix(mangled, name, jit_->data_layout()); + return std::string(mangled.begin(), mangled.end()); +} + absl::StatusOr FunctionRegistry::FindKernel( std::string_view name) { VLOG(3) << "Find host kernel with a name " << name; llvm::Expected sym = - jit_->FindCompiledSymbol(std::string(name)); + jit_->FindCompiledSymbol(Mangle(name)); if (!sym) { return absl::InvalidArgumentError( absl::StrCat("Can't resolve host kernel with a name ", name, @@ -99,7 +107,7 @@ absl::StatusOr FunctionRegistry::FindComparator( VLOG(3) << "Find comparator with a name " << name; llvm::Expected sym = - jit_->FindCompiledSymbol(std::string(name)); + jit_->FindCompiledSymbol(Mangle(name)); if (!sym) { return absl::InvalidArgumentError( absl::StrCat("Can't resolve comparator with a name ", name, @@ -175,6 +183,7 @@ absl::StatusOr> CpuExecutable::Create( std::move(hlo_profile_index_map), std::move(assignment))); executable->jit_ = std::move(jit); + executable->jit_->DoneCompiling(); executable->function_registry_ = FunctionRegistry(executable->jit_.get()); TF_ASSIGN_OR_RETURN(executable->thunks_, @@ -387,8 +396,7 @@ absl::Status CpuExecutable::ExecuteThunks( Thunk::ExecuteParams execute_params = { &*function_registry_, &allocations, - runtime::GetXfeedManager( - run_options->stream()->parent()->device_ordinal()), + runtime::GetXfeedManager(runtime::GetDeviceOrdinal(run_options)), run_options->intra_op_thread_pool(), &task_runner, &collective_execute_params, diff --git a/third_party/xla/xla/service/cpu/cpu_executable.h b/third_party/xla/xla/service/cpu/cpu_executable.h index b129674add7c72..2c2aa248bcbe5d 100644 --- a/third_party/xla/xla/service/cpu/cpu_executable.h +++ b/third_party/xla/xla/service/cpu/cpu_executable.h @@ -28,13 +28,13 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/thunk_executor.h" #include "xla/executable_run_options.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/thunk_executor.h" #include "xla/service/cpu/simple_orc_jit.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_status_internal.h" @@ -151,6 +151,8 @@ class CpuExecutable : public Executable { absl::StatusOr FindComparator(std::string_view name) final; private: + std::string Mangle(std::string_view name); + SimpleOrcJIT* jit_; }; diff --git a/third_party/xla/xla/service/cpu/cpu_float_support.cc b/third_party/xla/xla/service/cpu/cpu_float_support.cc index 6914e656b900a6..c5907168ebf1c1 100644 --- a/third_party/xla/xla/service/cpu/cpu_float_support.cc +++ b/third_party/xla/xla/service/cpu/cpu_float_support.cc @@ -17,8 +17,7 @@ limitations under the License. #include "xla/service/cpu/cpu_float_support.h" -#include "xla/service/cpu/onednn_convolution_rewriter.h" -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" namespace xla { namespace cpu { @@ -28,10 +27,10 @@ bool CpuFloatSupport::IsSupported(const HloInstruction& hlo) const { // oneDNN rewritable ops case HloOpcode::kDot: return LowPrecisionType() == BF16 && - OneDnnMatMulRewriter::ShouldRewrite(&hlo); + OneDnnContractionRewriter::ShouldRewriteDot(&hlo, true); case HloOpcode::kConvolution: return LowPrecisionType() == BF16 && - OneDnnConvolutionRewriter::ShouldRewrite(&hlo); + OneDnnContractionRewriter::ShouldRewriteConv(&hlo); // Collective ops. case HloOpcode::kAllGather: case HloOpcode::kAllReduce: diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.cc b/third_party/xla/xla/service/cpu/cpu_runtime.cc index f3ac32c04904bc..4e209e61f283c6 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime.cc @@ -73,6 +73,17 @@ XfeedManager* GetXfeedManager(int device_ordinal) { return it->second; } +// TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal +// directly since callers may not have a Stream*. +int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options) { + if (!run_options) { + return 0; + } else if (run_options->device_ordinal() != -1) { + return run_options->device_ordinal(); + } + return run_options->stream()->parent()->device_ordinal(); +} + extern const char* const kEigenMatMulF16SymbolName = "__xla_cpu_runtime_EigenMatMulF16"; extern const char* const kEigenMatMulF32SymbolName = @@ -198,17 +209,6 @@ std::string ShapeString(const void* shape_ptr, int32_t shape_length) { return ""; } -// TODO(zhangqiaorjc): Prefer to make callers set and use device_ordinal -// directly since callers may not have a Stream*. -int GetDeviceOrdinal(const ExecutableRunOptions* run_options) { - if (!run_options) { - return 0; - } else if (run_options->device_ordinal() != -1) { - return run_options->device_ordinal(); - } - return run_options->stream()->parent()->device_ordinal(); -} - ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void* AcquireInfeedBufferForDequeueImpl(const ExecutableRunOptions* run_options, int32_t buffer_length, diff --git a/third_party/xla/xla/service/cpu/cpu_runtime.h b/third_party/xla/xla/service/cpu/cpu_runtime.h index c40a84caf8aced..92beff43a3c0ea 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime.h +++ b/third_party/xla/xla/service/cpu/cpu_runtime.h @@ -103,6 +103,8 @@ extern const char* const kXlaCpuRuntimeSymbolNamePrefix; // `device_ordinal`. Note the device ordinal does not name a CPU XfeedManager* GetXfeedManager(int device_ordinal); +int GetDeviceOrdinal(const xla::ExecutableRunOptions* run_options); + } // namespace runtime } // namespace cpu } // namespace xla diff --git a/third_party/xla/xla/service/cpu/cpu_runtime_test.cc b/third_party/xla/xla/service/cpu/cpu_runtime_test.cc index 78bbc8f661e311..4e4d6aa5a909b2 100644 --- a/third_party/xla/xla/service/cpu/cpu_runtime_test.cc +++ b/third_party/xla/xla/service/cpu/cpu_runtime_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "unsupported/Eigen/CXX11/Tensor" #include "xla/array2d.h" #include "xla/client/local_client.h" +#include "xla/executable_run_options.h" #include "xla/service/cpu/runtime_custom_call_status.h" #include "xla/service/cpu/runtime_matmul.h" #include "xla/service/cpu/runtime_matmul_acl.h" @@ -180,5 +181,29 @@ TEST_F(CpuRuntimeTest, FailureStatus) { ASSERT_FALSE(__xla_cpu_runtime_StatusIsSuccess(&success_status)); } +// When run_options is null, the process should not crash and the device ordinal +// should be 0. +TEST_F(CpuRuntimeTest, GetDeviceOrdinalWhenRunOptionsEmpty) { + EXPECT_EQ(cpu::runtime::GetDeviceOrdinal(/*run_options=*/nullptr), 0); +} + +// When the device ordinal is set directly in run options, it should be returned +// (and NOT the value from stream). +TEST_F(CpuRuntimeTest, GetDeviceOrdinalWhenSetInRunOptions) { + // GetDeviceOrdinal implementation bases on the fact that device ordinal is + // -1 by default. So we need to assert for that here to avoid crash in case + // the default value changes in the future. + ExecutableRunOptions run_options; + ASSERT_EQ(run_options.device_ordinal(), -1); + + // Actual test - set device ordinal in run options and check that it is + // returned. + run_options.set_device_ordinal(3); + EXPECT_EQ(cpu::runtime::GetDeviceOrdinal(&run_options), 3); +} + +// TODO(abanas): Add test case for the device ordinal with stream case. It +// requires mocking the stream and stream executor. + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/cpu/executable.proto b/third_party/xla/xla/service/cpu/executable.proto index bca8a2cc2c64e4..d222660d0f0c35 100644 --- a/third_party/xla/xla/service/cpu/executable.proto +++ b/third_party/xla/xla/service/cpu/executable.proto @@ -17,7 +17,6 @@ syntax = "proto3"; package xla.cpu; -import "xla/service/cpu/xla_framework.proto"; import "xla/service/hlo.proto"; import "xla/xla.proto"; diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 13ede90a9192ab..e043b5c2e13bec 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -113,7 +113,9 @@ class IrEmitter::CpuElementalIrEmitter : public ElementalIrEmitter { public: CpuElementalIrEmitter(const HloModuleConfig& module_config, IrEmitter* ir_emitter, llvm::Module* module) - : ElementalIrEmitter(module, ir_emitter->b()), + : ElementalIrEmitter( + module, ir_emitter->b(), + Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), hlo_module_config_(module_config), ir_emitter_(ir_emitter) {} @@ -451,6 +453,18 @@ void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, } } +void IrEmitter::AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const { + AttachInvariantLoadMetadataForLoad(load, hlo_module_config_); +} + +/*static*/ void IrEmitter::AttachInvariantLoadMetadataForLoad( + llvm::LoadInst* load, const HloModuleConfig& config) { + if (config.debug_options().xla_llvm_enable_invariant_load_metadata()) { + load->setMetadata(llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(load->getContext(), /*MDs=*/{})); + } +} + absl::Status IrEmitter::HandleGetTupleElement( HloInstruction* get_tuple_element) { // A tuple is an array of pointers, one for each operand. Each pointer points @@ -2643,6 +2657,22 @@ absl::Status IrEmitter::HandleTopK(HloInstruction* hlo) { } #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + +// Emits operands alloca vector for oneDNN custom calls. +std::vector IrEmitter::EmitOneDnnOperandsAlloca( + HloInstruction* custom_call, llvm::Value*& args_val, int& arg_indx) { + std::vector operands_stack_alloca; + const int num_operands = custom_call->operand_count(); + operands_stack_alloca.reserve(num_operands); + for (int i = 0; i < num_operands; ++i) { + llvm_ir::IrArray ir_array(GetIrArrayFor(custom_call->operand(i))); + StackAlloca stack_alloca = GetAllocaAndEmitMemrefInfo(*b(), ir_array); + args_val = b()->CreateInsertValue(args_val, stack_alloca.value, arg_indx++); + operands_stack_alloca.push_back(std::move(stack_alloca)); + } + return operands_stack_alloca; +} + absl::Status IrEmitter::HandleOneDnnMatMulCalls( HloInstruction* custom_call, std::string runtime_symbol_name) { // We would like to emit LLVM IR for the following function call @@ -2684,7 +2714,6 @@ absl::Status IrEmitter::HandleOneDnnMatMulCalls( args_val = b()->CreateInsertValue(args_val, run_opts_val, arg_indx++); // Insert OneDnnMatMulConfig. - auto typed_custom_call = Cast(custom_call); auto backend_config = typed_custom_call->backend_config(); OneDnnMatMulConfig matmul_config; @@ -2696,17 +2725,8 @@ absl::Status IrEmitter::HandleOneDnnMatMulCalls( args_val = b()->CreateInsertValue(args_val, matmul_config_val, arg_indx++); // Insert operands. - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); TF_RET_CHECK(nargs == arg_indx) << "Number of arguments don't equal the last argument index."; @@ -2812,17 +2832,8 @@ absl::Status IrEmitter::HandleOneDnnConvolution(HloInstruction* custom_call) { b()->CreateGlobalStringPtr(llvm_ir::AsStringRef(str_config)); args_val = b()->CreateInsertValue(args_val, conv_config_val, arg_indx++); - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); TF_RET_CHECK(nargs == arg_indx) << "Number of arguments don't equal the last argument index."; @@ -2891,17 +2902,10 @@ absl::Status IrEmitter::HandleOneDnnLayerNorm(HloInstruction* custom_call) { args_val = b()->CreateInsertValue(args_val, ln_config_val, arg_indx++); // Insert operands. - std::vector operands_stack_alloca; - operands_stack_alloca.reserve(num_operands); - absl::c_transform(custom_call->operands(), operands_stack_alloca.begin(), - [this](HloInstruction* instr) { - llvm_ir::IrArray ir_array(GetIrArrayFor(instr)); - return GetAllocaAndEmitMemrefInfo(*b(), ir_array); - }); - for (int i = 0; i < num_operands; ++i) { - args_val = b()->CreateInsertValue(args_val, operands_stack_alloca[i].value, - arg_indx++); - } + auto operands_stack_alloca = + EmitOneDnnOperandsAlloca(custom_call, args_val, arg_indx); + TF_RET_CHECK(nargs == arg_indx) + << "Number of arguments don't equal the last argument index."; llvm::Value* args_ptr = llvm_ir::EmitAllocaAtFunctionEntry(ptr_array_type, "layernorm.args", b()); @@ -4083,12 +4087,8 @@ llvm::Value* IrEmitter::EmitGlobalBufferPointer( GetBufferTableArgument(), b()->getPtrTy(), slice.index(), b()); llvm::LoadInst* tempbuf_address_base = Load(b()->getPtrTy(), tempbuf_address_ptr); - if (hlo_module_config_.debug_options() - .xla_llvm_enable_invariant_load_metadata()) { - tempbuf_address_base->setMetadata( - llvm::LLVMContext::MD_invariant_load, - llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); - } + + AttachInvariantLoadMetadataForLoad(tempbuf_address_base); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); diff --git a/third_party/xla/xla/service/cpu/ir_emitter.h b/third_party/xla/xla/service/cpu/ir_emitter.h index 4ed7854f48a610..d2c94a913c0c42 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.h +++ b/third_party/xla/xla/service/cpu/ir_emitter.h @@ -59,6 +59,10 @@ limitations under the License. #include "xla/service/name_uniquer.h" #include "xla/xla_data.pb.h" +#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) +#include "xla/service/cpu/onednn_memory_util.h" +#endif + namespace xla { namespace cpu { @@ -320,6 +324,9 @@ class IrEmitter : public DfsHloVisitorWithDefault, absl::Status HandleAllReduceSingleReplica(HloInstruction* crs); absl::Status HandleAllReduceMultipleReplica(HloInstruction* crs); #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) + std::vector EmitOneDnnOperandsAlloca(HloInstruction* custom_call, + llvm::Value*& args_val, + int& arg_indx); absl::Status HandleOneDnnMatMulCalls(HloInstruction* hlo, std::string runtime_symbol_name); absl::Status HandleOneDnnSoftmax(HloInstruction* hlo); @@ -753,8 +760,14 @@ class IrEmitter : public DfsHloVisitorWithDefault, // result with the dereferenceable bytes required by the shape / buffer size. void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, const Shape& shape); - void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, - int64_t buffer_size); + static void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, + int64_t buffer_size); + + // Given a load instruction, annotate the load's result with the invariant + // load metadata. + void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load) const; + static void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* load, + const HloModuleConfig& config); // Calculate the alignment of a buffer allocated for a given shape. int MinimumAlignmentForShape(const Shape& shape); diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.cc b/third_party/xla/xla/service/cpu/ir_emitter2.cc index eeac5bcc390993..e7b671268093fc 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" @@ -47,7 +48,7 @@ limitations under the License. #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" -#include "llvm/Support/Casting.h" +#include "llvm/Support/CodeGen.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -128,7 +129,9 @@ class IrEmitter2::ElementalIrEmitter : public xla::ElementalIrEmitter { ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b, const HloModule* hlo_module, IrEmitter* nested_ir_emitter, bool fast_min_max) - : xla::ElementalIrEmitter(module, b), + : xla::ElementalIrEmitter( + module, b, + Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}), hlo_module_(hlo_module), nested_ir_emitter_(nested_ir_emitter), fast_min_max_(fast_min_max) {} @@ -221,6 +224,13 @@ IrEmitter2::IrEmitter2(const HloModule& hlo_module, llvm::Module* module, bool IrEmitter2::fast_min_max() const { return hlo_module_.config().debug_options().xla_cpu_enable_fast_min_max(); } +IrEmitter2::KernelInfo::KernelInfo(KernelPrototype prototype, + const se::BlockDim& block_dims, + const se::ThreadDim& thread_dims) + : name(prototype.function->getName().str()), + block_dims(block_dims), + thread_dims(thread_dims), + invariant_buffers(std::move(prototype.invariant_buffers)) {} absl::StatusOr IrEmitter2::EmitElementalHostKernel( const HloInstruction* instr) { @@ -249,8 +259,8 @@ absl::StatusOr IrEmitter2::EmitElementalHostKernel( se::ThreadDim thread_dims, EmitElementalLoops(b, instr, kernel_prototype, element_generator)); - return kernels_.emplace_back(KernelInfo{ - kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims}); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } absl::StatusOr IrEmitter2::EmitPadHostKernel( @@ -280,8 +290,7 @@ absl::StatusOr IrEmitter2::EmitPadHostKernel( nested_ir_emitter_->PopComputeFunction(); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitFusionHostKernel( @@ -325,9 +334,8 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( const_cast(fusion), kernel_prototype.results[0], &fused_emitter, &b)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } // Emit plain elemental loops for the fusion operation. @@ -339,8 +347,8 @@ absl::StatusOr IrEmitter2::EmitFusionHostKernel( se::ThreadDim thread_dims, EmitElementalLoops(b, fusion, kernel_prototype, element_generator)); - return kernels_.emplace_back(KernelInfo{ - kernel_prototype.function->getName().str(), se::BlockDim(), thread_dims}); + return kernels_.emplace_back( + KernelInfo(std::move(kernel_prototype), se::BlockDim(), thread_dims)); } absl::StatusOr IrEmitter2::EmitReductionHostKernel( @@ -392,8 +400,7 @@ absl::StatusOr IrEmitter2::EmitDotHostKernel( /*allow_runtime_calls=*/false)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( @@ -413,9 +420,8 @@ absl::StatusOr IrEmitter2::EmitConcatenateHostKernel( llvm_ir::IrArray output_array = kernel_prototype.results[0]; TF_RETURN_IF_ERROR(::xla::cpu::EmitFastConcatenate( instr, kernel_prototype.arguments, output_array, module_, ir_builder)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } VLOG(1) << "Could not emit fast concatenate for " << instr->ToString() << ": " << fast_impl_reason.message(); @@ -476,8 +482,7 @@ absl::StatusOr IrEmitter2::EmitDotFusionHostKernel( /*allow_runtime_calls=*/false)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( @@ -495,8 +500,7 @@ absl::StatusOr IrEmitter2::EmitSliceToDynamicHostKernel( TF_RETURN_IF_ERROR(nested_ir_emitter_->EmitSliceToDynamic( instr, kernel_prototype.arguments, output_array)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr @@ -513,8 +517,7 @@ IrEmitter2::EmitSelectAndScatterHostKernel(const HloInstruction* instr) { output_array)); return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + KernelInfo(std::move(kernel_prototype), se::BlockDim(), se::ThreadDim())); } absl::StatusOr @@ -534,9 +537,8 @@ IrEmitter2::EmitDynamicUpdateSliceHostKernel(const HloInstruction* instr) { kernel_prototype.arguments, kernel_prototype.results.front(), llvm_ir::IrName(instr, "in_place"), &b)); - return kernels_.emplace_back( - KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(), - se::ThreadDim()}); + return kernels_.emplace_back(KernelInfo(std::move(kernel_prototype), + se::BlockDim(), se::ThreadDim())); } return EmitElementalHostKernel(instr); @@ -564,6 +566,10 @@ absl::StatusOr IrEmitter2::EmitSortComparator( /*is_top_level_computation=*/true, schedule, /*allow_reassociation=*/false)); + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + comparator_function->setUWTableKind(llvm::UWTableKind::Default); + return comparators_.emplace_back( ComparatorInfo{comparator_function->getName().str()}); } @@ -710,6 +716,15 @@ llvm_ir::IrArray IrEmitter2::EmitKernelArgument(llvm::IRBuilder<>& b, // emit metadata to allow LLVM to use that information for optimization. llvm_ir::SetAlignmentMetadataForLoad(data, cpu_function_runtime::MinAlign()); + // All buffers pointers passed to host kernels are expected to be + // dereferenceable. + IrEmitter::AttachDereferenceableMetadataForLoad(data, ByteSizeOf(shape)); + + // All buffers pointers passed to host kernels are expected to be invariant + // over the whole program. Note the metadata is attached only to loading + // buffer pointers, not to loading actual buffers. + AttachInvariantLoadMetadataForLoad(data); + return llvm_ir::IrArray(data, llvm_ir::ShapeToIrType(shape, module_), shape); } @@ -780,10 +795,6 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( // Collect a set of invariant (read-only) buffer slices. If a buffer slice is // not a part of result set, then it must be a read-only buffer. - // - // TODO(ezhulenev): Pass this information to KernelThunk and add an extra run - // time check to verify that this property holds, as otherwise it can lead to - // hard to debug errors. absl::flat_hash_set invariant_slices; for (const KernelParameter& argument : arguments) { if (!result_slices.contains(argument.slice)) { @@ -791,11 +802,15 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( } } - // Create a kernel function with HostKernel API. - llvm::Function* function = llvm::dyn_cast( - module_->getOrInsertFunction(name, KernelFunctionTy(ctx)).getCallee()); + // Create a kernel function with HostKernel API. We use external linkage + // because we'll be resolving this function from the XLA runtime. + llvm::Function* function = llvm::Function::Create( + KernelFunctionTy(ctx), llvm::GlobalValue::ExternalLinkage, name, module_); function->setCallingConv(llvm::CallingConv::C); - function->setDoesNotThrow(); + + // Generate unwind information so that GDB can crawl through the stack frames + // created by the JIT compiled code. + function->setUWTableKind(llvm::UWTableKind::Default); // Set prefer-vector-width attribute to allow LLVM to use wider vector // registers (by default LLVM uses at most 256-bit registers). @@ -860,7 +875,8 @@ absl::StatusOr IrEmitter2::EmitKernelPrototype( kernel_thread_dims, kernel_thread, std::move(ir_arguments), - std::move(ir_results)}; + std::move(ir_results), + std::move(invariant_slices)}; } absl::StatusOr IrEmitter2::EmitKernelPrototype( @@ -1027,4 +1043,17 @@ absl::StatusOr IrEmitter2::EmitElementalLoops( return se::ThreadDim(); } +// This is a convenience function taken from IrEmitter, it uses module_ class +// field. If there will be more functions that use module_, we should consider +// refactoring (like we did for compute_function_ and builder_). +int64_t IrEmitter2::ByteSizeOf(const Shape& shape) const { + return llvm_ir::ByteSizeOf(shape, module_->getDataLayout()); +} + +void IrEmitter2::AttachInvariantLoadMetadataForLoad( + llvm::LoadInst* instr) const { + nested_ir_emitter_->AttachInvariantLoadMetadataForLoad(instr, + hlo_module_.config()); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter2.h b/third_party/xla/xla/service/cpu/ir_emitter2.h index c998840a24b330..b10f9034d19d2b 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2.h +++ b/third_party/xla/xla/service/cpu/ir_emitter2.h @@ -23,11 +23,13 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -61,6 +63,12 @@ namespace xla::cpu { // // WARNING: This is under construction and will eventually replace IrEmitter. class IrEmitter2 { + public: + friend class IrEmitter2Test; + + private: + struct KernelPrototype; + public: IrEmitter2(const HloModule& hlo_module, llvm::Module* module, IrEmitter* nested_ir_emitter); @@ -87,28 +95,16 @@ class IrEmitter2 { llvm::Value* z; }; - // A kernel function prototype with all the LLVM values that might be needed - // to emit the actual kernel body. - struct KernelPrototype { - llvm::Function* function; - llvm::BasicBlock* return_block; - - // LLVM values identifying kernel invocation thread coordinates. - KernelThreadDims thread_dims; - KernelThread thread; - - // LLVM values corresponding to the kernel arguments and results arrays. All - // tuples are flattened as we do not have any tuples at run time and only - // read and write data from/to leaf arrays. - std::vector arguments; - std::vector results; - }; - // Emitted kernel information that defines how to launch it at run time. struct KernelInfo { + explicit KernelInfo(KernelPrototype prototype, + const se::BlockDim& block_dims, + const se::ThreadDim& thread_dims); + std::string name; se::BlockDim block_dims; se::ThreadDim thread_dims; + absl::flat_hash_set invariant_buffers; }; // Emitted comparator function information (for sort operation). @@ -165,6 +161,30 @@ class IrEmitter2 { absl::StatusOr EmitSortComparator( const HloInstruction* instr); + private: + class ElementalIrEmitter; + + // A kernel function prototype with all the LLVM values that might be needed + // to emit the actual kernel body. + struct KernelPrototype { + llvm::Function* function; + llvm::BasicBlock* return_block; + + // LLVM values identifying kernel invocation thread coordinates. + KernelThreadDims thread_dims; + KernelThread thread; + + // LLVM values corresponding to the kernel arguments and results arrays. All + // tuples are flattened as we do not have any tuples at run time and only + // read and write data from/to leaf arrays. + std::vector arguments; + std::vector results; + + // Set containing all invariant (read-only) buffers. A buffer is read-only + // if it is not aliased with any result. + absl::flat_hash_set invariant_buffers; + }; + // Emits a host kernel prototype and prepares function for emitting kernel // body into it. absl::StatusOr EmitKernelPrototype( @@ -175,9 +195,6 @@ class IrEmitter2 { absl::StatusOr EmitKernelPrototype( const HloInstruction* instr); - private: - class ElementalIrEmitter; - // Parallel partition bounds for parallelized outer dimensions: // vector<[i64 lower_bound, i64 upper_bound]> using ParallelPartitionBounds = @@ -240,6 +257,13 @@ class IrEmitter2 { bool fast_min_max() const; + // Returns the number of bytes within the shape. + int64_t ByteSizeOf(const Shape& shape) const; + + // Given a load instruction, annotate the load's result with the invariant + // load metadata. + void AttachInvariantLoadMetadataForLoad(llvm::LoadInst* instr) const; + const HloModule& hlo_module_; llvm::Module* module_; diff --git a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc index 539facd06c1dca..b2e8414a344983 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter2_test.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter2_test.cc @@ -17,8 +17,11 @@ limitations under the License. #include #include +#include #include +#include "absl/memory/memory.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "llvm/IR/Constants.h" #include "llvm/IR/IRBuilder.h" @@ -42,9 +45,63 @@ limitations under the License. #include "tsl/platform/test.h" namespace xla::cpu { -namespace { -using IrEmitter2Test = HloTestBase; +class IrEmitter2Test : public HloTestBase { + public: + // This is a proxy function that allows us call private method + // IrEmitter2::EmitKernelPrototype. + static auto EmitKernelPrototype( + IrEmitter2& ir_emitter, + const std::vector& arguments, + const std::vector& results) { + return ir_emitter.EmitKernelPrototype("test", arguments, results); + } + + absl::StatusOr MakeIrEmitter2(llvm::Module& module, + const HloModule& hlo) { + TF_ASSIGN_OR_RETURN( + buffer_assignment_, + BufferAssigner::Run( + &hlo, std::make_unique(&hlo), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; })); + + target_machine_ = + std::make_unique( + [](int64_t size) { return 1; }); + + nested_ir_emitter_ = absl::WrapUnique( + new IrEmitter(nullptr, hlo, *buffer_assignment_, &module, {}, {}, {}, + target_machine_.get(), false)); + + return IrEmitter2(hlo, &module, nested_ir_emitter_.get()); + } + + // TODO(abanas): This function could be static. It requires making the + // underlying FindInstruction function static first. + absl::StatusOr EmitElementalHostKernel( + IrEmitter2& ir_emitter, HloModule& hlo, + std::string_view instruction_name) { + HloInstruction* instruction = FindInstruction(&hlo, instruction_name); + + if (instruction == nullptr) { + return absl::InternalError("Instruction not found"); + } + TF_ASSIGN_OR_RETURN(IrEmitter2::KernelInfo kernel, + ir_emitter.EmitElementalHostKernel(instruction)); + return kernel; + } + + private: + // Dependencies of IrEmitter2. These are created in MakeIrEmitter2 and kept + // alive for the duration of the test, because IrEmitter2 does not take + // ownership of them. + std::unique_ptr buffer_assignment_; + std::unique_ptr target_machine_; + std::unique_ptr nested_ir_emitter_; +}; + +namespace { TEST_F(IrEmitter2Test, BuildKernelPrototype) { auto hlo = std::make_unique("test", HloModuleConfig()); @@ -66,9 +123,8 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { {shape, res1}}; IrEmitter2 ir_emitter(*hlo, module.get(), /*nested_ir_emitter=*/nullptr); - TF_ASSERT_OK_AND_ASSIGN( - IrEmitter2::KernelPrototype prototype, - ir_emitter.EmitKernelPrototype("test", arguments, results)); + TF_ASSERT_OK_AND_ASSIGN(auto prototype, + EmitKernelPrototype(ir_emitter, arguments, results)); llvm::IRBuilder<> b(context); b.SetInsertPoint(prototype.function->getEntryBlock().getTerminator()); @@ -85,45 +141,45 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: define ptr @test(ptr %0) #0 { - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThreadDim, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThreadDim, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 0 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 1 - CHECK: getelementptr inbounds %SE_HOST_KernelThread, {{.*}} i32 2 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 0 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 1 + CHECK: getelementptr inbounds nuw %SE_HOST_KernelThread, {{.*}} i32 2 CHECK: load i64 CHECK: load i64 CHECK: load i64 - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 0, i32 0 - CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT:.+]] + CHECK: %[[ARG0:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0:.+]], !dereferenceable ![[DEREF_BYTES:.*]], !align ![[ALIGNMENT:.+]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 1, i32 0 - CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]] + CHECK: %[[ARG1:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 2, i32 0 - CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]] + CHECK: %[[ARG2:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] - CHECK-NEXT: getelementptr inbounds %SE_HOST_KernelCallFrame, {{.*}} i32 3 + CHECK-NEXT: getelementptr inbounds nuw %SE_HOST_KernelCallFrame, {{.*}} i32 3 CHECK: load ptr CHECK: getelementptr %SE_HOST_KernelArg, {{.*}} i32 3, i32 0 - CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !align ![[ALIGNMENT]] + CHECK: %[[ARG3:.+]] = load ptr, {{.*}}, !invariant.load ![[SCOPE0]], !dereferenceable ![[DEREF_BYTES]], !align ![[ALIGNMENT]] CHECK-NEXT: %[[PTR0:.+]] = getelementptr inbounds float, ptr %[[ARG0]] CHECK: load float, ptr %[[PTR0]], align 4, - CHECK-SAME: !invariant.load ![[SCOPE0:.+]], + CHECK-SAME: !invariant.load ![[SCOPE0]], CHECK-SAME: !noalias ![[SCOPE1:.+]] CHECK-NEXT: %[[PTR1:.+]] = getelementptr inbounds float, ptr %[[ARG1]] @@ -142,6 +198,8 @@ TEST_F(IrEmitter2Test, BuildKernelPrototype) { CHECK: ret ptr null CHECK: } + #0 = { uwtable "frame-pointer"="all" "prefer-vector-width"="256" } + CHECK-DAG: ![[DEREF_BYTES]] = !{i64 32} CHECK-DAG: ![[ALIGNMENT]] = !{i64 16} CHECK-DAG: ![[SCOPE0]] = !{} CHECK-DAG: ![[SCOPE1]] = !{![[RES0:.+]], ![[RES1:.+]]} @@ -165,25 +223,9 @@ TEST_F(IrEmitter2Test, EmitElementalKernel) { })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - HloInstruction* convert = FindInstruction(hlo.get(), "convert"); - ASSERT_NE(convert, nullptr); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - hlo.get(), std::make_unique(hlo.get()), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return /*alignment=*/1; })); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine( - [](int64_t size) { return 1; }); - - IrEmitter nested_ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), - {}, {}, {}, &target_machine, false); - - IrEmitter2 ir_emitter(*hlo, module.get(), &nested_ir_emitter); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - ir_emitter.EmitElementalHostKernel(convert)); + EmitElementalHostKernel(ir_emitter, *hlo, "convert")); ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: define ptr @convert(ptr %0) #0 { @@ -205,25 +247,9 @@ TEST_F(IrEmitter2Test, EmitParallelKernel) { })"; TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); - HloInstruction* convert = FindInstruction(hlo.get(), "convert"); - ASSERT_NE(convert, nullptr); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr buffer_assignment, - BufferAssigner::Run( - hlo.get(), std::make_unique(hlo.get()), - backend().compiler()->BufferSizeBytesFunction(), - [](LogicalBuffer::Color) { return /*alignment=*/1; })); - - TargetMachineFeaturesWithFakeAlignmentLogic target_machine( - [](int64_t size) { return 1; }); - - IrEmitter nested_ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), - {}, {}, {}, &target_machine, false); - - IrEmitter2 ir_emitter(*hlo, module.get(), &nested_ir_emitter); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, - ir_emitter.EmitElementalHostKernel(convert)); + EmitElementalHostKernel(ir_emitter, *hlo, "convert")); ASSERT_TRUE(*RunFileCheck(llvm_ir::DumpToString(module.get()), R"( CHECK: @convert_parallel_bounds = private constant [8 x [4 x [2 x i64]]] @@ -242,5 +268,66 @@ TEST_F(IrEmitter2Test, EmitParallelKernel) { )")); } +using IrEmitter2InvariantBuffersTest = IrEmitter2Test; + +TEST_F(IrEmitter2InvariantBuffersTest, AllInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + ASSERT_EQ(kernel.invariant_buffers.size(), 1); +} + +TEST_F(IrEmitter2InvariantBuffersTest, NoInvariantBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (0, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + ROOT add.0 = f32[2,2] add(p0, p0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + ASSERT_EQ(kernel.invariant_buffers.size(), 0); +} + +TEST_F(IrEmitter2InvariantBuffersTest, MixedBuffers) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m, input_output_alias={ {}: (1, {}, must-alias) } + ENTRY main { + p0 = f32[2,2] parameter(0) + p1 = f32[2,2] parameter(1) + ROOT add.0 = f32[2,2] add(p0, p1) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2 ir_emitter, MakeIrEmitter2(*module, *hlo)); + TF_ASSERT_OK_AND_ASSIGN(IrEmitter2::KernelInfo kernel, + EmitElementalHostKernel(ir_emitter, *hlo, "add.0")); + + // TODO(abanas): Verify also which buffer is read-only, not only the count. + ASSERT_EQ(kernel.invariant_buffers.size(), 1); +} + } // namespace } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/ir_emitter_test.cc b/third_party/xla/xla/service/cpu/ir_emitter_test.cc new file mode 100644 index 00000000000000..7102d20421df4b --- /dev/null +++ b/third_party/xla/xla/service/cpu/ir_emitter_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/ir_emitter.h" + +#include +#include +#include + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/Casting.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/buffer_assignment.h" +#include "xla/service/cpu/ir_function.h" +#include "xla/service/cpu/target_machine_features_fake.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_ordering.h" +#include "xla/service/hlo_parser.h" +#include "xla/service/logical_buffer.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +using IrEmitterTest = HloTestBase; + +static std::pair CreateFunction( + llvm::LLVMContext& context, llvm::Module* module, llvm::IRBuilder<>* b) { + llvm::PointerType* ptrtype = llvm::PointerType::getUnqual(context); + llvm::FunctionType* ftype = llvm::FunctionType::get(ptrtype, ptrtype, false); + + llvm::Function* function = llvm::dyn_cast( + module->getOrInsertFunction("func2", ftype).getCallee()); + + llvm::BasicBlock* return_block = + llvm::BasicBlock::Create(context, "", function); + b->SetInsertPoint(return_block); + [[maybe_unused]] llvm::ReturnInst* ret = b->CreateRet( + llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(context))); + + return std::make_pair(function, return_block); +} + +TEST_F(IrEmitterTest, ComputeFuncStack) { + llvm::LLVMContext context; + auto module = std::make_unique("test", context); + + const char* hlo_text = R"( + HloModule m + ENTRY main { + ROOT %zero = f32[] constant(0) + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto hlo, ParseAndReturnUnverifiedModule(hlo_text)); + const HloInstruction* zero = FindInstruction(hlo.get(), "zero"); + ASSERT_NE(zero, nullptr); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr buffer_assignment, + BufferAssigner::Run( + hlo.get(), std::make_unique(hlo.get()), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return /*alignment=*/1; })); + + TargetMachineFeaturesWithFakeAlignmentLogic target_machine( + [](int64_t size) { return 1; }); + + IrEmitter ir_emitter(nullptr, *hlo, *buffer_assignment, module.get(), {}, {}, + {}, &target_machine, false); + + llvm::IRBuilder<>* b = ir_emitter.b(); + ASSERT_NE(b, nullptr); + + const std::pair fb = + CreateFunction(context, module.get(), b); + + llvm::Function* function = fb.first; + llvm::BasicBlock* return_block = fb.second; + + ASSERT_NE(function, nullptr); + ASSERT_NE(return_block, nullptr); + + const auto funcname = "func1"; + const auto linkagetype = llvm::GlobalValue::LinkageTypes::ExternalLinkage; + const HloModuleConfig module_config; + ir_emitter.PushComputeFunction(funcname, linkagetype, module_config, + module.get(), 0); + ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(), + funcname); + + ir_emitter.PushComputeFunction(b, module.get(), 0, function, nullptr, + return_block); + ASSERT_EQ(ir_emitter.compute_function()->function(), function); + + ir_emitter.PopComputeFunction(); + ASSERT_EQ(ir_emitter.compute_function()->function()->getName().str(), + funcname); + + ir_emitter.PopComputeFunction(); +} + +} // namespace +} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc similarity index 88% rename from third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc rename to third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc index 45c6bc17a41b20..19122b393ce23b 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.cc +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.cc @@ -17,7 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" #include "xla/executable_run_options.h" #include "xla/hlo/evaluator/hlo_evaluator.h" @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/onednn_config.pb.h" +#include "xla/service/cpu/onednn_convolution.h" #include "xla/service/cpu/onednn_matmul.h" #include "xla/service/cpu/onednn_memory_util.h" #include "xla/service/cpu/onednn_pattern_utils.h" @@ -364,7 +365,8 @@ inline auto OptionalConvertAndBitcast(HloInstruction** optional_convert, } // namespace -bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { +bool OneDnnContractionRewriter::ShouldRewriteDot( + const HloInstruction* dot_instr, bool before_layout_assignment) { // Currently, blocking control dependencies if (dot_instr->HasControlDependencies()) return false; if (!IsSupportedType(dot_instr->shape().element_type())) return false; @@ -396,12 +398,13 @@ bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { // Layout should be row-major, contraction dimensions captures transpose // scenarios in last two dimensions. - // Col-major layouts are corrected to row-majow for BatchDot operation as - // part of the layout-pass. - if (!IsBatchDot(*dot_instr) && - (!IsRowMajor(lhs_shape) || !IsRowMajor(rhs_shape) || - !IsRowMajor(output_shape))) { - return false; + // Col-major layouts are corrected to row-major for BatchDot operation as + // part of the layout-assignment pass. + // Skip row-major layout check before layout-assignment pass + if (!before_layout_assignment) { + bool row_major = IsRowMajor(lhs_shape) && IsRowMajor(rhs_shape) && + IsRowMajor(output_shape); + if (!row_major) return false; } auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); @@ -424,7 +427,37 @@ bool OneDnnMatMulRewriter::ShouldRewrite(const HloInstruction* dot_instr) { return (num_flops >= flops_threshold); } -class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { +bool OneDnnContractionRewriter::ShouldRewriteConv( + const HloInstruction* conv_instr) { + if (conv_instr->HasControlDependencies()) return false; + if (!IsSupportedType(conv_instr->shape().element_type())) return false; + if (conv_instr->batch_group_count() != 1) return false; + + // TODO(intel-tf): Remove this restriction after enabling backward weights + // support + if (conv_instr->operand(1)->opcode() == HloOpcode::kReverse) return false; + + const Shape& inp_shape = conv_instr->operand(0)->shape(); + const Shape& ker_shape = conv_instr->operand(1)->shape(); + const Shape& out_shape = conv_instr->shape(); + if (ShapeUtil::IsZeroElementArray(inp_shape) || + ShapeUtil::IsZeroElementArray(ker_shape) || + ShapeUtil::IsZeroElementArray(out_shape)) { + return false; + } + + auto dims = conv_instr->window().dimensions().size(); + if (dims >= 4 || dims <= 0) return false; + + if (inp_shape.rank() != ker_shape.rank() || + inp_shape.rank() != out_shape.rank()) { + return false; + } + + return true; +} + +class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { public: // Matches patterns for possible MatMul fusions that are supported by oneDNN // library. Matched HLO instruction(s) are replaced by custom call. @@ -435,8 +468,10 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR( ValidateDotDimensionNumbers(dot_instr->dot_dimension_numbers())); - if (!OneDnnMatMulRewriter::ShouldRewrite(dot_instr)) + if (!OneDnnContractionRewriter::ShouldRewriteDot(dot_instr)) { + TF_RETURN_IF_ERROR(UpcastDotToF32(dot_instr)); return absl::OkStatus(); + } TF_ASSIGN_OR_RETURN(dot_instr, ReconfigureDotDimensions(dot_instr)); auto dot_dim_numbers = dot_instr->dot_dimension_numbers(); const Shape& lhs_shape = dot_instr->operand(0)->shape(); @@ -464,6 +499,72 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { return absl::OkStatus(); } + absl::Status HandleConvolution(HloInstruction* conv) override { + if (!OneDnnContractionRewriter::ShouldRewriteConv(conv)) { + return absl::OkStatus(); + } + + const Shape& conv_shape = conv->shape(); + auto dims = conv->window().dimensions().size(); + const ConvolutionDimensionNumbers& conv_dims = + conv->convolution_dimension_numbers(); + + BackendConfig backend_config; + OneDnnConvolutionConfig* conv_config = + backend_config.mutable_onednn_conv_config(); + + conv_config->set_dims(conv_shape.rank()); + conv_config->set_feature_groups(conv->feature_group_count()); + conv_config->mutable_input()->mutable_data()->set_batch_dim( + conv_dims.input_batch_dimension()); + conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim( + conv_dims.kernel_input_feature_dimension()); + conv_config->mutable_output()->mutable_data()->set_batch_dim( + conv_dims.output_batch_dimension()); + conv_config->mutable_input()->mutable_data()->set_feature_dim( + conv_dims.input_feature_dimension()); + conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim( + conv_dims.kernel_output_feature_dimension()); + conv_config->mutable_output()->mutable_data()->set_feature_dim( + conv_dims.output_feature_dimension()); + + const Shape& output_shape = conv->shape(); + + for (auto it = conv->window().dimensions().begin(); + it != conv->window().dimensions().end(); it++) { + if ((*it).padding_low() < 0 || (*it).padding_high() < 0 || + (*it).stride() < 0 || (*it).base_dilation() != 1 || + (*it).window_reversal()) { + return absl::OkStatus(); + } + // Changing the input subspace of uint repeated fields from whole numbers + // to natural nummbers to avoid misinterpretation of buffer values. + conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1); + conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1); + conv_config->mutable_window()->add_strides((*it).stride() + 1); + conv_config->mutable_window()->add_window_dilations( + (*it).window_dilation() + 1); + } + + for (int i = 0; i < dims; i++) { + conv_config->mutable_input()->mutable_data()->add_spatial_dims( + conv_dims.input_spatial_dimensions()[i] + 1); + conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims( + conv_dims.kernel_spatial_dimensions()[i] + 1); + conv_config->mutable_output()->mutable_data()->add_spatial_dims( + conv_dims.output_spatial_dimensions()[i] + 1); + } + + HloInstruction* custom_call = + conv->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)}, + "__onednn$convolution")); + + TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call)); + return absl::OkStatus(); + } + absl::Status HandleAdd(HloInstruction* instr) override { // Try to do a fusion for Dot(onednn-matmul) + Add. However, // HLO Add instruction might receive the addends after additional @@ -917,6 +1018,34 @@ class OneDnnMatMulRewriteVisitor : public DfsHloRewriteVisitor { TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr)); return adjusted_dot; } + + // This function upcasts BF16 dots to F32 if we are unable to rewrite them to + // oneDNN custom calls. + absl::Status UpcastDotToF32(HloInstruction* dot_instr) { + if (dot_instr->shape().element_type() != BF16) return absl::OkStatus(); + std::vector new_operands; + auto bf16_operands = dot_instr->operands(); + + std::for_each( + bf16_operands.begin(), bf16_operands.end(), + [&new_operands](HloInstruction* instr) { + new_operands.push_back( + instr->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(instr->shape(), F32), instr))); + }); + + HloInstruction* f32_dot = + dot_instr->AddInstruction(dot_instr->CloneWithNewOperands( + ShapeUtil::ChangeElementType(dot_instr->shape(), F32), + new_operands)); + + HloInstruction* replacement_instr = + f32_dot->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(f32_dot->shape(), BF16), f32_dot)); + + TF_RETURN_IF_ERROR(ReplaceInstruction(dot_instr, replacement_instr)); + return absl::OkStatus(); + } }; class OneDnnPostRewriteVisitor : public DfsHloRewriteVisitor { @@ -1108,10 +1237,10 @@ EMIT_SET_BACKEND_CONFIG_SPECIALIZATION(SetUserScratch, OneDnnMatMulConfig, onednn_matmul_config, optimization_config, user_scratchpad); -absl::StatusOr OneDnnMatMulRewriter::Run( +absl::StatusOr OneDnnContractionRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - OneDnnMatMulRewriteVisitor visitor; + OneDnnContractionRewriteVisitor visitor; TF_ASSIGN_OR_RETURN(auto result, visitor.RunOnModule(module, execution_threads)); diff --git a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h similarity index 64% rename from third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h rename to third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h index 7ad7f76ab3a5b5..7864dae961386b 100644 --- a/third_party/xla/xla/service/cpu/onednn_matmul_rewriter.h +++ b/third_party/xla/xla/service/cpu/onednn_contraction_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ -#define XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ +#ifndef XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_ +#define XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_ #if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) #include @@ -29,23 +29,27 @@ limitations under the License. namespace xla { namespace cpu { -// This pass pattern-matches HLO Dot instructions and rewrites into custom -// calls. -class OneDnnMatMulRewriter : public HloModulePass { +// This pass pattern-matches HLO Dot and Convolution instructions and rewrites +// them into custom calls. +class OneDnnContractionRewriter : public HloModulePass { public: - OneDnnMatMulRewriter(int intra_op_parallelism, - const tsl::thread::ThreadPool* compile_threadpool) + OneDnnContractionRewriter(int intra_op_parallelism, + const tsl::thread::ThreadPool* compile_threadpool) : intra_op_parallelism_(intra_op_parallelism), compile_threadpool_(compile_threadpool) {} - OneDnnMatMulRewriter() = default; - absl::string_view name() const override { return "onednn-matmul-rewriter"; } + OneDnnContractionRewriter() = default; + absl::string_view name() const override { + return "onednn-contraction-rewriter"; + } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; - static bool ShouldRewrite(const HloInstruction* dot_instr); + static bool ShouldRewriteDot(const HloInstruction* dot_instr, + bool before_layout_assignment = false); + static bool ShouldRewriteConv(const HloInstruction* conv_instr); private: int intra_op_parallelism_; @@ -56,4 +60,4 @@ class OneDnnMatMulRewriter : public HloModulePass { } // namespace xla #endif // INTEL_MKL && ENABLE_ONEDNN_V3 -#endif // XLA_SERVICE_CPU_ONEDNN_MATMUL_REWRITER_H_ +#endif // XLA_SERVICE_CPU_ONEDNN_CONTRACTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc deleted file mode 100644 index 0c65c5d3dd2ff2..00000000000000 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.cc +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - -#include "xla/service/cpu/onednn_convolution_rewriter.h" - -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/cpu/backend_config.pb.h" -#include "xla/service/cpu/onednn_config.pb.h" -#include "xla/service/cpu/onednn_memory_util.h" -#include "xla/service/cpu/onednn_util.h" -#include "xla/service/pattern_matcher.h" -#include "xla/status_macros.h" - -namespace xla { -namespace cpu { - -namespace { -namespace m = match; -} // namespace - -bool OneDnnConvolutionRewriter::ShouldRewrite(const HloInstruction* conv) { - if (conv->HasControlDependencies()) return false; - if (!IsSupportedType(conv->shape().element_type())) return false; - if (conv->batch_group_count() != 1) return false; - - if (conv->operand(1)->opcode() == HloOpcode::kReverse) return false; - - const Shape& inp_shape = conv->operand(0)->shape(); - const Shape& ker_shape = conv->operand(1)->shape(); - const Shape& out_shape = conv->shape(); - if (ShapeUtil::IsZeroElementArray(inp_shape) || - ShapeUtil::IsZeroElementArray(ker_shape) || - ShapeUtil::IsZeroElementArray(out_shape)) { - return false; - } - - auto dims = conv->window().dimensions().size(); - if (dims >= 4 || dims <= 0) return false; - - if (inp_shape.rank() != ker_shape.rank() || - inp_shape.rank() != out_shape.rank()) { - return false; - } - - return true; -} - -class OneDnnConvolutionRewriterVisitor : public DfsHloRewriteVisitor { - public: - absl::Status HandleConvolution(HloInstruction* conv) override { - auto pattern = match::Op(&conv).WithOpcode(HloOpcode::kConvolution); - if (!Match(conv, pattern)) return absl::OkStatus(); - if (!OneDnnConvolutionRewriter::ShouldRewrite(conv)) { - return absl::OkStatus(); - } - - const Shape& conv_shape = conv->shape(); - auto dims = conv->window().dimensions().size(); - const ConvolutionDimensionNumbers& conv_ddata = - conv->convolution_dimension_numbers(); - - BackendConfig backend_config; - OneDnnConvolutionConfig* conv_config = - backend_config.mutable_onednn_conv_config(); - - conv_config->set_dims(conv_shape.rank()); - conv_config->set_feature_groups(conv->feature_group_count()); - conv_config->mutable_input()->mutable_data()->set_batch_dim( - conv_ddata.input_batch_dimension()); - conv_config->mutable_kernel()->mutable_filter()->set_input_feature_dim( - conv_ddata.kernel_input_feature_dimension()); - conv_config->mutable_output()->mutable_data()->set_batch_dim( - conv_ddata.output_batch_dimension()); - conv_config->mutable_input()->mutable_data()->set_feature_dim( - conv_ddata.input_feature_dimension()); - conv_config->mutable_kernel()->mutable_filter()->set_output_feature_dim( - conv_ddata.kernel_output_feature_dimension()); - conv_config->mutable_output()->mutable_data()->set_feature_dim( - conv_ddata.output_feature_dimension()); - - const Shape& output_shape = conv->shape(); - - for (auto it = conv->window().dimensions().begin(); - it != conv->window().dimensions().end(); it++) { - if ((*it).padding_low() < 0 || (*it).padding_high() < 0 || - (*it).stride() < 0) { - return absl::OkStatus(); - } - conv_config->mutable_window()->add_pad_left((*it).padding_low() + 1); - conv_config->mutable_window()->add_pad_right((*it).padding_high() + 1); - conv_config->mutable_window()->add_strides((*it).stride() + 1); - conv_config->mutable_window()->add_window_dilations( - (*it).window_dilation() + 1); - if ((*it).base_dilation() != 1 || (*it).window_reversal()) { - return absl::OkStatus(); - } - } - - for (int i = 0; i < dims; i++) { - conv_config->mutable_input()->mutable_data()->add_spatial_dims( - conv_ddata.input_spatial_dimensions()[i] + 1); - conv_config->mutable_kernel()->mutable_filter()->add_spatial_dims( - conv_ddata.kernel_spatial_dimensions()[i] + 1); - conv_config->mutable_output()->mutable_data()->add_spatial_dims( - conv_ddata.output_spatial_dimensions()[i] + 1); - } - - HloInstruction* custom_call = - conv->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, {conv->mutable_operand(0), conv->mutable_operand(1)}, - "__onednn$convolution")); - - TF_RETURN_IF_ERROR(custom_call->set_backend_config(backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(conv, custom_call)); - return absl::OkStatus(); - } -}; - -absl::StatusOr OneDnnConvolutionRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - OneDnnConvolutionRewriterVisitor visitor; - return visitor.RunOnModule(module, execution_threads); -} - -} // namespace cpu -} // namespace xla - -#endif // INTEL_MKL && ENABLE_ONEDNN_V3 diff --git a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h b/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h deleted file mode 100644 index 2dbd3a66eb48c2..00000000000000 --- a/third_party/xla/xla/service/cpu/onednn_convolution_rewriter.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ -#define XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ -#if defined(INTEL_MKL) && defined(ENABLE_ONEDNN_V3) - -#include - -#include "absl/algorithm/container.h" -#include "absl/status/statusor.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/hlo_pass_interface.h" - -namespace xla { -namespace cpu { - -// This pass converts hlo convolution instructions into a single oneDNN -// operation and rewrites into custom calls. -class OneDnnConvolutionRewriter : public HloModulePass { - public: - absl::string_view name() const override { - return "onednn-convolution-rewriter"; - } - - using HloPassInterface::Run; - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override; - - static bool ShouldRewrite(const HloInstruction* instr); -}; - -} // namespace cpu -} // namespace xla - -#endif // INTEL_MKL && ENABLE_ONEDNN_V3 -#endif // XLA_SERVICE_CPU_ONEDNN_CONVOLUTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/conv_impl.cc b/third_party/xla/xla/service/cpu/runtime/conv_impl.cc deleted file mode 100644 index 199a97919fa53e..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/conv_impl.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#define EIGEN_USE_THREADS - -#include "xla/service/cpu/runtime/conv_impl.h" - -namespace tensorflow::xla { - -// Instantiate Conv2D template for all supported devices and data types. -#define CONV2D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ - template void EigenConv2DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ - Eigen::Index input_y, Eigen::Index input_channels, \ - Eigen::Index kernel_x, Eigen::Index kernel_y, \ - Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \ - Eigen::Index y_stride, Eigen::Index padding_x_before, \ - Eigen::Index padding_x_after, Eigen::Index padding_y_before, \ - Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \ - Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \ - Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count, \ - std::optional> done_callback) - -CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); -CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); - -#undef CONV2D_INSTANTIATE_TEMPLATE - -// Instantiate Conv3D template for all supported devices and data types. -#define CONV3D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \ - template void EigenConv3DImpl( \ - const EigenDevice& device, ScalarType* out, ScalarType* lhs, \ - ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \ - Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \ - Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \ - Eigen::Index kernel_channels, Eigen::Index kernel_filters, \ - Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \ - Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \ - Eigen::Index padding_x_before, Eigen::Index padding_x_after, \ - Eigen::Index padding_y_before, Eigen::Index padding_y_after, \ - Eigen::Index padding_z_before, Eigen::Index padding_z_after, \ - Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \ - Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \ - Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \ - Eigen::Index feature_group_count, \ - std::optional> done_callback) - -CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half); -CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float); - -} // namespace tensorflow::xla diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc deleted file mode 100644 index 7ac73162df4d9b..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.cc +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/kernel_thunk.h" - -#define EIGEN_USE_THREADS - -#include -#include -#include -#include -#include -#include - -#include "absl/base/optimization.h" -#include "absl/memory/memory.h" -#include "absl/numeric/bits.h" -#include "absl/status/status.h" -#include "absl/strings/str_format.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "unsupported/Eigen/CXX11/Tensor" -#include "llvm/ADT/SmallVector.h" -#include "xla/runtime/buffer_use.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/statusor.h" -#include "tsl/profiler/lib/traceme.h" - -namespace xla::cpu { - -absl::StatusOr> KernelThunk::Create( - Info info, absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) { - if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) { - return Internal("Host kernel %s minimum alignment %d is not a power of 2", - info.op_name, *min_alignment); - } - - return absl::WrapUnique( - new KernelThunk(std::move(info), arguments_buffers, results_buffers, - std::move(kernel_name), thread_dim, min_alignment)); -} - -KernelThunk::KernelThunk( - Info info, absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment) - : Thunk(Kind::kKernel, std::move(info)), - arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()), - results_buffers_(results_buffers.begin(), results_buffers.end()), - num_kernel_args_(arguments_buffers.size() + results_buffers.size()), - kernel_name_(std::move(kernel_name)), - thread_dim_(thread_dim), - min_alignment_(min_alignment), - call_once_(thread_dim_ == se::ThreadDim()), - kernel_ptr_(nullptr) {} - -tsl::AsyncValueRef KernelThunk::Execute( - const ExecuteParams& params) { - tsl::profiler::TraceMe trace([&] { return TraceMeEncode(); }); - - VLOG(3) << absl::StreamFormat( - "Launch host kernel %s with %d arguments buffers and %d results buffers: " - "#threads=%s", - kernel_name_, arguments_buffers_.size(), results_buffers_.size(), - thread_dim_.ToString()); - - // We use `llvm::SmallVector` instead of `absl::InlinedVector` because - // it allows to resize a vector without zero-initializing storage. - llvm::SmallVector kernel_args; - kernel_args.resize_for_overwrite(num_kernel_args_); - - SE_HOST_KernelArg* kernel_args_ptr = kernel_args.data(); - const BufferAllocations* allocations = params.buffer_allocations; - - for (BufferAllocation::Slice& buffer : arguments_buffers_) { - if constexpr (ShouldCheckBufferSlices()) { - TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; - } else { - auto mem = allocations->GetDeviceAddressUnchecked(buffer); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; - } - } - - for (BufferAllocation::Slice& buffer : results_buffers_) { - if constexpr (ShouldCheckBufferSlices()) { - TF_ASSIGN_OR_RETURN(auto mem, allocations->GetDeviceAddress(buffer)); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; - } else { - auto mem = allocations->GetDeviceAddressUnchecked(buffer); - *kernel_args_ptr++ = SE_HOST_KernelArg{mem.opaque(), mem.size()}; - } - } - - if (ABSL_PREDICT_FALSE(VLOG_IS_ON(3))) { - VlogKernelArgs(kernel_args); - } - - // Сheck that all resolved buffers are properly aligned. - if constexpr (ShouldCheckBufferSlices()) { - TF_RETURN_IF_ERROR(CheckBufferAlignment(kernel_args)); - } - - // TODO(ezhulenev): Kernel ptr should be loaded as a part of Thunk - // initialization stage. - se::host::HostKernel* kernel = kernel_ptr_.load(); - - // Because thunks are owned by a parent CpuExecutable, we can safely assume - // that kernel pointer will not change after we find it the first time. - if (ABSL_PREDICT_FALSE(kernel == nullptr)) { - TF_ASSIGN_OR_RETURN(SE_HOST_Kernel * kernel_fn, - params.function_registry->FindKernel(kernel_name_)); - - absl::MutexLock lock(&mutex_); - kernel_.emplace(num_kernel_args_, kernel_fn, nullptr); - kernel_ptr_.store(kernel = &kernel_.value()); - } - - // Use a fast path if kernel called just once. - if (ABSL_PREDICT_TRUE(call_once_)) { - TF_RETURN_IF_ERROR(kernel->CallOnce(kernel_args)); - return OkExecuteEvent(); - } - - // If intra-op thread pool is not nullptr, we launch HostKernel in async mode - // by scheduling tasks into it. HostKernel launch completion will - // automatically signal KernelThunk execute completion. - if (ABSL_PREDICT_TRUE(params.intra_op_threadpool)) { - return kernel->Launch( - thread_dim_, kernel_args, [¶ms](se::host::HostKernel::Task task) { - params.intra_op_threadpool->getPool()->Schedule(std::move(task)); - }); - } - - TF_RETURN_IF_ERROR(kernel->Launch(thread_dim_, kernel_args)); - return OkExecuteEvent(); -} - -absl::Status KernelThunk::CheckBufferAlignment( - absl::Span kernel_args) { - if (min_alignment_.has_value()) { - for (int64_t i = 0; i < num_kernel_args_; ++i) { - auto ptr = reinterpret_cast(kernel_args[i].data); - if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) { - return Internal( - "Host kernel %s buffer argument #%d (%p) is not aligned to a " - "required minimum alignment of %d bytes", - info().op_name, i, kernel_args[i].data, *min_alignment_); - } - } - } - return absl::OkStatus(); -} - -void KernelThunk::VlogKernelArgs( - absl::Span kernel_args) { - for (int64_t i = 0; i < arguments_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat(" arg #%d: %s (%p)", i, - arguments_buffers_[i].ToString(), - kernel_args[i].data); - } - for (int64_t i = 0; i < results_buffers_.size(); ++i) { - VLOG(3) << absl::StreamFormat( - " res #%d: %s (%p)", i, results_buffers_[i].ToString(), - kernel_args[arguments_buffers_.size() + i].data); - } -} - -KernelThunk::BufferUses KernelThunk::buffer_uses() const { - BufferUses buffer_uses; - for (const BufferAllocation::Slice& buffer : arguments_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kRead); - } - for (const BufferAllocation::Slice& buffer : results_buffers_) { - buffer_uses.emplace_back(buffer, BufferUse::kWrite); - } - return buffer_uses; -} - -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h b/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h deleted file mode 100644 index 80bf16a4573916..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ -#define XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/base/thread_annotations.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/synchronization/mutex.h" -#include "absl/types/span.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/stream_executor/host/host_kernel.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/tsl/concurrency/async_value_ref.h" - -namespace xla::cpu { - -// Launches compiled host kernel on the caller thread. -class KernelThunk final : public Thunk { - public: - static absl::StatusOr> Create( - Info info, absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment = std::nullopt); - - tsl::AsyncValueRef Execute(const ExecuteParams& params) final; - - BufferUses buffer_uses() const final; - - private: - KernelThunk(Info info, - absl::Span arguments_buffers, - absl::Span results_buffers, - std::string kernel_name, se::ThreadDim thread_dim, - std::optional min_alignment); - - // Checks that all buffers are aligned to the minimum alignment. We codegen - // with the assumption that all buffers are aligned, and if they are not, we - // will crash with a segmentation fault, or worse, produce incorrect results. - absl::Status CheckBufferAlignment( - absl::Span kernel_args); - - void VlogKernelArgs(absl::Span kernel_args); - - std::vector arguments_buffers_; - std::vector results_buffers_; - - size_t num_kernel_args_; - - std::string kernel_name_; - se::ThreadDim thread_dim_; - std::optional min_alignment_; - - // If `true`, host kernel will be called just once for a logical thread dim - // (1,1,1). This is a fast path for small host kernels that have just one - // logical thread dim. - bool call_once_; - - // Lazily loaded host kernel corresponding to `kernel_name_`. - absl::Mutex mutex_; - std::optional kernel_ ABSL_GUARDED_BY(mutex_); - std::atomic kernel_ptr_; // pointer to `kernel_` -}; - -} // namespace xla::cpu - -#endif // XLA_SERVICE_CPU_RUNTIME_KERNEL_THUNK_H_ diff --git a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc b/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc deleted file mode 100644 index 63696c0e83278c..00000000000000 --- a/third_party/xla/xla/service/cpu/runtime/kernel_thunk_test.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/cpu/runtime/kernel_thunk.h" - -#include -#include -#include -#include - -#include "absl/status/statusor.h" -#include "absl/strings/match.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/cpu/runtime/buffer_allocations.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/maybe_owning_device_memory.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/launch_dim.h" -#include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/platform/statusor.h" -#include "tsl/platform/test.h" - -namespace xla::cpu { -namespace { - -class AddF32HostKernel : public Thunk::FunctionRegistry { - public: - absl::StatusOr FindKernel(std::string_view name) override { - return +[](const SE_HOST_KernelCallFrame* call_frame) { - const SE_HOST_KernelArg& in = call_frame->args[0]; - const SE_HOST_KernelArg& out = call_frame->args[1]; - - float* in_ptr = reinterpret_cast(in.data); - float* out_ptr = reinterpret_cast(out.data); - - uint64_t i = call_frame->thread->x; - *(out_ptr + i) = *(in_ptr + i) + *(in_ptr + i); - - return static_cast(nullptr); - }; - } -}; - -TEST(KernelThunkTest, CheckAlignment) { - auto thunk = KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(), - /*min_alignment=*/3); - EXPECT_TRUE(absl::StrContains(thunk.status().message(), - "minimum alignment 3 is not a power of 2")); -} - -TEST(KernelThunkTest, AddF32) { - std::vector buffers; - std::vector in = {1.0, 2.0, 3.0, 4.0}; - std::vector out(4, 0.0); - - size_t size_in_bytes = in.size() * sizeof(float); - buffers.emplace_back(se::DeviceMemoryBase(in.data(), size_in_bytes)); - buffers.emplace_back(se::DeviceMemoryBase(out.data(), size_in_bytes)); - - BufferAllocations allocations(buffers); - - BufferAllocation in_alloc(0, size_in_bytes, 0); - BufferAllocation out_alloc(1, size_in_bytes, 0); - - BufferAllocation::Slice in_slice(&in_alloc, 0, size_in_bytes); - BufferAllocation::Slice out_slice(&out_alloc, 0, size_in_bytes); - - TF_ASSERT_OK_AND_ASSIGN( - auto thunk, KernelThunk::Create({"add_f32"}, {in_slice}, {out_slice}, - "add_f32", se::ThreadDim(4))); - - AddF32HostKernel host_kernels; - Thunk::ExecuteParams params = {&host_kernels, &allocations}; - - auto execute_event = thunk->Execute(params); - tsl::BlockUntilReady(execute_event); - ASSERT_FALSE(execute_event.IsError()); - - std::vector expected = {2.0, 4.0, 6.0, 8.0}; - EXPECT_EQ(out, expected); -} - -} // namespace -} // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/runtime_conv2d.cc b/third_party/xla/xla/service/cpu/runtime_conv2d.cc index 907f0f57346020..696f556b20fd7a 100644 --- a/third_party/xla/xla/service/cpu/runtime_conv2d.cc +++ b/third_party/xla/xla/service/cpu/runtime_conv2d.cc @@ -15,11 +15,13 @@ limitations under the License. #include "xla/service/cpu/runtime_conv2d.h" +#include + #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/runtime/conv_impl.h" #include "xla/service/cpu/runtime_lightweight_check.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32( @@ -35,13 +37,13 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF32( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16( @@ -57,11 +59,11 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv2DF16( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } diff --git a/third_party/xla/xla/service/cpu/runtime_conv3d.cc b/third_party/xla/xla/service/cpu/runtime_conv3d.cc index ad86203609e1aa..fee2293d73fd97 100644 --- a/third_party/xla/xla/service/cpu/runtime_conv3d.cc +++ b/third_party/xla/xla/service/cpu/runtime_conv3d.cc @@ -15,11 +15,13 @@ limitations under the License. #include "xla/service/cpu/runtime_conv3d.h" +#include + #define EIGEN_USE_THREADS #include "absl/base/dynamic_annotations.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" #include "xla/executable_run_options.h" -#include "xla/service/cpu/runtime/conv_impl.h" #include "xla/service/cpu/runtime_lightweight_check.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32( @@ -37,14 +39,14 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF32( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16( @@ -62,12 +64,12 @@ ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConv3DF16( const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); XLA_LIGHTWEIGHT_CHECK(run_options->intra_op_thread_pool() != nullptr); - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } diff --git a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc index 3c9c8d1e3d3eb6..6b07b41aad00cc 100644 --- a/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc +++ b/third_party/xla/xla/service/cpu/runtime_handle_ffi_call.cc @@ -142,9 +142,9 @@ static absl::Status BuildAndCallFfi( // Forward executable run options to the FFI handlers via the call options. ffi::CallOptions call_options = { - run_options->device_ordinal(), run_options->stream(), - run_options->allocator(), /*called_computation=*/nullptr, - run_options->ffi_execution_context()}; + run_options->device_ordinal(), + ffi::CallOptions::CpuOptions{run_options->intra_op_thread_pool()}, + /*called_computation=*/nullptr, run_options->ffi_execution_context()}; ffi::CallFrame call_frame = builder.Build(); return ffi::Call(registration->bundle.execute, call_frame, call_options); diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc index 999e53cc296025..bc749f5c42be20 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/service/cpu/runtime_single_threaded_conv2d.h" +#include + #include "absl/base/dynamic_annotations.h" -#include "xla/service/cpu/runtime/conv_impl.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConv2DF16( @@ -29,13 +31,13 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF16( int64_t padding_left, int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation, int64_t rhs_row_dilation, int64_t rhs_col_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -49,11 +51,11 @@ __xla_cpu_runtime_EigenSingleThreadedConv2DF32( int64_t padding_right, int64_t lhs_row_dilation, int64_t lhs_col_dilation, int64_t rhs_row_dilation, int64_t rhs_col_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv2DImpl( + xla::cpu::internal::EigenConv2D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, padding_top, padding_bottom, padding_left, padding_right, lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation, - feature_group_count); + feature_group_count, std::nullopt); } diff --git a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc index 91dd6c87948712..d0d807aeb26e69 100644 --- a/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc +++ b/third_party/xla/xla/service/cpu/runtime_single_threaded_conv3d.cc @@ -15,8 +15,10 @@ limitations under the License. #include "xla/service/cpu/runtime_single_threaded_conv3d.h" +#include + #include "absl/base/dynamic_annotations.h" -#include "xla/service/cpu/runtime/conv_impl.h" +#include "xla/backends/cpu/runtime/convolution_thunk_internal.h" ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConv3DF32( @@ -31,14 +33,14 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF32( int64_t lhs_y_dilation, int64_t lhs_z_dilation, int64_t rhs_x_dilation, int64_t rhs_y_dilation, int64_t rhs_z_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } ABSL_ATTRIBUTE_NO_SANITIZE_MEMORY void @@ -54,12 +56,12 @@ __xla_cpu_runtime_EigenSingleThreadedConv3DF16( int64_t lhs_y_dilation, int64_t lhs_z_dilation, int64_t rhs_x_dilation, int64_t rhs_y_dilation, int64_t rhs_z_dilation, int64_t feature_group_count) { - tensorflow::xla::EigenConv3DImpl( + xla::cpu::internal::EigenConv3D( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_x, input_y, input_z, input_channels, kernel_x, kernel_y, kernel_z, kernel_channels, kernel_filters, output_x, output_y, output_z, x_stride, y_stride, z_stride, padding_x_before, padding_x_after, padding_y_before, padding_y_after, padding_z_before, padding_z_after, lhs_x_dilation, lhs_y_dilation, lhs_z_dilation, rhs_x_dilation, rhs_y_dilation, - rhs_z_dilation, feature_group_count); + rhs_z_dilation, feature_group_count, std::nullopt); } diff --git a/third_party/xla/xla/service/cpu/tests/BUILD b/third_party/xla/xla/service/cpu/tests/BUILD index 40a97495501877..be51e842c6b5e0 100644 --- a/third_party/xla/xla/service/cpu/tests/BUILD +++ b/third_party/xla/xla/service/cpu/tests/BUILD @@ -293,8 +293,8 @@ xla_cc_test( "//xla/service:hlo_module_config", "//xla/service/cpu:cpu_compiler", "//xla/service/cpu:test_header_helper", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -361,6 +361,7 @@ xla_cc_test( name = "onednn_matmul_test", srcs = ["onednn_matmul_test.cc"], copts = tsl_copts(), + shard_count = 4, tags = [ "no_oss", "notap", @@ -372,7 +373,7 @@ xla_cc_test( "//xla:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/service:cpu_plugin", - "//xla/service/cpu:onednn_matmul_rewriter", + "//xla/service/cpu:onednn_contraction_rewriter", "//xla/service/cpu:onednn_util", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -393,7 +394,7 @@ xla_cc_test( "//xla:test_helpers", "//xla/hlo/utils:hlo_matchers", "//xla/service:cpu_plugin", - "//xla/service/cpu:onednn_matmul_rewriter", + "//xla/service/cpu:onednn_contraction_rewriter", "//xla/service/cpu:onednn_util", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", diff --git a/third_party/xla/xla/service/cpu/tests/cpu_spmd_compile_test.cc b/third_party/xla/xla/service/cpu/tests/cpu_spmd_compile_test.cc index 6dc8cb9f7bb089..077e7eef9cc4f7 100644 --- a/third_party/xla/xla/service/cpu/tests/cpu_spmd_compile_test.cc +++ b/third_party/xla/xla/service/cpu/tests/cpu_spmd_compile_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/service/cpu/tests/cpu_codegen_test.h" #include "xla/service/executable.h" #include "xla/service/hlo_module_config.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc index 6428e31c1d2fbb..6bceebc7343c8e 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_convolution_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" #include "xla/service/cpu/onednn_util.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -34,6 +34,12 @@ namespace cpu { class ConvolutionTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* conv_rewrite_str_ = R"( ; CHECK: custom_call_target="__onednn$convolution", ; CHECK: backend_config={ diff --git a/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc index 9751e207b5e5da..92ca5061724faf 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_layer_norm_test.cc @@ -24,6 +24,12 @@ namespace { class LayerNormTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* onednn_layer_norm_ = R"( ; CHECK: custom_call_target="__onednn$layernorm", @@ -95,7 +101,7 @@ TEST_F(LayerNormTest, LayerNormTest0_FP32) { common_hlo_region_ + R"( ENTRY main { Arg_0.1 = f32[84,197,768]{2,1,0} parameter(0), sharding={replicated} - + )" + common_hlo_entry_computation_block_ + R"( ROOT add.338 = f32[84,197,768]{2,1,0} add(multiply.331, subtract.337) @@ -219,7 +225,7 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { ROOT add_0 = f32[] add(Arg_0, Arg_1) } ENTRY main { - Arg_2= f16[2,4,8] parameter(0), sharding={replicated} + Arg_2 = f16[2,4,8] parameter(0), sharding={replicated} convert_0 = f32[2,4,8] convert(Arg_2) constant_0 = f32[] constant(0) convert_1 = f32[] convert(constant_0) @@ -241,7 +247,7 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { constant_3 = s32[] constant(8) convert_6 = f32[] convert(constant_3) broadcast_2 = f32[2,4] broadcast(convert_6), dimensions={} - divide_1= f32[2,4] divide(reduce_1, broadcast_2) + divide_1 = f32[2,4] divide(reduce_1, broadcast_2) convert_7 = f16[2,4] convert(divide_1) reshape_2 = f16[2,4,1] reshape(convert_7) rsqrt_0 = f16[2,4,1] rsqrt(reshape_2) @@ -249,13 +255,13 @@ TEST_F(LayerNormTest, LayerNormTest2_F16) { broadcast_3 = f16[2,4,8] broadcast(reshape_3), dimensions={0,1} constant_4 = f16[8] constant({1,1,1,1,1,1,1,1}) broadcast_4 = f16[2,4,8] broadcast(constant_4), dimensions={2} - multiply_1 = f16[2,4,8] multiply(broadcast3, broadcast_4) + multiply_1 = f16[2,4,8] multiply(broadcast_3, broadcast_4) multiply_2 = f16[2,4,8] multiply(multiply_1, Arg_2) constant_5 = f16[8] constant({1,1,1,1,1,1,1,1}) broadcast_5 = f16[2,4,8] broadcast(constant_5), dimensions={2} reshape_4 = f16[2,4] reshape(reshape_0) - broadcast_5 = f16[2,4,8] broadcast(reshape_4), dimensions={0,1} - multiply_3 = f16[2,4,8] multiply(multiply_1, broadcast_5) + broadcast_6 = f16[2,4,8] broadcast(reshape_4), dimensions={0,1} + multiply_3 = f16[2,4,8] multiply(multiply_1, broadcast_6) subtract_1 = f16[2,4,8] subtract(broadcast_5, multiply_3) ROOT add_1 = f16[2,4,8] add(multiply_2, subtract_1) } diff --git a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc index 389716c4ddef95..57f7c09aba11e8 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_matmul_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" -#include "xla/service/cpu/onednn_matmul_rewriter.h" +#include "xla/service/cpu/onednn_contraction_rewriter.h" #include "xla/service/cpu/onednn_util.h" #include "xla/shape_util.h" #include "xla/test.h" @@ -36,6 +36,12 @@ namespace cpu { class MatmulTest : public HloTestBase { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* fused_matmul_bias_ = R"( ; CHECK: custom_call_target="__onednn$matmul", ; CHECK: backend_config={ @@ -225,7 +231,7 @@ TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion1) { TEST_F(MatmulTest, SimpleTestF32WithBiasAddFusion2) { const char* matmul_module_str = R"( HloModule matmul.biasadd.test.f32 - + ENTRY matmul.biasadd.test.f32 { arg0.1 = f32[400,300] parameter(0), parameter_replication={false} reshape.2 = f32[400,300] reshape(arg0.1) @@ -797,7 +803,9 @@ TEST_F(MatmulTest, TestNonScalarConstantEltwiseLinearF32) { ; CHECK: backend_config={ ; CHECK-DAG: "outer_dimension_partitions":[], ; CHECK-DAG: "onednn_matmul_config":{ - ; CHECK-NOT: "fused_ops":["LINEAR"] + ; CHECK-NOT: "fusions":{ + ; CHECK-NOT: "ops":["LINEAR"] + ; CHECK-NOT: } ; CHECK-DAG: } ; CHECK: } )"); @@ -1128,7 +1136,7 @@ TEST_F(MatmulTest, SIGMOIDTestF32) { const.0 = f32[32]{0} constant(5) bcast.0 = f32[32,32,4,32] broadcast(const.0), dimensions={3} add.0 = f32[32,32,4,32] add(onednn.matmul.0, bcast.0) - + const.1 = f32[] constant(1) bcast.1 = f32[32,32,4,32] broadcast(const.1), dimensions={} negate.0 = f32[32,32,4,32] negate(add.0) @@ -1149,7 +1157,7 @@ TEST_F(MatmulTest, SIGMOIDTestBF16) { } const char* matmul_module_str = R"( HloModule matmul.bias.sigmoid.test.bf16 - + ENTRY matmul.bias.sigmoid.test.bf16 { arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} convert.0 = bf16[32,32,4,16] convert(arg.0) @@ -1180,7 +1188,7 @@ TEST_F(MatmulTest, SIGMOIDTestF16) { } const char* matmul_module_str = R"( HloModule matmul.bias.sigmoid.test.f16 - + ENTRY matmul.bias.sigmoid.test.f16 { arg.0 = f32[32,32,4,16] parameter(0), parameter_replication={false} convert.0 = f16[32,32,4,16] convert(arg.0) @@ -1230,7 +1238,7 @@ TEST_F(MatmulTest, SimpleTestBF16Gemv2) { const char* matmul_module_str = R"( HloModule matmul.test.bf16 - + ENTRY matmul.test.bf16 { arg.0 = bf16[100,300,300] parameter(0) arg.1 = bf16[300] parameter(1) @@ -1493,47 +1501,44 @@ TEST_F(MatmulTest, WeightsPrepackAndScratch) { )"); } -TEST_F(MatmulTest, ConsecutiveBinaryAdd) { - const char* matmul_module_str = R"( - HloModule matmul.test.f32 - region_0.22 { - Arg_0.23 = f32[] parameter(0) - Arg_1.24 = f32[] parameter(1) - ROOT add.25 = f32[] add(Arg_0.23, Arg_1.24) +TEST_F(MatmulTest, ColMajorBF16DotBeforeLayoutAssignment) { + if (!IsSupportedType(PrimitiveType::BF16)) { + GTEST_SKIP() << "CPU does not support BF16."; } - region_1.29 { - Arg_0.30 = f32[] parameter(0) - Arg_1.31 = f32[] parameter(1) - ROOT add.32 = f32[] add(Arg_0.30, Arg_1.31) - } + const char* matmul_module_str = R"( + HloModule matmul.colmajor.test + ENTRY matmul.colmajor.test.bf16 { + arg.0 = bf16[500,500]{0,1} parameter(0) + arg.1 = bf16[500,500]{1,0} parameter(1) + transpose.0 = bf16[500,500]{0,1} transpose(arg.1), dimensions={1,0} + ROOT dot.0 = bf16[500,500]{1,0} dot(arg.0, arg.1), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + })"; - ENTRY main { - constant.2 = f32[] constant(1e-06) - broadcast.3 = f32[1000000] broadcast(constant.2), dimensions={} - constant.7 = f32[] constant(1) - broadcast.8 = f32[1000000,3] broadcast(constant.7), dimensions={} - Arg_0.1 = f32[3] parameter(0) - reshape.10 = f32[1,3] reshape(Arg_0.1) - broadcast.11 = f32[1,3] broadcast(reshape.10), dimensions={0,1} - reshape.12 = f32[3] reshape(broadcast.11) - broadcast.13 = f32[1000000,3] broadcast(reshape.12), dimensions={1} - subtract.14 = f32[1000000,3] subtract(broadcast.8, broadcast.13) - constant.4 = f32[] constant(0) - broadcast.5 = f32[3,3] broadcast(constant.4), dimensions={} - dot.15 = f32[1000000,3] dot(subtract.14, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={0} - dot.16 = f32[1000000,3] dot(broadcast.3, dot.15), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} - dot.17 = f32[1000000,3] dot(broadcast.3, subtract.14), lhs_batch_dims={0}, lhs_contracting_dims={}, rhs_batch_dims={0}, rhs_contracting_dims={} - dot.18 = f32[1000000,3] dot(dot.17, broadcast.5), lhs_contracting_dims={1}, rhs_contracting_dims={1} - add.19 = f32[1000000,3] add(dot.16, dot.18) - constant.9 = f32[3] constant({1, 2, 3}) - dot.20 = f32[1000000,3] dot(broadcast.3, constant.9), lhs_contracting_dims={}, rhs_contracting_dims={} - add.21 = f32[1000000,3] add(add.19, dot.20) - constant.6 = f32[] constant(0) - reduce.26 = f32[3] reduce(add.21, constant.6), dimensions={0}, to_apply=region_0.22 - reshape.27 = f32[1,3] reshape(reduce.26) - negate.28 = f32[1,3] negate(reshape.27) - ROOT reduce.33 = f32[3] reduce(negate.28, constant.6), dimensions={0}, to_apply=region_1.29 + EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec(1e-2, 1e-2))); + MatchOptimizedHlo(matmul_module_str, + R"( + ; CHECK: (bf16[500,500]{1,0}, u8[{{.*}}]{0}) + ; CHECK-SAME: custom_call_target="__onednn$matmul" + )"); +} + +TEST_F(MatmulTest, ConsecutiveBinaryAdd) { + const char* matmul_module_str = R"( + HloModule matmul.test.f32 + ENTRY matmul.test.f32 { + arg0.1 = f32[128,32,4,4] parameter(0) + arg0.2 = f32[128,32,4,4] parameter(1) + dot.7 = f32[128,32,4,4] dot(arg0.1, arg0.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} + const.0 = f32[128,32] constant({...}) + bcast.1 = f32[128,32,4,4] broadcast(const.0), dimensions={0,1} + add.0 = f32[128,32,4,4] add(dot.7,bcast.1) + const.1 = f32[4] constant({1,2,3,4}) + bcast.2 = f32[128,32,4,4] broadcast(const.1), dimensions={3} + add.1 = f32[128,32,4,4] add(add.0, bcast.2) + tuple.12 = (f32[128,32,4,4]) tuple(add.1) + ROOT get-tuple-element.13 = f32[128,32,4,4] get-tuple-element(tuple.12), index=0 })"; EXPECT_TRUE(RunAndCompare(matmul_module_str, ErrorSpec{1e-4, 1e-4})); @@ -1541,7 +1546,7 @@ TEST_F(MatmulTest, ConsecutiveBinaryAdd) { TEST_F(MatmulTest, BroadcastedAddAfterFusion) { const char* matmul_module_str = R"( - HloModule matmul.nonscalar.test.1 + HloModule matmul.nonscalar.test ENTRY matmul.nonscalar.test.f32 { arg.0 = f32[16,400,500] parameter(0) arg.1 = f32[16,500,3] parameter(1) diff --git a/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc b/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc index 124b4472024c17..1fff5d88a736e5 100644 --- a/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc +++ b/third_party/xla/xla/service/cpu/tests/onednn_softmax_test.cc @@ -47,13 +47,63 @@ class OneDnnSoftmaxTest : public HloTestBase, public ::testing::WithParamInterface> { protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_use_thunk_runtime(false); + return debug_options; + } + const char* onednn_softmax_ = R"( ; CHECK: custom_call_target="__onednn$softmax" )"; + // Get raw HLO text for generic softmax pattern, after replacing $0 with + // datatype and $1 with batch size. + const std::string GetGenericSoftmaxHLORawText(PrimitiveType data_type, + int batch_size) { + const std::string softmax_hlo_template_string = R"( + HloModule softmax_module + region_max { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT maximum = $0[] maximum(Arg_0, Arg_1) + } + region_add { + Arg_0 = $0[] parameter(0) + Arg_1 = $0[] parameter(1) + ROOT add = $0[] add(Arg_0, Arg_1) + } + ENTRY main { + Arg_0 = $0[$1,128,30522]{2,1,0} parameter(0) + neg_inf = $0[] constant(-inf) + reduce_max = $0[$1,128]{1,0} reduce(Arg_0, neg_inf), dimensions={2}, to_apply=region_max + reshape.0 = $0[$1,128,1]{2,1,0} reshape(reduce_max) + broadcast.0 = $0[$1,128,1]{2,1,0} broadcast(reshape.0), dimensions={0,1,2} + reshape.1 = $0[$1,128]{1,0} reshape(broadcast.0) + broadcast.1 = $0[$1,128,30522]{2,1,0} broadcast(reshape.1), dimensions={0,1} + subtract.0 = $0[$1,128,30522]{2,1,0} subtract(Arg_0, broadcast.1) + exponential = $0[$1,128,30522]{2,1,0} exponential(subtract.0) + const_zero = $0[] constant(0) + reduce_add = $0[$1,128]{1,0} reduce(exponential, const_zero), dimensions={2}, to_apply=region_add + reshape.2 = $0[$1,128,1]{2,1,0} reshape(reduce_add) + broadcast.2 = $0[$1,128,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2} + reshape.3 = $0[$1,128]{1,0} reshape(broadcast.2) + broadcast.3 = $0[$1,128,30522]{2,1,0} broadcast(reshape.3), dimensions={0,1} + ROOT divide = $0[$1,128,30522]{2,1,0} divide(exponential, broadcast.3) + } + )"; + + const std::string softmax_hlo_string = absl::Substitute( + softmax_hlo_template_string, + primitive_util::LowercasePrimitiveTypeName(data_type), batch_size); + + return softmax_hlo_string; + } + // Test pattern match with OneDnnOpsRewriter pass - void TestSoftmax(std::string input_hlo_string, int expected_softmax_axis) { + void TestSoftmaxPatternMatching(std::string input_hlo_string, + int expected_softmax_axis) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(input_hlo_string)); OneDnnOpsRewriter softmax_rewrite_pass; @@ -74,6 +124,7 @@ class OneDnnSoftmaxTest }; // Softmax test with last dimension as axis. In this case, axis = 2 +// This test is to make sure the pattern matching works as expected TEST_P(OneDnnSoftmaxTest, SoftmaxGenericTest) { PrimitiveType data_type; int batch_size; @@ -82,44 +133,44 @@ TEST_P(OneDnnSoftmaxTest, SoftmaxGenericTest) { GTEST_SKIP() << "CPU does not support " << primitive_util::LowercasePrimitiveTypeName(data_type); } + const std::string softmax_hlo_string = + GetGenericSoftmaxHLORawText(data_type, batch_size); - const std::string softmax_hlo_template_string = R"( + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis*/ 2); +} + +// Generic Softmax test with last dimension as axis. In this case, axis = 2 +// This test to make sure the accuracy is fine with onednn softmax custom call +TEST_P(OneDnnSoftmaxTest, SoftmaxGenericNumericalCorrectnessTest) { + PrimitiveType data_type; + int batch_size; + std::tie(data_type, batch_size) = GetParam(); + if (!IsSupportedType(data_type)) { + GTEST_SKIP() << "CPU does not support " + << primitive_util::LowercasePrimitiveTypeName(data_type); + } + + const std::string onednn_softmax_hlo_template_string = R"( HloModule softmax_module - region_max { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT maximum = $0[] maximum(Arg_0, Arg_1) - } - region_add { - Arg_0 = $0[] parameter(0) - Arg_1 = $0[] parameter(1) - ROOT add = $0[] add(Arg_0, Arg_1) - } ENTRY main { Arg_0 = $0[$1,128,30522]{2,1,0} parameter(0) - neg_inf = $0[] constant(-inf) - reduce_max = $0[$1,128]{1,0} reduce(Arg_0, neg_inf), dimensions={2}, to_apply=region_max - reshape.0 = $0[$1,128,1]{2,1,0} reshape(reduce_max) - broadcast.0 = $0[$1,128,1]{2,1,0} broadcast(reshape.0), dimensions={0,1,2} - reshape.1 = $0[$1,128]{1,0} reshape(broadcast.0) - broadcast.1 = $0[$1,128,30522]{2,1,0} broadcast(reshape.1), dimensions={0,1} - subtract.0 = $0[$1,128,30522]{2,1,0} subtract(Arg_0, broadcast.1) - exponential = $0[$1,128,30522]{2,1,0} exponential(subtract.0) - const_zero = $0[] constant(0) - reduce_add = $0[$1,128]{1,0} reduce(exponential, const_zero), dimensions={2}, to_apply=region_add - reshape.2 = $0[$1,128,1]{2,1,0} reshape(reduce_add) - broadcast.2 = $0[$1,128,1]{2,1,0} broadcast(reshape.2), dimensions={0,1,2} - reshape.3 = $0[$1,128]{1,0} reshape(broadcast.2) - broadcast.3 = $0[$1,128,30522]{2,1,0} broadcast(reshape.3), dimensions={0,1} - ROOT divide = $0[$1,128,30522]{2,1,0} divide(exponential, broadcast.3) + ROOT custom-call = $0[$1,128,30522]{2,1,0} custom-call(Arg_0), custom_call_target="$2", backend_config={"onednn_softmax_config":{"softmax_axis":2}} } )"; - const std::string softmax_hlo_string = absl::Substitute( - softmax_hlo_template_string, - primitive_util::LowercasePrimitiveTypeName(data_type), batch_size); + auto onednn_softmax_hlo_string = + absl::Substitute(onednn_softmax_hlo_template_string, + primitive_util::LowercasePrimitiveTypeName(data_type), + batch_size, "__onednn$softmax"); + const std::string hlo_string_ref = + GetGenericSoftmaxHLORawText(data_type, batch_size); + + float atol = (data_type == F32) ? 1e-4 : 1e-2; + float rtol = (data_type == F32) ? 1e-4 : 1e-2; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis*/ 2); + EXPECT_TRUE(RunAndCompareTwoModules(onednn_softmax_hlo_string, hlo_string_ref, + ErrorSpec{atol, rtol}, + /*run_hlo_passes=*/false)); } INSTANTIATE_TEST_SUITE_P(OneDnnSoftmaxTestSuite, OneDnnSoftmaxTest, @@ -163,7 +214,7 @@ TEST_F(OneDnnSoftmaxTest, SoftmaxFP32OnAxisZero) { } )"; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis*/ 0); + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis*/ 0); } TEST_F(OneDnnSoftmaxTest, SoftmaxWithBF16ConvertOutputFP32Pattern) { @@ -204,7 +255,7 @@ TEST_F(OneDnnSoftmaxTest, SoftmaxWithBF16ConvertOutputFP32Pattern) { } )"; - TestSoftmax(softmax_hlo_string, /*expected_softmax_axis=*/2); + TestSoftmaxPatternMatching(softmax_hlo_string, /*expected_softmax_axis=*/2); } } // namespace cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.cc b/third_party/xla/xla/service/cpu/thunk_emitter.cc index 7d3c9c558021d4..d2be391fa9fe8f 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.cc +++ b/third_party/xla/xla/service/cpu/thunk_emitter.cc @@ -28,6 +28,29 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/all_gather_thunk.h" +#include "xla/backends/cpu/runtime/all_reduce_thunk.h" +#include "xla/backends/cpu/runtime/all_to_all_thunk.h" +#include "xla/backends/cpu/runtime/call_thunk.h" +#include "xla/backends/cpu/runtime/collective_permute_thunk.h" +#include "xla/backends/cpu/runtime/collective_thunk.h" +#include "xla/backends/cpu/runtime/conditional_thunk.h" +#include "xla/backends/cpu/runtime/convolution_thunk.h" +#include "xla/backends/cpu/runtime/copy_thunk.h" +#include "xla/backends/cpu/runtime/custom_call_thunk.h" +#include "xla/backends/cpu/runtime/dot_thunk.h" +#include "xla/backends/cpu/runtime/fft_thunk.h" +#include "xla/backends/cpu/runtime/infeed_thunk.h" +#include "xla/backends/cpu/runtime/kernel_thunk.h" +#include "xla/backends/cpu/runtime/logical_id_thunk.h" +#include "xla/backends/cpu/runtime/outfeed_thunk.h" +#include "xla/backends/cpu/runtime/reduce_scatter_thunk.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/rng_state_thunk.h" +#include "xla/backends/cpu/runtime/sort_thunk.h" +#include "xla/backends/cpu/runtime/thunk.h" +#include "xla/backends/cpu/runtime/topk_thunk.h" +#include "xla/backends/cpu/runtime/while_thunk.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -42,29 +65,6 @@ limitations under the License. #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/ir_emission_utils.h" #include "xla/service/cpu/ir_emitter2.h" -#include "xla/service/cpu/runtime/all_gather_thunk.h" -#include "xla/service/cpu/runtime/all_reduce_thunk.h" -#include "xla/service/cpu/runtime/all_to_all_thunk.h" -#include "xla/service/cpu/runtime/call_thunk.h" -#include "xla/service/cpu/runtime/collective_permute_thunk.h" -#include "xla/service/cpu/runtime/collective_thunk.h" -#include "xla/service/cpu/runtime/conditional_thunk.h" -#include "xla/service/cpu/runtime/convolution_thunk.h" -#include "xla/service/cpu/runtime/copy_thunk.h" -#include "xla/service/cpu/runtime/custom_call_thunk.h" -#include "xla/service/cpu/runtime/dot_thunk.h" -#include "xla/service/cpu/runtime/fft_thunk.h" -#include "xla/service/cpu/runtime/infeed_thunk.h" -#include "xla/service/cpu/runtime/kernel_thunk.h" -#include "xla/service/cpu/runtime/logical_id_thunk.h" -#include "xla/service/cpu/runtime/outfeed_thunk.h" -#include "xla/service/cpu/runtime/reduce_scatter_thunk.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/rng_state_thunk.h" -#include "xla/service/cpu/runtime/sort_thunk.h" -#include "xla/service/cpu/runtime/thunk.h" -#include "xla/service/cpu/runtime/topk_thunk.h" -#include "xla/service/cpu/runtime/while_thunk.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" @@ -516,9 +516,9 @@ absl::StatusOr ThunkEmitter::EmitConcatenateKernelThunk( ir_emitter_.EmitConcatenateHostKernel(concatenate)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitGetDimensionSizeThunk( @@ -609,9 +609,9 @@ absl::StatusOr ThunkEmitter::EmitElementalKernelThunk( ir_emitter_.EmitElementalHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitPadKernelThunk( @@ -620,9 +620,9 @@ absl::StatusOr ThunkEmitter::EmitPadKernelThunk( TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitPadHostKernel(padInstr)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(padInstr)); - return ThunkSequence::Of( - ThunkInfo(padInstr), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + padInstr, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( @@ -631,9 +631,9 @@ absl::StatusOr ThunkEmitter::EmitFusionKernelThunk( TF_ASSIGN_OR_RETURN(auto kernel, ir_emitter_.EmitFusionHostKernel(fusion)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( @@ -642,9 +642,9 @@ absl::StatusOr ThunkEmitter::EmitReductionKernelThunk( ir_emitter_.EmitReductionHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitRngThunk( @@ -755,9 +755,19 @@ absl::StatusOr ThunkEmitter::EmitWhileThunk( TF_ASSIGN_OR_RETURN(ThunkSequence body_thunk, EmitHloComputation(instruction->while_body())); + // Check if while loop has a statically known trip count. + TF_ASSIGN_OR_RETURN( + auto loop_config, + instruction->backend_config()); + + std::optional trip_count; + if (loop_config.has_known_trip_count()) { + trip_count = loop_config.known_trip_count().n(); + } + return ThunkSequence::Of(ThunkInfo(instruction), cond_buffer, std::move(cond_thunk), - std::move(body_thunk)); + std::move(body_thunk), trip_count); } absl::StatusOr ThunkEmitter::EmitDotThunk( @@ -789,9 +799,7 @@ absl::StatusOr ThunkEmitter::EmitDotThunk( TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } // Emit DotThunk implementing dot instruction as a library call. @@ -970,9 +978,9 @@ absl::StatusOr ThunkEmitter::EmitSliceToDynamicThunk( ir_emitter_.EmitSliceToDynamicHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of( - ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, - kernel.thread_dims, /*min_alignment=*/cpu_function_runtime::MinAlign()); + return MakeKernelThunkSequence( + instruction, buffers, kernel, + /*min_alignment=*/cpu_function_runtime::MinAlign()); } absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( @@ -981,9 +989,7 @@ absl::StatusOr ThunkEmitter::EmitSelectAndScatterThunk( ir_emitter_.EmitSelectAndScatterHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } absl::StatusOr ThunkEmitter::EmitSliceThunk( @@ -1000,9 +1006,7 @@ absl::StatusOr ThunkEmitter::EmitDynamicUpdateSliceThunk( auto kernel, ir_emitter_.EmitDynamicUpdateSliceHostKernel(instruction)); TF_ASSIGN_OR_RETURN(auto buffers, GetHostKernelAllocationSlices(instruction)); - return ThunkSequence::Of(ThunkInfo(instruction), - buffers.arguments, buffers.results, - kernel.name, kernel.thread_dims); + return MakeKernelThunkSequence(instruction, buffers, kernel); } absl::StatusOr ThunkEmitter::EmitSortThunk( @@ -1088,4 +1092,14 @@ absl::Status ThunkEmitter::ElementTypesSameAndSupported( return absl::OkStatus(); } +absl::StatusOr ThunkEmitter::MakeKernelThunkSequence( + const HloInstruction* instruction, + const ThunkEmitter::HostKernelAllocationSlices& buffers, + const IrEmitter2::KernelInfo& kernel, + std::optional min_alignment) { + return ThunkSequence::Of( + ThunkInfo(instruction), buffers.arguments, buffers.results, kernel.name, + kernel.thread_dims, kernel.invariant_buffers, min_alignment); +} + } // namespace xla::cpu diff --git a/third_party/xla/xla/service/cpu/thunk_emitter.h b/third_party/xla/xla/service/cpu/thunk_emitter.h index 6921f76e75179b..ad6eb8863b5ee6 100644 --- a/third_party/xla/xla/service/cpu/thunk_emitter.h +++ b/third_party/xla/xla/service/cpu/thunk_emitter.h @@ -16,21 +16,23 @@ limitations under the License. #ifndef XLA_SERVICE_CPU_THUNK_EMITTER_H_ #define XLA_SERVICE_CPU_THUNK_EMITTER_H_ +#include #include +#include #include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/backends/cpu/runtime/resource_use.h" +#include "xla/backends/cpu/runtime/thunk.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/buffer_assignment.h" #include "xla/service/cpu/ir_emitter2.h" -#include "xla/service/cpu/runtime/resource_use.h" -#include "xla/service/cpu/runtime/thunk.h" #include "xla/service/cpu/target_machine_features.h" #include "xla/service/hlo_module_config.h" #include "xla/shape_util.h" @@ -195,6 +197,13 @@ class ThunkEmitter { absl::Span operands, absl::Span supported_types); + // Convenience function that creates a thunk sequence containing given kernel. + static absl::StatusOr MakeKernelThunkSequence( + const HloInstruction* instruction, + const ThunkEmitter::HostKernelAllocationSlices& buffers, + const IrEmitter2::KernelInfo& kernel, + std::optional min_alignment = std::nullopt); + IrEmitter2& ir_emitter_; const BufferAssignment& buffer_assignment_; diff --git a/third_party/xla/xla/service/cpu/xfeed_manager_test.cc b/third_party/xla/xla/service/cpu/xfeed_manager_test.cc index 5c6be64e6a7dd1..5b682e207386d8 100644 --- a/third_party/xla/xla/service/cpu/xfeed_manager_test.cc +++ b/third_party/xla/xla/service/cpu/xfeed_manager_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/service/cpu/cpu_runtime.h" #include "xla/shape_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc index 0d4317baeda146..534b3590c74f0a 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier.cc @@ -31,6 +31,7 @@ absl::Status VerifyS4U4Usage(HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kBitcast: case HloOpcode::kBroadcast: + case HloOpcode::kCall: case HloOpcode::kConstant: case HloOpcode::kConcatenate: case HloOpcode::kConvert: diff --git a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc index 277143e89b7189..ec4d07e35106c1 100644 --- a/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc +++ b/third_party/xla/xla/service/cpu_gpu_shape_verifier_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -65,10 +65,15 @@ TEST_F(CpuGpuShapeVerifierTest, Int4SupportedInstruction) { const char* const hlo_string = R"( HloModule Module - ENTRY main { + bcast { p0 = u4[] parameter(0) ROOT out = u4[3, 3] broadcast(p0), dimensions={} } + + ENTRY main { + p0 = u4[] parameter(0) + ROOT out = u4[3, 3] call(p0), to_apply=bcast + } )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo_string)); diff --git a/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc b/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc index df05bef5a4397f..2fe22688ee2018 100644 --- a/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc +++ b/third_party/xla/xla/service/dfs_hlo_visitor_with_default_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/dot_as_convolution_util.cc b/third_party/xla/xla/service/dot_as_convolution_util.cc index e22dddcf7cee68..25d6b6a48c9d48 100644 --- a/third_party/xla/xla/service/dot_as_convolution_util.cc +++ b/third_party/xla/xla/service/dot_as_convolution_util.cc @@ -129,6 +129,9 @@ bool SpatialIsContracting(int64_t lhs_spatial_size, int64_t rhs_spatial_size, } } + dims.lhs_shape_rank = conv->operand(0)->shape().rank(); + dims.rhs_shape_rank = conv->operand(1)->shape().rank(); + dims.output_shape_rank = conv->shape().rank(); return dims; } @@ -224,6 +227,10 @@ DotConvolutionDimsInfo ParseDotGeneralFromDot(const HloInstruction* dot) { dnums.rhs_non_contracting_dims.back().spatial_dim = -1; } } + + dnums.lhs_shape_rank = dot->operand(0)->shape().rank(); + dnums.rhs_shape_rank = dot->operand(1)->shape().rank(); + dnums.output_shape_rank = dot->shape().rank(); return dnums; } diff --git a/third_party/xla/xla/service/dot_as_convolution_util.h b/third_party/xla/xla/service/dot_as_convolution_util.h index 01236f8c7ec9d1..9bed16990fc204 100644 --- a/third_party/xla/xla/service/dot_as_convolution_util.h +++ b/third_party/xla/xla/service/dot_as_convolution_util.h @@ -16,8 +16,8 @@ limitations under the License. #ifndef XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ #define XLA_SERVICE_DOT_AS_CONVOLUTION_UTIL_H_ +#include #include -#include #include #include "xla/hlo/ir/hlo_instruction.h" @@ -55,6 +55,10 @@ struct DotConvolutionDimsInfo { std::vector lhs_non_contracting_dims; std::vector rhs_non_contracting_dims; std::vector conv_spatial_dims; + + int64_t lhs_shape_rank; + int64_t rhs_shape_rank; + int64_t output_shape_rank; }; // Parses a convolution and returns a DotGeneralAsConvolutionDimsInfo. If it can diff --git a/third_party/xla/xla/service/dot_merger_test.cc b/third_party/xla/xla/service/dot_merger_test.cc index 97b9da0d0c279d..786970e7904f96 100644 --- a/third_party/xla/xla/service/dot_merger_test.cc +++ b/third_party/xla/xla/service/dot_merger_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/dump.cc b/third_party/xla/xla/service/dump.cc index 13b3032ac4ca69..3aa3a8862011a3 100644 --- a/third_party/xla/xla/service/dump.cc +++ b/third_party/xla/xla/service/dump.cc @@ -50,10 +50,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/hlo_proto_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "tsl/lib/io/zlib_compression_options.h" #include "tsl/lib/io/zlib_outputbuffer.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -64,6 +64,24 @@ limitations under the License. namespace xla { +absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env) { + if (!env->IsDirectory(dir).ok()) { + absl::Status status = env->RecursivelyCreateDir(dir); + // Two threads can race to observe the absence of the dump directory and + // simultaneously try to create it, causing the "losing" thread to get a + // "directory already exists" error. We can work around this by checking + // again whether the dir exists. + if (!status.ok()) { + status = env->IsDirectory(dir); + if (!status.ok()) { + LOG(ERROR) << "Could not create directory " << dir; + return status; + } + } + } + return absl::OkStatus(); +} + std::string RenderGraph(absl::string_view label, const HloModule& module, RenderedGraphFormat format, bool show_fusion_subcomputations) { @@ -299,17 +317,8 @@ static std::optional GetDumpFilePath( VLOG(1) << "Dumping " << filename << " to " << dir; tsl::Env* env = tsl::Env::Default(); - // Two threads can race to observe the absence of the dump directory and - // simultaneously try to create it, causing the "losing" thread to get a - // "directory already exists" error. We can work around this by checking - // again whether the dir exists. - if (!env->IsDirectory(dir).ok()) { - auto status = env->RecursivelyCreateDir(dir); - if (!status.ok() && !env->IsDirectory(dir).ok()) { - LOG(ERROR) << "Could not create directory " << dir - << " for dumping XLA debug data: " << status; - return std::nullopt; - } + if (!CreateDirIfNeeded(dir, env).ok()) { + return std::nullopt; } // Make sure we are not going to dump more modules than the user has asked. @@ -677,15 +686,7 @@ void DumpProtobufToFile(const tsl::protobuf::Message& proto, if (dir.empty()) { return; } - if (!env->IsDirectory(dir).ok()) { - auto status = env->RecursivelyCreateDir(dir); - if (!status.ok()) { - LOG(ERROR) << "Could not create directory " << dir - << " for dumping: " << status; - return; - } - } - if (!env->IsDirectory(dir).ok()) { + if (!CreateDirIfNeeded(dir, env).ok()) { return; } const std::string path = tsl::io::JoinPath(dir, filename); @@ -884,4 +885,20 @@ void DumpHloModuleMetadataIfEnabled(const std::vector& modules) { } } +absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, + const std::string& directory, + const std::string& file_name, + std::string* full_path) { + tsl::Env* env = tsl::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + TF_RETURN_IF_ERROR(CreateDirIfNeeded(directory, env)); + std::string safe_file_name = SanitizeFileName(file_name) + ".pb"; + std::string full_path_impl; + if (!full_path) { + full_path = &full_path_impl; + } + *full_path = tsl::io::JoinPath(directory, safe_file_name); + return tsl::WriteBinaryProto(env, *full_path, message); +} + } // namespace xla diff --git a/third_party/xla/xla/service/dump.h b/third_party/xla/xla/service/dump.h index 0b1a6d2891d0ad..623e7298fb9306 100644 --- a/third_party/xla/xla/service/dump.h +++ b/third_party/xla/xla/service/dump.h @@ -43,6 +43,10 @@ constexpr char kAfterOptimizationsDumpName[] = "after_optimizations"; class BufferAssignment; class HloSnapshot; +// Creates dir if doesn't exist (analogue of `mkdir -p`), tries to get around +// race conditions by trying again on collision. +absl::Status CreateDirIfNeeded(const std::string& dir, tsl::Env* env); + // Get a timestamp which we can use as a filename prefix specific to this // module. std::string TimestampFor(const HloModule& module); @@ -173,6 +177,18 @@ inline bool DumpingEnabledForHloModule(const HloModule& module) { // writing to two files, but you don't want to print twice. bool DumpingToStdout(const DebugOptions& opts); +// Writes the given message in binary proto to the path formed by joining +// 'directory/file_name.pb'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing +// illegal characters with underscore '_'. +// +// If 'full_name' is not null then it is set to the name of the file the +// protobuf was written to. +absl::Status DumpProtoToDirectory(const tsl::protobuf::Message& message, + const std::string& directory, + const std::string& file_name, + std::string* full_path = nullptr); + } // namespace xla #endif // XLA_SERVICE_DUMP_H_ diff --git a/third_party/xla/xla/service/dump_test.cc b/third_party/xla/xla/service/dump_test.cc index 78adadf3ea8d1b..6df547d96fdfce 100644 --- a/third_party/xla/xla/service/dump_test.cc +++ b/third_party/xla/xla/service/dump_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc index 502f0079000948..9dc9de161aa4bd 100644 --- a/third_party/xla/xla/service/dynamic_dimension_inference_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_inference_test.cc @@ -30,8 +30,8 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc b/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc index 94e48eca1104e3..2131c6c002a3e5 100644 --- a/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc +++ b/third_party/xla/xla/service/dynamic_dimension_simplifier_test.cc @@ -37,10 +37,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/dynamic_padder_test.cc b/third_party/xla/xla/service/dynamic_padder_test.cc index 3e3efa1a1832c6..972bc38ae8c40b 100644 --- a/third_party/xla/xla/service/dynamic_padder_test.cc +++ b/third_party/xla/xla/service/dynamic_padder_test.cc @@ -53,9 +53,9 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/llvm_irgen_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/dynamic_parameter_binding_test.cc b/third_party/xla/xla/service/dynamic_parameter_binding_test.cc index 11dfbcdbec9617..94eaf4e5166bce 100644 --- a/third_party/xla/xla/service/dynamic_parameter_binding_test.cc +++ b/third_party/xla/xla/service/dynamic_parameter_binding_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/elemental_ir_emitter.cc b/third_party/xla/xla/service/elemental_ir_emitter.cc index be2101eb275c30..a959444d99c4c0 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter.cc @@ -30,10 +30,14 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Support/MathExtras.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -60,6 +64,11 @@ limitations under the License. namespace xla { using absl::StrCat; +using llvm::PatternMatch::m_BitCast; +using llvm::PatternMatch::m_Intrinsic; +using llvm::PatternMatch::m_Select; +using llvm::PatternMatch::m_Value; +using llvm::PatternMatch::match; using llvm_ir::IrArray; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; @@ -713,6 +722,48 @@ absl::StatusOr ElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType from_type = op->operand(0)->shape().element_type(); PrimitiveType to_type = op->shape().element_type(); CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type; + + // LLVM optimizes away `fpcast` and `fpext` operations and optimized + // LLVM IR has arithmetic operations on `bfloat16` that are not natively + // supported on any of the CPUs, and LLVM inserts very expensive calls to + // fp conversion functions around bf16 operations. To avoid this, we use + // bitcasts and shifts to convert bf16 to f32 and back using truncation + // with rounding, and suppress LLVM optimizations that hurt performance. + // This is enabled explicitly by a flag only for XLA:CPU backend. + if (options_.xla_cpu_use_truncate_f32_to_bf16_conversion) { + if (from_type == F32 && to_type == BF16) { + // This implementation is based on Eigen `float_to_bfloat16_rtne` with + // a special case for nans. + auto* i32 = b_->CreateBitCast(operand_value, b_->getInt32Ty()); + + // Rounding bias for non-nan values. + auto* lsb = + b_->CreateAnd(b_->CreateLShr(i32, 16), + llvm::ConstantInt::get(b_->getInt32Ty(), 1)); + auto* rounding_bias = b_->CreateAdd( + llvm::ConstantInt::get(b_->getInt32Ty(), 0x7fff), lsb); + + // For nan values, we simply truncate the original value. + auto* is_nan = + b_->createIsFPClass(operand_value, llvm::FPClassTest::fcNan); + auto* i16 = b_->CreateTrunc( + b_->CreateLShr( + b_->CreateSelect(is_nan, i32, + b_->CreateAdd(i32, rounding_bias)), + 16), + b_->getInt16Ty()); + + return b_->CreateBitCast(i16, b_->getBFloatTy()); + } + if (from_type == BF16 && to_type == F32) { + auto* i16 = b_->CreateBitCast(operand_value, b_->getInt16Ty()); + auto* i32 = b_->CreateZExt(i16, b_->getInt32Ty()); + auto* i32s = b_->CreateShl(i32, 16); + auto* f32 = b_->CreateBitCast(i32s, b_->getFloatTy()); + return f32; + } + } + if (from_type == to_type) { return operand_value; } diff --git a/third_party/xla/xla/service/elemental_ir_emitter.h b/third_party/xla/xla/service/elemental_ir_emitter.h index bdc4559c6f3498..bb74bb80eaf4f1 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter.h +++ b/third_party/xla/xla/service/elemental_ir_emitter.h @@ -37,11 +37,21 @@ namespace xla { class ElementalIrEmitter : public IrBuilderMixin { public: + struct Options { + // Instead of relying on builtin `fpext` and `fpcast` emit a bitcast and + // truncate to convert f32 to bf16 (and emit extend to convert bf16 to f32). + bool xla_cpu_use_truncate_f32_to_bf16_conversion = false; + }; + using HloToElementGeneratorMap = absl::flat_hash_map; + ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b, + const Options& options) + : b_(b), module_(module), options_(options) {} + ElementalIrEmitter(llvm::Module* module, llvm::IRBuilder<>* b) - : b_(b), module_(module) {} + : ElementalIrEmitter(module, b, Options()) {} virtual ~ElementalIrEmitter() = default; @@ -321,6 +331,8 @@ class ElementalIrEmitter : public IrBuilderMixin { llvm::Module* module_; + Options options_; + friend class ElementalIrEmitterForTests; }; diff --git a/third_party/xla/xla/service/executable.cc b/third_party/xla/xla/service/executable.cc index ed86114607cf6f..aa81fce3e80e1c 100644 --- a/third_party/xla/xla/service/executable.cc +++ b/third_party/xla/xla/service/executable.cc @@ -25,7 +25,6 @@ limitations under the License. #include "xla/service/maybe_owning_device_memory.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/service/flatten_call_graph_test.cc b/third_party/xla/xla/service/flatten_call_graph_test.cc index 57498209c756c1..0a8be831355a5e 100644 --- a/third_party/xla/xla/service/flatten_call_graph_test.cc +++ b/third_party/xla/xla/service/flatten_call_graph_test.cc @@ -23,9 +23,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/fuzzy_matcher.h b/third_party/xla/xla/service/fuzzy_matcher.h new file mode 100644 index 00000000000000..6e5cd3e09eee5e --- /dev/null +++ b/third_party/xla/xla/service/fuzzy_matcher.h @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_FUZZY_MATCHER_H_ +#define XLA_SERVICE_FUZZY_MATCHER_H_ + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/pattern_matcher.h" + +namespace xla { + +// Fuzzy matchers for HLOs. +namespace fm { + +// TODO(b/355972677): Extend this to support opcodes other than convert +template +auto OptConvert(Pattern pattern) { + auto shared = match::SharedSubpattern(pattern); + return match::AnyOf(match::Convert(shared), shared); +} + +#define XLA_FUZZY_UNOP_PATTERN(NAME) \ + template \ + inline auto NAME(HloInstructionType** matched_inst) { \ + return OptConvert(match::Op(matched_inst).WithOpcode(HloOpcode::k##NAME)); \ + } \ + \ + template \ + inline auto NAME(Arg&& arg) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg))); \ + } +XLA_FUZZY_UNOP_PATTERN(Tanh) +XLA_FUZZY_UNOP_PATTERN(Exp) +XLA_FUZZY_UNOP_PATTERN(Broadcast) +#undef XLA_FUZZY_UNOP_PATTERN + +#define XLA_FUZZY_BINOP_PATTERN(NAME) \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))); \ + } \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs))); \ + } +XLA_FUZZY_BINOP_PATTERN(Dot) +XLA_FUZZY_BINOP_PATTERN(Divide) +XLA_FUZZY_BINOP_PATTERN(Subtract) +XLA_FUZZY_BINOP_PATTERN(Multiply) +// Currently we only use binary matcher for reduce. +XLA_FUZZY_BINOP_PATTERN(Reduce) +#undef XLA_FUZZY_BINOP_PATTERN + +#define XLA_FUZZY_TERNOP_PATTERN(NAME) \ + template \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \ + return OptConvert(match::Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ + Arg1&& arg1, Arg2&& arg2) { \ + return OptConvert(match::Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg0)) \ + .WithOperand(1, std::forward(arg1)) \ + .WithOperand(2, std::forward(arg2))); \ + } +XLA_FUZZY_TERNOP_PATTERN(Select); +#undef XLA_FUZZY_TERNOP_PATTERN + +} // namespace fm + +} // namespace xla + +#endif // XLA_SERVICE_FUZZY_MATCHER_H_ diff --git a/third_party/xla/xla/service/fuzzy_matcher_test.cc b/third_party/xla/xla/service/fuzzy_matcher_test.cc new file mode 100644 index 00000000000000..ac97d13233aa52 --- /dev/null +++ b/third_party/xla/xla/service/fuzzy_matcher_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/fuzzy_matcher.h" + +#include +#include "xla/service/pattern_matcher.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using FuzzyMatcherTest = HloTestBase; + +TEST_F(FuzzyMatcherTest, IgnoreConvert) { + constexpr char kModuleStr[] = R"( + HloModule test_module + ENTRY test { + x = f16[8,3] parameter(0) + y = f16[8,3] parameter(1) + div = f16[8,3] divide(x, y) + ROOT convert = f32[8,3] convert(div) + })"; + TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, + ParseAndReturnVerifiedModule(kModuleStr)); + auto* root = hlo_module->entry_computation()->root_instruction(); + EXPECT_TRUE( + Match(root, fm::Divide(match::Parameter(0), match::Parameter(1)))); +} + +} // namespace + +} // namespace xla diff --git a/third_party/xla/xla/service/generic_transfer_manager_test.cc b/third_party/xla/xla/service/generic_transfer_manager_test.cc index 41ea92d46a0385..eb8cb7afa85004 100644 --- a/third_party/xla/xla/service/generic_transfer_manager_test.cc +++ b/third_party/xla/xla/service/generic_transfer_manager_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index b1675f63aef9b9..2bf4315a0c9cf5 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -36,7 +36,6 @@ load("//xla/tests:build_defs.bzl", "xla_test") load( "//xla/tsl:tsl.bzl", "if_google", - "if_oss", "internal_visibility", "tsl_copts", "tsl_gpu_library", @@ -65,6 +64,8 @@ filegroup( ]), ) +exports_files(srcs = ["gpu_compiler_test_autotune_db.textproto"]) + tf_proto_library( name = "backend_configs", srcs = ["backend_configs.proto"], @@ -168,6 +169,7 @@ xla_test( "//xla/stream_executor/gpu:gpu_types_header", "//xla/tests:client_library_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -175,7 +177,6 @@ xla_test( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ] + if_cuda_is_configured([ @@ -312,7 +313,6 @@ cc_library( ":execution_stream_assignment", ":gpu_asm_opts_util", ":gpu_conv_runner", - ":gpu_fused_mha_runner", ":gpu_norm_runner", ":hlo_fusion_analysis", ":ir_emission_utils", @@ -355,9 +355,9 @@ cc_library( "//xla/service/gpu/runtime:conditional_thunk", "//xla/service/gpu/runtime:convolution_thunk", "//xla/service/gpu/runtime:copy_thunk", + "//xla/service/gpu/runtime:cudnn_thunk", "//xla/service/gpu/runtime:custom_call_thunk", "//xla/service/gpu/runtime:fft_thunk", - "//xla/service/gpu/runtime:fused_mha_thunk", "//xla/service/gpu/runtime:gemm_thunk", "//xla/service/gpu/runtime:gpublas_lt_matmul_thunk", "//xla/service/gpu/runtime:infeed_thunk", @@ -483,135 +483,6 @@ cc_library( ], ) -cc_library( - name = "gemm_fusion_autotuner", - srcs = if_cuda_is_configured(["gemm_fusion_autotuner.cc"]), - hdrs = if_cuda_is_configured(["gemm_fusion_autotuner.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_cuda_is_configured([ - ":autotuner_compile_util", - ":autotuner_util", - ":backend_configs_cc", - ":buffer_comparator", - ":gemm_rewriter", - ":gpu_float_support", - ":gpu_fusible", - ":instruction_fusion", - ":ir_emission_utils", - ":matmul_utils", - ":split_k_gemm_rewriter", - ":stream_executor_util", - ":cudnn_fusion_compiler", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cuda_headers", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:status_macros", - "//xla/tools:hlo_decomposer_lib", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:algorithm_util", - "//xla/service:dump", - "//xla/service:executable", - "//xla/service:float_normalization", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:shaped_buffer", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", - "//xla/stream_executor", - "//xla/stream_executor/gpu:redzone_allocator", - "@local_tsl//tsl/lib/core:bits", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - "//xla/tsl/util/proto:proto_utils", - "//xla/service/gpu:hlo_traversal", - ":fusion_wrapper", - ":priority_fusion", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/stream_executor:stream_executor_memory_allocator", - "@local_tsl//tsl/platform:path", - ]), -) - -xla_test( - name = "gemm_fusion_autotuner_test", - srcs = if_cuda_is_configured(["gemm_fusion_autotuner_test.cc"]), - backend_tags = {"gpu": [ - "requires-gpu-sm80", - ]}, - backends = [ - "gpu", - ], - tags = [ - "nomac", - ], - deps = [ - ":autotuner_util", - ":backend_configs_cc", - ":gemm_fusion", - ":gemm_fusion_autotuner", - ":gemm_rewriter", - ":ir_emission_utils", - ":matmul_utils", - "//xla:autotuning_proto_cc", - "//xla:error_spec", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:call_inliner", - "//xla/service:dump", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass_pipeline", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_description_proto_cc", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "//xla/tools:hlo_decomposer_lib", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]), -) - cc_library( name = "triton_call", srcs = if_gpu_is_configured(["triton_call.cc"]), @@ -784,26 +655,22 @@ cc_library( "//xla/hlo/ir:backend_config", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", - "//xla/service:hlo_parser", - "//xla/service/llvm_ir:buffer_assignment_util", "//xla/service/llvm_ir:llvm_type_conversion_util", "//xla/service/llvm_ir:llvm_util", - "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla/translate/mhlo_to_hlo:type_to_shape", + "//xla/stream_executor:device_description", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", "@llvm-project//llvm:TargetParser", - "@local_tsl//tsl/lib/strings:proto_serialization", - "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], ) @@ -817,13 +684,14 @@ xla_cc_test( ":ir_emission_utils", "//xla:literal", "//xla:literal_util", + "//xla:shape_util", "//xla:types", - "//xla:util", "//xla/hlo/ir:backend_config", "//xla/hlo/ir:hlo", "//xla/service:buffer_assignment", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -835,6 +703,7 @@ cc_library( name = "reduction_utils", srcs = ["reduction_utils.cc"], hdrs = ["reduction_utils.h"], + compatible_with = get_compatible_with_portable(), local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":ir_emission_utils", @@ -915,45 +784,6 @@ build_cub_sort_kernels( ]), ) -cc_library( - name = "gemm_rewriter", - srcs = ["gemm_rewriter.cc"], - hdrs = ["gemm_rewriter.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":ir_emission_utils", - ":matmul_utils", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/evaluator:hlo_evaluator", - "//xla/hlo/ir:hlo", - "//xla/service:algorithm_util", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_description", - "//xla/stream_executor/gpu:gpu_blas_lt", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", - ], -) - cc_library( name = "triton_tiling_propagation", srcs = ["triton_tiling_propagation.cc"], @@ -1019,9 +849,9 @@ xla_cc_test( name = "triton_fusion_analysis_test", srcs = ["triton_fusion_analysis_test.cc"], deps = [ - ":gemm_fusion", ":triton_fusion_analysis", "//xla/hlo/ir:hlo", + "//xla/service/gpu/transforms:gemm_fusion", "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", @@ -1032,100 +862,6 @@ xla_cc_test( ], ) -cc_library( - name = "gemm_fusion", - srcs = ["gemm_fusion.cc"], - hdrs = ["gemm_fusion.h"], - deps = [ - ":backend_configs_cc", - ":cublas_padding_requirements", - ":ir_emission_utils", - ":matmul_utils", - ":triton_fusion_analysis", - ":triton_tiling_propagation", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:tensor_float_32_utils", - ], -) - -xla_cc_test( - name = "gemm_fusion_test", - srcs = ["gemm_fusion_test.cc"], - deps = [ - ":cublas_padding_requirements", - ":gemm_fusion", - ":triton_fusion_analysis", - "//xla:autotuning_proto_cc", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gemv_rewriter", - srcs = ["gemv_rewriter.cc"], - hdrs = ["gemv_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gemv_rewriter_test", - srcs = ["gemv_rewriter_test.cc"], - deps = [ - ":gemv_rewriter", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "split_k_gemm_rewriter", srcs = ["split_k_gemm_rewriter.cc"], @@ -1176,10 +912,10 @@ xla_cc_test( "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -1187,230 +923,20 @@ xla_cc_test( ) cc_library( - name = "softmax_rewriter_triton", - srcs = ["softmax_rewriter_triton.cc"], - hdrs = ["softmax_rewriter_triton.h"], + name = "matmul_utils", + srcs = ["matmul_utils.cc"], + hdrs = ["matmul_utils.h"], + compatible_with = get_compatible_with_portable(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ ":backend_configs_cc", - ":hlo_traversal", ":ir_emission_utils", + "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/service/gpu/model:fusion_analysis_cache", - "//xla/service/gpu/model:gpu_indexing_performance_model", - "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/service/gpu/model:tiled_hlo_computation", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gemm_algorithm_picker", - srcs = if_gpu_is_configured(["gemm_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["gemm_algorithm_picker.h"]), - deps = if_gpu_is_configured([ - ":backend_configs_cc", - ":buffer_comparator", - ":cublas_cudnn", - ":gpu_asm_opts_util", - ":gpu_conv_runner", - ":ir_emission_utils", - ":matmul_utils", - ":stream_executor_util", - ":variant_visitor", - ":autotuner_compile_util", - ":autotuner_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "//xla:autotune_results_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla:status_macros", - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor/gpu:gpu_blas_lt", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla/tsl/util/proto:proto_utils", - "//xla:util", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - ]) + ["@com_google_absl//absl/status"], -) - -cc_library( - name = "autotuner_util", - srcs = if_gpu_is_configured(["autotuner_util.cc"]), - hdrs = if_gpu_is_configured(["autotuner_util.h"]), - deps = if_gpu_is_configured([ - ":gpu_asm_opts_util", - ":stream_executor_util", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", - "//xla:autotune_results_proto_cc", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:compilation_environments", - "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor", - "//xla/stream_executor/gpu:redzone_allocator", - "@local_tsl//tsl/platform:base64", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", - ]), -) - -# We need a separate target, as runtime executable cannot depend on compilation -# pipeline. -cc_library( - name = "autotuner_compile_util", - srcs = if_gpu_is_configured(["autotuner_compile_util.cc"]), - hdrs = if_gpu_is_configured(["autotuner_compile_util.h"]), - deps = if_gpu_is_configured([ - ":autotuner_util", - ":gpu_executable_run_options", - ":ir_emission_utils", - "@com_google_absl//absl/functional:any_invocable", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "//xla/hlo/ir:hlo", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:maybe_owning_device_memory", - "//xla/service:shaped_buffer", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_stream_header", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla:executable_run_options", - "//xla:shape_util", - "//xla:util", - "//xla:xla_proto_cc", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ]) + ["@com_google_absl//absl/status"], -) - -xla_test( - name = "autotuner_compile_util_test", - srcs = if_gpu_is_configured(["autotuner_compile_util_test.cc"]), - backends = ["gpu"], - deps = if_gpu_is_configured( - [ - ":autotuner_compile_util", - ":autotuner_util", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "//xla/hlo/ir:hlo", - "//xla/service:platform_util", - "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:statusor", - ], - if_false = [ - "@com_google_googletest//:gtest_main", # b/317293391 - ], - ), -) - -xla_test( - name = "gemm_algorithm_picker_test", - srcs = if_gpu_is_configured(["gemm_algorithm_picker_test.cc"]), - backends = [ - "gpu_v100", - "gpu_amd_any", - ], - deps = [ - ":autotuner_util", - ":backend_configs_cc", - ":gemm_algorithm_picker", - ":gemm_rewriter", - ":variant_visitor", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:platform_util", - "//xla/stream_executor:device_description", - "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - "@local_tsl//tsl/protobuf:dnn_proto_cc", - ], -) - -cc_library( - name = "matmul_utils", - srcs = ["matmul_utils.cc"], - hdrs = ["matmul_utils.h"], - compatible_with = get_compatible_with_portable(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - ":backend_configs_cc", - ":ir_emission_utils", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", @@ -1467,4616 +993,1906 @@ xla_cc_test( ) cc_library( - name = "dot_dimension_sorter", - srcs = ["dot_dimension_sorter.cc"], - hdrs = ["dot_dimension_sorter.h"], + name = "gpu_conv_runner", + srcs = ["gpu_conv_runner.cc"], + hdrs = ["gpu_conv_runner.h"], deps = [ - "//xla:permutation_util", + ":backend_configs_cc", + ":cublas_cudnn", + ":stream_executor_util", "//xla:shape_util", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor:lazy_op_runner", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_test( - name = "dot_dimension_sorter_test", - srcs = ["dot_dimension_sorter_test.cc"], - backends = ["gpu"], - deps = [ - ":dot_dimension_sorter", - "//xla:error_spec", - "//xla/hlo/ir:hlo", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_googletest//:gtest", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "dot_sparsity_rewriter", - srcs = ["dot_sparsity_rewriter.cc"], - hdrs = ["dot_sparsity_rewriter.h"], + name = "gpu_norm_runner", + srcs = ["gpu_norm_runner.cc"], + hdrs = ["gpu_norm_runner.h"], deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":stream_executor_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor:lazy_op_runner", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "dot_sparsity_rewriter_test", - srcs = ["dot_sparsity_rewriter_test.cc"], - deps = [ - ":dot_sparsity_rewriter", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:statusor", - ], + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), ) cc_library( - name = "gpu_async_collective_annotator", - srcs = ["gpu_async_collective_annotator.cc"], - hdrs = ["gpu_async_collective_annotator.h"], + name = "cusolver_context", + srcs = if_gpu_is_configured(["cusolver_context.cc"]), + hdrs = if_gpu_is_configured(["cusolver_context.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ - ":backend_configs_cc", + "//xla:comparison_util", + "//xla:types", "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla:xla_data_proto_cc", + "//xla/stream_executor", + "//xla/stream_executor:blas", + "//xla/stream_executor/gpu:gpu_stream", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_async_collective_annotator_test", - srcs = ["gpu_async_collective_annotator_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_async_collective_annotator", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "//xla/tsl/cuda:cusolver", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + "//xla/stream_executor/rocm:rocblas_wrapper", + "//xla/stream_executor/rocm:rocsolver_wrapper", + "//xla/stream_executor/rocm:hipsolver_wrapper", + ]), +) + +tf_proto_library( + name = "fusion_process_dump_proto", + srcs = ["fusion_process_dump.proto"], + cc_api_version = 2, + protodeps = [ + "//xla/stream_executor:device_description_proto", ], ) cc_library( - name = "gpu_convert_async_collectives_to_sync", - srcs = ["gpu_convert_async_collectives_to_sync.cc"], - hdrs = ["gpu_convert_async_collectives_to_sync.h"], + name = "fusion_process_dump", + srcs = ["fusion_process_dump.cc"], + hdrs = ["fusion_process_dump.h"], deps = [ - ":backend_configs_cc", + ":fusion_process_dump_proto_cc", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:convert_async_collectives_to_sync", + "//xla/service:hlo_graph_dumper", + "//xla/stream_executor:device_description", + "//xla/tools:hlo_module_loader", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "gpu_convert_async_collectives_to_sync_test", - srcs = ["gpu_convert_async_collectives_to_sync_test.cc"], + name = "fusion_process_dump_test", + srcs = ["fusion_process_dump_test.cc"], deps = [ - ":backend_configs_cc", - ":gpu_convert_async_collectives_to_sync", - "//xla:util", + ":fusion_process_dump", + ":fusion_process_dump_proto_cc", + ":gpu_device_info_for_tests", + "//xla:test", "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "conv_algorithm_picker", - srcs = if_gpu_is_configured(["conv_algorithm_picker.cc"]), - hdrs = if_gpu_is_configured(["conv_algorithm_picker.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([ - ":autotuner_compile_util", - ":autotuner_util", - ":backend_configs_cc", - ":buffer_comparator", + name = "cudnn_support_utils", + srcs = ["cudnn_support_utils.cc"], + hdrs = ["cudnn_support_utils.h"], + deps = [ ":cublas_cudnn", - ":gpu_asm_opts_util", - ":gpu_autotuning_proto_cc", - ":gpu_conv_runner", - ":hlo_algorithm_denylist", - ":stream_executor_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cudnn_header", - "//xla:autotune_results_proto_cc", - "//xla:autotuning_proto_cc", - "//xla:debug_options_flags", - "//xla:literal_util", "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", + "//xla:window_util", "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:slow_operation_alarm", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:numeric_options", - "//xla/stream_executor:scratch_allocator", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:lazy_op_runner", - "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/gpu:redzone_allocator", - "//xla/stream_executor/rocm:rocm_platform_id", - "@local_tsl//tsl/platform:errors", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "//xla/tsl/util:env_var", "@local_tsl//tsl/platform:statusor", - "//xla/tsl/util/proto:proto_utils", - "@local_tsl//tsl/platform:status", - ]), + ], ) -xla_test( - name = "conv_algorithm_picker_test", - srcs = if_gpu_is_configured(["conv_algorithm_picker_test.cc"]), - backends = [ - "gpu_v100", - "gpu_amd_any", - ], - tags = [ - "noasan", - "nomsan", - ], +xla_cc_test( + name = "cudnn_support_utils_test", + srcs = ["cudnn_support_utils_test.cc"], deps = [ - ":autotuner_util", - ":conv_algorithm_picker", - ":gpu_conv_rewriter", - ":stream_executor_util", - "//xla:debug_options_flags", + ":cudnn_support_utils", + "//xla:shape_util", + "//xla:test", + "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:platform_util", - "//xla/service:tuple_simplifier", + "//xla/service:hlo_parser", "//xla/stream_executor:device_description", - "//xla/stream_executor:platform", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "gpu_conv_runner", - srcs = ["gpu_conv_runner.cc"], - hdrs = ["gpu_conv_runner.h"], + name = "cublas_padding_requirements", + srcs = ["cublas_padding_requirements.cc"], + hdrs = ["cublas_padding_requirements.h"], deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", + ":variant_visitor", "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", + "//xla/stream_executor:device_description", + ], +) + +tf_proto_library( + name = "executable_proto", + srcs = ["executable.proto"], + cc_api_version = 2, + protodeps = [ + "//xla/service:hlo_proto", + "//xla:xla_proto", ], ) cc_library( - name = "gpu_norm_runner", - srcs = ["gpu_norm_runner.cc"], - hdrs = ["gpu_norm_runner.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]), + name = "target_constants", + hdrs = ["target_constants.h"], ) cc_library( - name = "gpu_fused_mha_runner", - srcs = ["gpu_fused_mha_runner.cc"], - hdrs = ["gpu_fused_mha_runner.h"], + name = "gpu_transfer_manager", + srcs = ["gpu_transfer_manager.cc"], + hdrs = ["gpu_transfer_manager.h"], deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":stream_executor_util", + ":io_feed_manager", + ":target_constants", + "//xla:literal", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/service:compiler", + "//xla/service:generic_transfer_manager", + "//xla/service:shaped_buffer", + "//xla/service:transfer_manager", "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", + "//xla/stream_executor:event", + "//xla/stream_executor:memory_allocation", + "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/rocm:rocm_platform_id", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", - "@eigen_archive//:eigen3", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", "@local_tsl//tsl/platform:statusor", ], + alwayslink = True, # Contains per-platform transfer manager registration ) cc_library( - name = "gpu_conv_rewriter", - srcs = ["gpu_conv_rewriter.cc"], - hdrs = ["gpu_conv_rewriter.h"], + name = "gpu_float_support", + srcs = ["gpu_float_support.cc"], + hdrs = ["gpu_float_support.h"], deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", + "//xla/service:float_support", + "//xla/service/gpu/fusions/triton:triton_support", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", + "@com_google_absl//absl/log:check", ], ) cc_library( - name = "gpu_sort_rewriter", - srcs = if_gpu_is_configured( - ["gpu_sort_rewriter.cc"], - ["gpu_sort_rewriter_stub.cc"], - ), - hdrs = ["gpu_sort_rewriter.h"], + name = "compile_module_to_llvm_ir", + srcs = [ + "compile_module_to_llvm_ir.cc", + ], + hdrs = [ + "compile_module_to_llvm_ir.h", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":cublas_cudnn", - "//xla:comparison_util", + ":executable_proto_cc", + ":execution_stream_assignment", + ":gpu_constants", + ":gpu_executable", + ":gpu_memory_space_assignment", + ":ir_emitter_context", + ":ir_emitter_unnested", + ":metrics", + ":runtime_intrinsics", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:stable_sort_expander", - "//xla/service/gpu/runtime:cub_sort_thunk", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:buffer_assignment", + "//xla/service:buffer_value", + "//xla/service:dump", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_ordering", + "//xla/service:hlo_proto_cc", + "//xla/service:logical_buffer", + "//xla/service/gpu/runtime:sequential_thunk", + "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor/rocm:rocm_platform_id", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:AsmParser", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", ], ) cc_library( - name = "move_copy_to_users", - srcs = ["move_copy_to_users.cc"], - hdrs = ["move_copy_to_users.h"], + name = "fusion_pipeline", + srcs = ["fusion_pipeline.cc"], + hdrs = ["fusion_pipeline.h"], deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", + "//xla:xla_proto_cc", + "//xla/service:cpu_gpu_shape_verifier", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_cse", + "//xla/service:hlo_dce", "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "move_copy_to_users_test", - srcs = ["move_copy_to_users_test.cc"], - deps = [ - ":move_copy_to_users", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", "//xla/service:layout_assignment", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "gpu_conv_rewriter_test", - srcs = ["gpu_conv_rewriter_test.cc"], - deps = [ - ":cublas_cudnn", - ":gpu_conv_rewriter", - "//xla:array4d", - "//xla:literal_util", - "//xla:protobuf_util", - "//xla:shape_util", - "//xla:test", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:shape_inference", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/transforms:fusion_merger", + "//xla/service/gpu/transforms:horizontal_input_fusion", + "//xla/service/gpu/transforms:horizontal_loop_fusion", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:multi_output_fusion", + "//xla/service/gpu/transforms:priority_fusion", + "//xla/service/gpu/transforms:variadic_op_splitter", "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:env", ], ) -xla_test( - name = "gpu_sort_rewriter_test", - srcs = if_cuda_is_configured(["gpu_sort_rewriter_test.cc"]), - backends = ["gpu"], - tags = ["no_oss"], +cc_library( + name = "prepare_hlo_for_ir_emitting_pipeline", + srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"], + hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"], deps = [ - ":cublas_cudnn", - ":gpu_sort_rewriter", - "//xla:error_spec", - "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + "//xla/service:copy_insertion", + "//xla/service:cpu_gpu_shape_verifier", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_dce", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", + "//xla/service:layout_assignment", + "//xla/service:loop_schedule_linearizer", + "//xla/service/gpu/transforms:alias_passthrough_params", + "//xla/service/gpu/transforms:copy_fusion", + "//xla/service/gpu/transforms:horizontal_loop_fusion", + "//xla/service/gpu/transforms:sanitize_constant_names", ], ) cc_library( - name = "cusolver_context", - srcs = if_gpu_is_configured(["cusolver_context.cc"]), - hdrs = if_gpu_is_configured(["cusolver_context.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", + name = "gpu_compiler", + srcs = if_gpu_is_configured([ + "gpu_compiler.cc", ]), - deps = [ - "//xla:comparison_util", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor/gpu:gpu_stream", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "//xla/tsl/cuda:cusolver", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/rocm:rocblas_wrapper", - "//xla/stream_executor/rocm:rocsolver_wrapper", - "//xla/stream_executor/rocm:hipsolver_wrapper", + hdrs = if_gpu_is_configured([ + "gpu_compiler.h", ]), -) - -cc_library( - name = "cusolver_rewriter", - srcs = if_gpu_is_configured(["cusolver_rewriter.cc"]), - hdrs = if_gpu_is_configured(["cusolver_rewriter.h"]), deps = if_gpu_is_configured([ - ":cusolver_context", + # go/keep-sorted start prefix_order=":,, + ":buffer_sharing", + ":compile_module_to_llvm_ir", + ":conv_layout_normalization", + ":executable_proto_cc", + ":execution_stream_assignment", + ":fusion_pipeline", + ":gpu_constants", + ":gpu_executable", + ":gpu_float_support", + ":gpu_hlo_schedule", + ":gpu_latency_hiding_scheduler", + ":gpu_p2p_pipeliner", + ":gpu_spmd_pipeline", + ":hlo_fusion_stats", ":ir_emission_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "//xla:comparison_util", - "//xla:literal", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/stream_executor", - "//xla/stream_executor:blas", - "//xla/stream_executor:device_memory_allocator", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ]), -) - -cc_library( - name = "instruction_fusion", - srcs = ["instruction_fusion.cc"], - hdrs = ["instruction_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:fusion_node_indexing_evaluation", - "//xla/service:fusion_queue", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/stream_executor:device_description", + ":ir_emitter", + ":ir_emitter_context", + ":ir_emitter_unnested", + ":kernel_reuse_cache", + ":matmul_utils", + ":metrics", + ":prepare_hlo_for_ir_emitting_pipeline", + ":reduction_utils", + ":runtime_intrinsics", + ":stream_executor_util", + "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "instruction_fusion_test", - srcs = ["instruction_fusion_test.cc"], - tags = [ - "nomsan", - "not_run:arm", - ], - deps = [ - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":instruction_fusion", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -tf_proto_library( - name = "fusion_process_dump_proto", - srcs = ["fusion_process_dump.proto"], - cc_api_version = 2, - protodeps = [ - "//xla/stream_executor:device_description_proto", - ], -) - -cc_library( - name = "fusion_process_dump", - srcs = ["fusion_process_dump.cc"], - hdrs = ["fusion_process_dump.h"], - deps = [ - ":fusion_process_dump_proto_cc", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_graph_dumper", - "//xla/stream_executor:device_description", - "//xla/tools:hlo_module_loader", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "fusion_process_dump_test", - srcs = ["fusion_process_dump_test.cc"], - deps = [ - ":fusion_process_dump", - ":fusion_process_dump_proto_cc", - ":gpu_device_info_for_tests", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "priority_fusion", - srcs = ["priority_fusion.cc"], - hdrs = ["priority_fusion.h"], - deps = [ - ":backend_configs_cc", - ":fusion_process_dump_proto_cc", - ":gpu_fusible", - ":hlo_fusion_analysis", - ":hlo_traversal", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:dump", - "//xla/service:fusion_queue", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_graph_dumper", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/model:fusion_analysis_cache", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/model:gpu_performance_model", - "//xla/service/gpu/model:gpu_performance_model_base", - "//xla/service/gpu/model:symbolic_tile_analysis", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "priority_fusion_test", - srcs = ["priority_fusion_test.cc"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["no_pip"], - deps = [ - ":backend_configs_cc", - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":hlo_fusion_analysis", - ":priority_fusion", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "multi_output_fusion", - srcs = ["multi_output_fusion.cc"], - hdrs = ["multi_output_fusion.h"], - deps = [ - ":gpu_fusible", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_dfs_reachability", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_graph_dumper", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/model:gpu_performance_model", - "//xla/service/gpu/model:gpu_performance_model_base", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "multi_output_fusion_test", - srcs = ["multi_output_fusion_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":gpu_device_info_for_tests", - ":gpu_fusible", - ":multi_output_fusion", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "rename_fusions", - srcs = ["rename_fusions.cc"], - hdrs = ["rename_fusions.h"], - deps = [ - ":hlo_traversal", - ":ir_emission_utils", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "rename_fusions_test", - srcs = ["rename_fusions_test.cc"], - deps = [ - ":rename_fusions", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - ], -) - -xla_cc_test( - name = "softmax_rewriter_triton_test", - srcs = ["softmax_rewriter_triton_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_device_info_for_tests", - ":softmax_rewriter_triton", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:instruction_fusion", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_sanitize_constant_names", - srcs = ["gpu_sanitize_constant_names.cc"], - hdrs = ["gpu_sanitize_constant_names.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:name_uniquer", - "//xla/service/llvm_ir:buffer_assignment_util", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "gpu_sanitize_constant_names_test", - srcs = ["gpu_sanitize_constant_names_test.cc"], - deps = [ - ":gpu_sanitize_constant_names", - "//xla:literal_util", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "fusion_merger", - srcs = ["fusion_merger.cc"], - hdrs = ["fusion_merger.h"], - deps = [ - ":gpu_fusible", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_graph_dumper", - "//xla/service:hlo_pass", - "//xla/service:instruction_fusion", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/model:gpu_performance_model", - "//xla/service/gpu/model:gpu_performance_model_base", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "fusion_merger_test", - srcs = ["fusion_merger_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":fusion_merger", - ":gpu_device_info_for_tests", - ":gpu_fusible", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_cost_analysis", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "gpu_conv_padding_legalization", - srcs = ["gpu_conv_padding_legalization.cc"], - hdrs = ["gpu_conv_padding_legalization.h"], - deps = [ - ":cublas_cudnn", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:shape_inference", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_conv_padding_legalization_test", - srcs = ["gpu_conv_padding_legalization_test.cc"], - deps = [ - ":cublas_cudnn", - ":gpu_conv_padding_legalization", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "cudnn_support_utils", - srcs = ["cudnn_support_utils.cc"], - hdrs = ["cudnn_support_utils.h"], - deps = [ - ":cublas_cudnn", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "cudnn_support_utils_test", - srcs = ["cudnn_support_utils_test.cc"], - deps = [ - ":cudnn_support_utils", - "//xla:shape_util", - "//xla:test", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cudnn_pad_for_convolutions", - srcs = ["cudnn_pad_for_convolutions.cc"], - hdrs = ["cudnn_pad_for_convolutions.h"], - deps = [ - ":cublas_cudnn", - ":cudnn_support_utils", - ":stream_executor_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:bind_front", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "cudnn_pad_for_convolutions_test", - srcs = ["cudnn_pad_for_convolutions_test.cc"], - deps = [ - ":cublas_cudnn", - ":cudnn_pad_for_convolutions", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "cudnn_vectorize_convolutions", - srcs = ["cudnn_vectorize_convolutions.cc"], - hdrs = ["cudnn_vectorize_convolutions.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":cudnn_support_utils", - ":stream_executor_util", - "//xla:shape_util", - "//xla:util", - "//xla/client:xla_builder", - "//xla/client:xla_computation", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "cudnn_vectorize_convolutions_test", - srcs = ["cudnn_vectorize_convolutions_test.cc"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":cudnn_vectorize_convolutions", - "//xla:util", - "//xla/service:call_inliner", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cudnn_simplify_padding", - srcs = ["cudnn_simplify_padding.cc"], - hdrs = ["cudnn_simplify_padding.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - "//xla:literal", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "cudnn_simplify_padding_test", - srcs = ["cudnn_simplify_padding_test.cc"], - deps = [ - ":cudnn_pad_for_convolutions", - ":cudnn_simplify_padding", - ":cudnn_vectorize_convolutions", - "//xla:literal", - "//xla:util", - "//xla/service:algebraic_simplifier", - "//xla/service:call_inliner", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cublas_pad_for_gemms", - srcs = ["cublas_pad_for_gemms.cc"], - hdrs = ["cublas_pad_for_gemms.h"], - deps = [ - ":gemm_fusion", - ":ir_emission_utils", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cublas_padding_requirements", - srcs = ["cublas_padding_requirements.cc"], - hdrs = ["cublas_padding_requirements.h"], - deps = [ - ":variant_visitor", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", - ], -) - -xla_cc_test( - name = "cublas_pad_for_gemms_test", - srcs = ["cublas_pad_for_gemms_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":cublas_pad_for_gemms", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "cudnn_fusion_compiler", - srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]), - hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]), - deps = if_cuda_is_configured([ - ":backend_configs_cc", - ":cudnn_support_utils", - ":ir_emission_utils", - ":kernel_reuse_cache", - ":matmul_utils", - ":triton_fusion_analysis", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_config_cuda//cuda:cudnn_header", - "//xla:shape_util", - "//xla:comparison_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "//xla/stream_executor:dnn", - "//xla/stream_executor:stream_executor_h", - "//xla/service:dump", - "//xla/stream_executor/cuda:cudnn_frontend_helpers", - "//xla/stream_executor/cuda:cudnn_plugin", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ]), -) - -cc_library( - name = "cudnn_workspace_rewriter", - srcs = if_cuda_is_configured(["cudnn_workspace_rewriter.cc"]), - hdrs = if_cuda_is_configured(["cudnn_workspace_rewriter.h"]), - deps = if_cuda_is_configured([ - ":backend_configs_cc", - ":ir_emission_utils", - ":gpu_fused_mha_runner", - ":cublas_cudnn", - ":stream_executor_util", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_config_cuda//cuda:cudnn_header", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/stream_executor/cuda:cudnn_frontend_helpers", - "//xla/stream_executor/cuda:cudnn_plugin", - "//xla/stream_executor:dnn", - "//xla/stream_executor:stream_executor_h", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - "//xla:status_macros", - ]), -) - -tf_proto_library( - name = "executable_proto", - srcs = ["executable.proto"], - cc_api_version = 2, - protodeps = [ - "//xla/service:hlo_proto", - "//xla:xla_proto", - ], -) - -cc_library( - name = "target_constants", - hdrs = ["target_constants.h"], -) - -cc_library( - name = "gpu_transfer_manager", - srcs = ["gpu_transfer_manager.cc"], - hdrs = ["gpu_transfer_manager.h"], - deps = [ - ":io_feed_manager", - ":target_constants", - "//xla:literal", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:compiler", - "//xla/service:generic_transfer_manager", - "//xla/service:shaped_buffer", - "//xla/service:transfer_manager", - "//xla/stream_executor", - "//xla/stream_executor:event", - "//xla/stream_executor:memory_allocation", - "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/rocm:rocm_platform_id", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/synchronization", - "@llvm-project//llvm:Core", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:statusor", - ], - alwayslink = True, # Contains per-platform transfer manager registration -) - -cc_library( - name = "gpu_reduce_scatter_creator", - srcs = ["gpu_reduce_scatter_creator.cc"], - hdrs = ["gpu_reduce_scatter_creator.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_opt_utils", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - ], -) - -cc_library( - name = "gpu_all_gather_optimizer", - srcs = ["gpu_all_gather_optimizer.cc"], - hdrs = ["gpu_all_gather_optimizer.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "gpu_float_support", - srcs = ["gpu_float_support.cc"], - hdrs = ["gpu_float_support.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:float_support", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/log:check", - ], -) - -cc_library( - name = "compile_module_to_llvm_ir", - srcs = [ - "compile_module_to_llvm_ir.cc", - ], - hdrs = [ - "compile_module_to_llvm_ir.h", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":executable_proto_cc", - ":execution_stream_assignment", - ":gpu_constants", - ":gpu_executable", - ":gpu_memory_space_assignment", - ":ir_emitter_context", - ":ir_emitter_unnested", - ":metrics", - ":runtime_intrinsics", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service:buffer_value", - "//xla/service:dump", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_ordering", - "//xla/service:hlo_proto_cc", - "//xla/service:logical_buffer", - "//xla/service/gpu/runtime:conditional_thunk", - "//xla/service/gpu/runtime:sequential_thunk", - "//xla/service/gpu/runtime:thunk", - "//xla/service/gpu/runtime:while_thunk", - "//xla/stream_executor", - "//xla/stream_executor:device_description", - "//xla/stream_executor/rocm:rocm_platform_id", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:AsmParser", - "@llvm-project//llvm:TransformUtils", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "command_buffer_scheduling", - srcs = ["command_buffer_scheduling.cc"], - hdrs = ["command_buffer_scheduling.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":hlo_fusion_analysis", - ":hlo_traversal", - ":ir_emission_utils", - ":variant_visitor", - "//xla:shape_util", - "//xla:util", - "//xla/ffi:ffi_api", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "command_buffer_scheduling_test", - srcs = ["command_buffer_scheduling_test.cc"], - deps = [ - ":command_buffer_scheduling", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/stream_executor:device_description", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "custom_kernel_fusion_autotuner", - srcs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.cc"]), - hdrs = if_gpu_is_configured(["custom_kernel_fusion_autotuner.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = if_gpu_is_configured([ - ":autotuner_compile_util", - ":autotuner_util", - ":backend_configs_cc", - ":buffer_comparator", - ":gemm_rewriter", - ":gpu_float_support", - ":gpu_fusible", - ":instruction_fusion", - ":ir_emission_utils", - ":matmul_utils", - ":split_k_gemm_rewriter", - "//xla/service/gpu/kernels:custom_kernel", - "//xla/service/gpu/kernels:custom_kernel_fusion", - ":stream_executor_util", - ":cudnn_fusion_compiler", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_config_cuda//cuda:cuda_headers", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:status_macros", - "//xla/tools:hlo_decomposer_lib", - "//xla:statusor", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:algorithm_util", - "//xla/service:dump", - "//xla/service:executable", - "//xla/service:float_normalization", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:shaped_buffer", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory", - "//xla/stream_executor", - "//xla/stream_executor/gpu:redzone_allocator", - "@local_tsl//tsl/lib/core:bits", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - "//xla/tsl/util/proto:proto_utils", - "//xla/service/gpu:hlo_traversal", - ]) + [ - "//xla/stream_executor:stream_executor_memory_allocator", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:path", - ], -) - -xla_test( - name = "custom_kernel_fusion_autotuner_test", - srcs = if_cuda_is_configured(["custom_kernel_fusion_autotuner_test.cc"]), - backends = [ - "gpu", - ], - deps = [ - ":autotuner_util", - ":custom_kernel_fusion_autotuner", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass_pipeline", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "custom_kernel_fusion_rewriter", - srcs = ["custom_kernel_fusion_rewriter.cc"], - hdrs = ["custom_kernel_fusion_rewriter.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service/gpu/kernels:custom_fusion_library", - "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "custom_kernel_fusion_rewriter_test", - srcs = ["custom_kernel_fusion_rewriter_test.cc"], - deps = [ - ":custom_kernel_fusion_rewriter", - ":gpu_device_info_for_tests", - "//xla/hlo/ir:hlo", - "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "dynamic_slice_fusion_rewriter", - srcs = ["dynamic_slice_fusion_rewriter.cc"], - hdrs = ["dynamic_slice_fusion_rewriter.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":gpu_constants", - ":hlo_traversal", - ":ir_emission_utils", - "//xla:shape_util", - "//xla:util", - "//xla/ffi:ffi_api", - "//xla/ffi/api:c_api", - "//xla/hlo/ir:hlo", - "//xla/service:custom_call_target_registry", - "//xla/service:hlo_pass", - "//xla/service/gpu/kernels:custom_fusion_library", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "dynamic_slice_fusion_rewriter_test", - srcs = if_cuda_is_configured(["dynamic_slice_fusion_rewriter_test.cc"]), - deps = [ - ":dynamic_slice_fusion_rewriter", - ":gpu_device_info_for_tests", - "//xla:shape_util", - "//xla/client:xla_builder", - "//xla/client/lib:constants", - "//xla/ffi", - "//xla/ffi:ffi_api", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_value", - "//xla/service:custom_call_target_registry", - "//xla/service:executable", - "//xla/service:hlo_memory_scheduler", - "//xla/service:hlo_module_config", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_types_header", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "fusion_pipeline", - srcs = ["fusion_pipeline.cc"], - hdrs = ["fusion_pipeline.h"], - deps = [ - ":fusion_merger", - ":horizontal_input_fusion", - ":horizontal_loop_fusion", - ":instruction_fusion", - ":multi_output_fusion", - ":priority_fusion", - ":variadic_op_splitter", - "//xla:xla_proto_cc", - "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_cse", - "//xla/service:hlo_dce", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_verifier", - "//xla/service:layout_assignment", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/stream_executor:device_description", - "@local_tsl//tsl/platform:env", - ], -) - -cc_library( - name = "prepare_hlo_for_ir_emitting_pipeline", - srcs = ["prepare_hlo_for_ir_emitting_pipeline.cc"], - hdrs = ["prepare_hlo_for_ir_emitting_pipeline.h"], - deps = [ - ":alias_passthrough_params", - ":copy_fusion", - ":gpu_sanitize_constant_names", - ":horizontal_loop_fusion", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:copy_insertion", - "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_verifier", - "//xla/service:layout_assignment", - "//xla/service:loop_schedule_linearizer", - ], -) - -cc_library( - name = "gpu_compiler", - srcs = if_gpu_is_configured([ - "gpu_compiler.cc", - ]), - hdrs = if_gpu_is_configured([ - "gpu_compiler.h", - ]), - deps = if_gpu_is_configured([ - # go/keep-sorted start prefix_order=":,, - ":algorithm_checker", - ":alias_passthrough_params", - ":all_reduce_blueconnect", - ":autotuner_util", - ":buffer_sharing", - ":collective_permute_cycle_decomposer", - ":collective_permute_valid_iteration_annotator", - ":command_buffer_scheduling", - ":compile_module_to_llvm_ir", - ":conv_layout_normalization", - ":copy_fusion", - ":custom_kernel_fusion_autotuner", - ":custom_kernel_fusion_rewriter", - ":dot_dimension_sorter", - ":dot_operand_converter", - ":double_buffer_loop_unrolling", - ":dynamic_slice_fusion_rewriter", - ":executable_proto_cc", - ":execution_stream_assignment", - ":fusion_merger", - ":fusion_pipeline", - ":fusion_wrapper", - ":gemm_broadcast_folding_rewriter", - ":gemm_fusion", - ":gemm_rewriter", - ":gemv_rewriter", - ":gpu_algebraic_simplifier", - ":gpu_all_gather_optimizer", - ":gpu_async_collective_annotator", - ":gpu_constants", - ":gpu_conv_rewriter", - ":gpu_convert_async_collectives_to_sync", - ":gpu_executable", - ":gpu_float_support", - ":gpu_hlo_schedule", - ":gpu_latency_hiding_scheduler", - ":gpu_layout_assignment", - ":gpu_p2p_pipeliner", - ":gpu_reduce_scatter_creator", - ":gpu_sanitize_constant_names", - ":gpu_scatter_expander", - ":gpu_spmd_pipeline", - ":gpu_windowed_einsum_handler", - ":hlo_fusion_stats", - ":horizontal_input_fusion", - ":horizontal_loop_fusion", - ":instruction_fusion", - ":ir_emission_utils", - ":ir_emitter", - ":ir_emitter_context", - ":ir_emitter_unnested", - ":kernel_reuse_cache", - ":matmul_utils", - ":metrics", - ":move_copy_to_users", - ":multi_output_fusion", - ":pipelined_p2p_rewriter", - ":prepare_hlo_for_ir_emitting_pipeline", - ":priority_fusion", - ":reduction_degenerate_dim_remover", - ":reduction_dimension_grouper", - ":reduction_layout_normalizer", - ":reduction_splitter", - ":reduction_utils", - ":rename_fusions", - ":runtime_intrinsics", - ":scatter_slice_simplifier", - ":softmax_rewriter_triton", - ":stream_attribute_annotator", - ":stream_attribute_async_wrapper", - ":stream_executor_util", - ":topk_specializer", - ":topk_splitter", - ":tree_reduction_rewriter", - ":triton_fusion_numerics_verifier", - ":variadic_op_splitter", - "@com_google_absl//absl/base", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_absl//absl/types:variant", - "@llvm-project//llvm:AsmParser", - "@llvm-project//llvm:BitReader", - "@llvm-project//llvm:BitWriter", - "@llvm-project//llvm:Core", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:TransformUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/service/gpu/model:gpu_cost_model_stats_collection", - "//xla/service/gpu/model:gpu_hlo_cost_analysis", - "//xla/service/gpu/runtime:thunk", - "//xla/service/llvm_ir:llvm_util", - "//xla/service/spmd:collective_permute_motion", - "//xla/service:algebraic_simplifier", - "//xla/service:all_gather_broadcast_reorder", - "//xla/service:all_gather_combiner", - "//xla/service:all_reduce_combiner", - "//xla/service:all_reduce_contiguous", - "//xla/service:all_reduce_folder", - "//xla/service:all_reduce_promotion", - "//xla/service:all_reduce_reassociate", - "//xla/service:all_reduce_splitter", - "//xla/service:async_collective_creator", - "//xla/service:batchnorm_expander", - "//xla/service:bitcast_dtypes_expander", - "//xla/service:broadcast_canonicalizer", - "//xla/service:buffer_assignment", - "//xla/service:buffer_value", - "//xla/service:call_inliner", - "//xla/service:collective_permute_decomposer", - "//xla/service:collective_pipeliner", - "//xla/service:collective_quantizer", - "//xla/service:collectives_schedule_linearizer", - "//xla/service:comparison_expander", - "//xla/service:compiler", - "//xla/service:conditional_canonicalizer", - "//xla/service:conditional_simplifier", - "//xla/service:convert_async_collectives_to_sync", - "//xla/service:convert_memory_placement_to_internal_annotations", - "//xla/service:convert_mover", - "//xla/service:convolution_4d_expander", - "//xla/service:convolution_pred_expander", - "//xla/service:copy_insertion", - "//xla/service:cpu_gpu_shape_verifier", - "//xla/service:dot_decomposer", - "//xla/service:dot_merger", - "//xla/service:dump", - "//xla/service:dynamic_dimension_inference", - "//xla/service:dynamic_dimension_simplifier", - "//xla/service:dynamic_index_splitter", - "//xla/service:dynamic_padder", - "//xla/service:eigh_expander", - "//xla/service:executable", - "//xla/service:export_hlo", - "//xla/service:flatten_call_graph", - "//xla/service:float_normalization", - "//xla/service:float_support", - "//xla/service:gather_expander", - "//xla/service:gather_simplifier", - "//xla/service:hlo_computation_deduplicator", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_cse", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", - "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_proto_cc", - "//xla/service:hlo_rematerialization", - "//xla/service:hlo_verifier", - "//xla/service:host_memory_transfer_asyncifier", - "//xla/service:host_offload_legalize", - "//xla/service:host_offloader", - "//xla/service:layout_assignment", - "//xla/service:layout_normalization", - "//xla/service:llvm_compiler", - "//xla/service:logical_buffer", - "//xla/service:logistic_expander", - "//xla/service:loop_schedule_linearizer", - "//xla/service:operand_upcaster", - "//xla/service:optimization_barrier_expander", - "//xla/service:optimize_input_output_buffer_alias", - "//xla/service:qr_expander", - "//xla/service:real_imag_expander", - "//xla/service:reduce_decomposer", - "//xla/service:reduce_scatter_combiner", - "//xla/service:reduce_scatter_reassociate", - "//xla/service:reduce_window_rewriter", - "//xla/service:reshape_decomposer", - "//xla/service:reshape_mover", - "//xla/service:result_caster", - "//xla/service:rng_bit_generator_expander", - "//xla/service:rng_expander", - "//xla/service:scatter_expander", - "//xla/service:scatter_simplifier", - "//xla/service:sharding_remover", - "//xla/service:simplify_fp_conversions", - "//xla/service:slice_sinker", - "//xla/service:slow_operation_alarm", - "//xla/service:sort_simplifier", - "//xla/service:stable_sort_expander", - "//xla/service:stochastic_convert_decomposer", - "//xla/service:sub_byte_normalization", - "//xla/service:topk_rewriter", - "//xla/service:transpose_folding", - "//xla/service:tuple_simplifier", - "//xla/service:while_loop_all_reduce_code_motion", - "//xla/service:while_loop_constant_sinking", - "//xla/service:while_loop_simplifier", - "//xla/service:while_loop_trip_count_annotator", - "//xla/service:zero_sized_hlo_elimination", - "//xla/stream_executor", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/integrations:device_mem_allocator", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:dnn", - "//xla/stream_executor:platform_manager", - "//xla/translate/hlo_to_mhlo:hlo_utils", - "//xla/translate/mhlo_to_hlo:location_exporter", - "//xla:autotune_results_proto_cc", - "//xla:debug_options_flags", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "@local_tsl//tsl/lib/monitoring:counter", - "@local_tsl//tsl/platform:blocking_counter", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:numbers", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:platform_port", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - "@local_tsl//tsl/profiler/lib:traceme", - # go/keep-sorted end - ]) + xla_internal(["service:export_hlo"]) + if_google([ - "//xla/hlo/experimental/auto_sharding", - ]), -) - -xla_test( - name = "gpu_compiler_test", - srcs = if_gpu_is_configured(["gpu_compiler_test.cc"]), - backend_tags = { - "gpu_a100": ["no_rocm"], - "gpu_v100": ["no_rocm"], - }, - backends = ["gpu"], - data = ["gpu_compiler_test_autotune_db.textproto"], - deps = [ - ":autotuner_util", - ":gpu_compiler", - ":gpu_hlo_schedule", - ":metrics", - "//xla:autotune_results_proto_cc", - "//xla:error_spec", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:xla_debug_info_manager", - "//xla/stream_executor:device_description", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:casts", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -xla_test( - name = "gpu_offloading_test", - srcs = ["gpu_offloading_test.cc"], - backends = ["gpu"], - tags = ["no_rocm"], - deps = [ - ":backend_configs_cc", - "//xla:autotune_results_proto_cc", - "//xla:error_spec", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:buffer_value", - "//xla/service:hlo_cost_analysis", - "//xla/service:hlo_memory_scheduler", - "//xla/service:hlo_rematerialization", - "//xla/service/gpu:stream_attribute_annotator", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "auto_sharding_gpu_compiler_test", - srcs = ["auto_sharding_gpu_compiler_test.cc"], - backends = ["gpu"], - tags = ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "nvptx_compiler", - srcs = [ - "nvptx_compiler_registration.cc", - ], - tags = [ - "gpu", - "manual", - "no_rocm", - ], - deps = [ - ":nvptx_compiler_impl", - "//xla/service:compiler", - "//xla/stream_executor/cuda:cuda_platform_id", - "@local_tsl//tsl/platform:path", - ], - alwayslink = True, # Contains compiler registration -) - -cc_library( - name = "nvptx_compiler_impl", - srcs = [ - "nvptx_compiler.cc", - ], - hdrs = [ - "nvptx_compiler.h", - ], - tags = [ - "gpu", - "manual", - "no_rocm", - ], - deps = [ - ":autotuner_util", - ":buffer_sharing", - ":conv_algorithm_picker", - ":cublas_pad_for_gemms", - ":cublas_padding_requirements", - ":cudnn_fused_conv_rewriter", - ":cudnn_fused_mha_rewriter", - ":cudnn_fused_mha_transpose_fusion", - ":cudnn_fusion_compiler", - ":cudnn_norm_rewriter", - ":cudnn_pad_for_convolutions", - ":cudnn_simplify_padding", - ":cudnn_vectorize_convolutions", - ":cudnn_workspace_rewriter", - ":cusolver_rewriter", - ":dot_sparsity_rewriter", - ":gemm_algorithm_picker", - ":gemm_fusion_autotuner", - ":gpu_algebraic_simplifier", - ":gpu_asm_opts_util", - ":gpu_compiler", - ":gpu_conv_padding_legalization", - ":gpu_conv_rewriter", - ":gpu_sort_rewriter", - ":ir_emission_utils", - ":metrics", - ":target_constants", - ":triangular_solve_rewriter", - "//xla:autotune_results_proto_cc", - "//xla:util", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/pjrt/distributed:key_value_store_interface", - "//xla/service:algebraic_simplifier", - "//xla/service:call_inliner", - "//xla/service:convert_mover", - "//xla/service:dot_dimension_merger", - "//xla/service:dump", - "//xla/service:float_normalization", - "//xla/service:float_support", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_cse", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_dce", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_verifier", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", - "//xla/service/gpu/llvm_gpu_backend", - "//xla/service/llvm_ir:llvm_util", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/stream_executor/cuda:cuda_asm_compiler", - "//xla/stream_executor/cuda:cuda_diagnostics", - "//xla/stream_executor/cuda:cuda_platform_id", - "//xla/stream_executor/cuda:ptx_compilation_method", - "//xla/stream_executor/cuda:ptx_compiler", - "//xla/stream_executor/cuda:ptx_compiler_support", - "//xla/stream_executor/cuda:ptx_linking_method", - "//xla/stream_executor/gpu:gpu_asm_opts", - "//xla/stream_executor/gpu:gpu_driver_header", - "//xla/stream_executor/gpu:gpu_executor_header", - "//xla/tsl/util:env_var", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/cleanup", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:IRReader", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/lib:scoped_annotation", - "@local_tsl//tsl/profiler/lib:traceme", - ], -) - -xla_test( - name = "nvptx_compiler_test", - srcs = [ - "nvptx_compiler_test.cc", - ], - backends = [ - "gpu_v100", - "gpu_a100", - ], - tags = [ - "gpu", - "no_rocm", - "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. - ], - deps = [ - ":gpu_constants", - ":gpu_hlo_schedule", - ":gpu_latency_hiding_scheduler", - ":nvptx_compiler_impl", - "//xla:util", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:buffer_assignment", - "//xla/service:buffer_value", - "//xla/service:hlo_ordering", - "//xla/service:logical_buffer", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "ptx_compilation_test", - srcs = [ - "ptx_compilation_test.cc", - ], - backends = [ - "gpu", - ], - tags = [ - "gpu", - "no_rocm", - "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. - ], - deps = [ - ":gpu_executable", - ":nvptx_compiler_impl", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/stream_executor:device_description", - "//xla/stream_executor/cuda:ptx_compilation_method", - "//xla/stream_executor/cuda:ptx_compiler_support", - "//xla/stream_executor/cuda:ptx_linking_method", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@llvm-project//llvm:Object", - "@llvm-project//llvm:Support", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - -xla_cc_test( - name = "gpu_aot_compilation_test", - srcs = if_gpu_is_configured([ - "gpu_aot_compilation_test.cc", - ]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - tags = [ - "gpu", - "ignore_for_dep=third_party/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h", - "no_oss", - "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. - "requires-gpu-nvidia", - ], - deps = if_cuda_is_configured([ - ":nvptx_compiler_impl", - ]) + if_rocm_is_configured([ - ":amdgpu_compiler_impl", - ]) + [ - ":gpu_transfer_manager", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_module_group", - "//xla/service:compiler", - "//xla/service:executable", - "//xla/service:gpu_plugin", - "//xla/service:platform_util", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "amdgpu_compiler", - srcs = [ - "amdgpu_compiler_registration.cc", - ], - local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), - tags = ["manual"], - deps = [ - ":amdgpu_compiler_impl", - "//xla/service:compiler", - "//xla/stream_executor/rocm:rocm_platform_id", - ], - alwayslink = True, # Contains compiler registration -) - -cc_library( - name = "gpu_algebraic_simplifier", - srcs = [ - "gpu_algebraic_simplifier.cc", - ], - hdrs = [ - "gpu_algebraic_simplifier.h", - ], - deps = [ - ":matmul_utils", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/service:hlo_pass", - "//xla/service/gpu/fusions/triton:triton_support", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "gpu_algebraic_simplifier_test", - srcs = ["gpu_algebraic_simplifier_test.cc"], - deps = [ - ":gpu_algebraic_simplifier", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "amdgpu_compiler_impl", - srcs = [ - "amdgpu_compiler.cc", - ], - hdrs = [ - "amdgpu_compiler.h", - ], - tags = ["manual"], - deps = [ - ":autotuner_util", - ":conv_algorithm_picker", - ":cublas_pad_for_gemms", - ":cublas_padding_requirements", - ":cudnn_fused_conv_rewriter", - ":cusolver_rewriter", - ":gemm_algorithm_picker", - ":gpu_algebraic_simplifier", - ":gpu_compiler", - ":gpu_conv_padding_legalization", - ":gpu_conv_rewriter", - ":gpu_sort_rewriter", - ":target_constants", - ":triangular_solve_rewriter", - "//xla:util", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/service:call_inliner", - "//xla/service:convert_mover", - "//xla/service:dot_dimension_merger", - "//xla/service:float_normalization", - "//xla/service:float_support", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_verifier", - "//xla/service:reshape_mover", - "//xla/service:tuple_simplifier", - "//xla/service/gpu/llvm_gpu_backend", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:dnn", - "//xla/stream_executor/rocm:rocm_platform_id", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:ir_headers", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ] + if_rocm_is_configured([ - # keep sorted - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -cc_library( - name = "all_reduce_blueconnect", - srcs = ["all_reduce_blueconnect.cc"], - hdrs = ["all_reduce_blueconnect.h"], - deps = [ - "//xla:shape_util", - "//xla:status_macros", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_ops_utils", - "//xla/service:computation_placer_hdr", - "//xla/service:global_device_id", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:btree", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "all_reduce_blueconnect_test", - srcs = ["all_reduce_blueconnect_test.cc"], - deps = [ - ":all_reduce_blueconnect", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:computation_placer_hdr", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "xfeed_queue", - hdrs = ["xfeed_queue.h"], - deps = [ - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "io_feed_manager", - srcs = [ - "infeed_manager.cc", - "outfeed_manager.cc", - "xla_executor_state.h", - ], - hdrs = [ - "infeed_manager.h", - "outfeed_manager.h", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":xfeed_queue", - "//xla:literal", - "//xla:shape_tree", - "//xla:shape_util", - "//xla:util", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:stream_executor_h", - "//xla/stream_executor/gpu:gpu_executor_header", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:notification", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_layout_assignment", - srcs = ["gpu_layout_assignment.cc"], - hdrs = ["gpu_layout_assignment.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":matmul_utils", - ":reduction_utils", - ":stream_executor_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:util", - "//xla:window_util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:computation_layout", - "//xla/service:host_memory_offload_annotations_hdr", - "//xla/service:layout_assignment", - "//xla/service:logical_buffer", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "//xla/tsl/util:env_var", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_layout_assignment_test", - srcs = ["gpu_layout_assignment_test.cc"], - deps = [ - ":gpu_layout_assignment", - ":stream_executor_util", - "//xla:shape_layout", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:computation_layout", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # build_cleaner: keep - "@com_google_absl//absl/types:span", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_schedule_postprocessing", - srcs = ["gpu_schedule_postprocessing.cc"], - hdrs = ["gpu_schedule_postprocessing.h"], - deps = [ - ":backend_configs_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_schedule_postprocessing_test", - srcs = ["gpu_schedule_postprocessing_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_schedule_postprocessing", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - #"@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_hlo_schedule", - srcs = ["gpu_hlo_schedule.cc"], - hdrs = ["gpu_hlo_schedule.h"], - deps = [ - ":backend_configs_cc", - ":gpu_latency_hiding_scheduler", - ":gpu_schedule_postprocessing", - ":scheduling_instruction_annotator", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:buffer_value", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_memory_scheduler", - "//xla/service:hlo_pass_pipeline", - "//xla/service:latency_hiding_scheduler", - "//xla/service:p2p_schedule_preparation", - "//xla/service:profile_guided_latency_estimator", - "//xla/service/gpu/model:analytical_latency_estimator", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", - ], -) - -xla_test( - name = "gpu_hlo_schedule_test", - srcs = [ - "gpu_hlo_schedule_test.cc", - ], - backends = ["gpu"], - deps = [ - ":gpu_hlo_schedule", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:backend", - "//xla/service:hlo_module_config", - "//xla/service:hlo_ordering", - "//xla/stream_executor:device_description", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:test_utils", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", - ], -) - -cc_library( - name = "gpu_p2p_pipeliner", - srcs = ["gpu_p2p_pipeliner.cc"], - hdrs = ["gpu_p2p_pipeliner.h"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", - "//xla/service:collective_pipeliner", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass_pipeline", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - ], -) - -xla_cc_test( - name = "gpu_p2p_pipeliner_test", - srcs = [ - "gpu_p2p_pipeliner_test.cc", - ], - deps = [ - ":gpu_p2p_pipeliner", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass_pipeline", - "//xla/service:hlo_verifier", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - ], -) - -cc_library( - name = "gpu_spmd_pipeline", - srcs = ["gpu_spmd_pipeline.cc"], - hdrs = ["gpu_spmd_pipeline.h"], - deps = [ - ":gpu_algebraic_simplifier", - ":runtime_intrinsics", - "//xla/hlo/ir:hlo", - "//xla/hlo/transforms:hlo_constant_splitter", - "//xla/service:algebraic_simplifier", - "//xla/service:conditional_simplifier", - "//xla/service:gather_expander", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_dce", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:reshape_mover", - "//xla/service:scatter_expander", - "//xla/service:sharding_propagation", - "//xla/service:sort_simplifier", - "//xla/service:tuple_simplifier", - "//xla/service:while_loop_constant_sinking", - "//xla/service:while_loop_simplifier", - "//xla/service/spmd:collective_permute_motion", - "//xla/service/spmd:stateful_rng_spmd_partitioner", - "//xla/service/spmd/shardy:shardy_xla_pass", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/functional:function_ref", - "@com_google_absl//absl/log:check", - ], -) - -xla_cc_test( - name = "gpu_spmd_pipeline_test", - srcs = [ - "gpu_spmd_pipeline_test.cc", - ], - deps = [ - ":gpu_spmd_pipeline", - "//xla:shape_util", - "//xla:util", - "//xla/client:executable_build_options", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass_pipeline", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "while_transformer_test", - srcs = ["while_transformer_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", - "//xla:test", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:while_loop_analysis", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - ], -) - -cuda_library( - name = "stream_executor_util_kernel", - srcs = if_cuda_is_configured(["stream_executor_util_kernel.cu.cc"]), - deps = ["@local_config_cuda//cuda:cuda_headers"], -) - -cc_library( - name = "stream_executor_util", - srcs = ["stream_executor_util.cc"], - hdrs = ["stream_executor_util.h"], - copts = tsl_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":cublas_cudnn", - ":launch_dimensions", - ":stream_executor_util_kernel", - "//xla:autotuning_proto_cc", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/stream_executor", - "//xla/stream_executor:data_type", - "//xla/stream_executor:dnn", - "//xla/stream_executor:kernel_factory", - "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", - "//xla/stream_executor:typed_kernel_factory", - "//xla/tsl/util:env_var", - "//xla/tsl/util/proto:proto_utils", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", - "@com_google_absl//absl/time", - "@com_google_absl//absl/types:span", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "stream_executor_util_test", - srcs = ["stream_executor_util_test.cc"], - deps = [ - ":stream_executor_util", - "//xla:autotuning_proto_cc", - "//xla/service:hlo_module_config", - "//xla/tsl/util/proto:proto_utils", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/time", - "@com_google_googletest//:gtest_main", - ], -) - -cc_library( - name = "gpu_asm_opts_util", - srcs = ["gpu_asm_opts_util.cc"], - hdrs = ["gpu_asm_opts_util.h"], - compatible_with = get_compatible_with_portable(), - copts = tsl_copts(), - deps = [ - "//xla:xla_proto_cc", - "//xla/stream_executor/gpu:gpu_asm_opts", - "@com_google_absl//absl/strings", - ], -) - -cc_library( - name = "hlo_fusion_analysis", - srcs = ["hlo_fusion_analysis.cc"], - hdrs = ["hlo_fusion_analysis.h"], - deps = [ - ":backend_configs_cc", - ":hlo_traversal", - ":ir_emission_utils", - ":reduction_utils", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - ], -) - -xla_cc_test( - name = "hlo_fusion_analysis_test", - srcs = ["hlo_fusion_analysis_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_device_info_for_tests", - ":hlo_fusion_analysis", - ":hlo_traversal", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_description_proto_cc", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "buffer_comparator", - srcs = if_gpu_is_configured(["buffer_comparator.cc"]), - hdrs = if_gpu_is_configured(["buffer_comparator.h"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured([ - # keep sorted - ":buffer_comparator_kernel", - ":gpu_asm_opts_util", - ":launch_dimensions", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/service:hlo_module_config", - "//xla/stream_executor", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:asm_compiler", - "@com_google_absl//absl/base", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@eigen_archive//:eigen3", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - ]) + if_rocm_is_configured([ - # keep sorted - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -gpu_kernel_library( - name = "buffer_comparator_kernel", - srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]), - copts = rocm_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -xla_test( - name = "buffer_comparator_test", - srcs = if_gpu_is_configured(["buffer_comparator_test.cc"]), - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - ":stream_executor_util", - "//xla:shape_util", - "//xla:types", - "//xla/service:hlo_module_config", - "//xla/stream_executor", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:device_memory_handle", - "//xla/stream_executor:platform_manager", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ] + if_gpu_is_configured([ - ":buffer_comparator", - "//xla/stream_executor:device_memory", - ]), -) - -cc_library( - name = "buffer_sharing", - srcs = ["buffer_sharing.cc"], - hdrs = ["buffer_sharing.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":hlo_fusion_analysis", - ":ir_emission_utils", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_description_proto_cc", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@llvm-project//llvm:Support", - ], -) - -cc_library( - name = "gpu_fusible", - srcs = ["gpu_fusible.cc"], - hdrs = ["gpu_fusible.h"], - deps = [ - ":backend_configs_cc", - ":hlo_traversal", - ":ir_emission_utils", - ":reduction_utils", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_dataflow_analysis", - "//xla/service:instruction_fusion", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/synchronization", - ], -) - -xla_cc_test( - name = "gpu_fusible_test", - srcs = ["gpu_fusible_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":gpu_fusible", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "cudnn_fused_conv_rewriter", - srcs = ["cudnn_fused_conv_rewriter.cc"], - hdrs = ["cudnn_fused_conv_rewriter.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - "//xla:comparison_util", - "//xla:debug_options_flags", - "//xla:literal", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:ml_dtypes", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "cudnn_fused_conv_rewriter_test", - srcs = ["cudnn_fused_conv_rewriter_test.cc"], - backend_tags = { - "gpu_a100": [ - "noasan", - "nomsan", - "no_rocm", - ], - }, - backends = [ - "gpu_a100", - "gpu_amd_any", - ] + if_oss(["gpu_any"]), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - shard_count = 10, - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":cudnn_fused_conv_rewriter", - ":gpu_conv_rewriter", - ":stream_executor_util", - "//xla:comparison_util", - "//xla:error_spec", - "//xla/hlo/ir:hlo", - "//xla/service:algebraic_simplifier", - "//xla/service:convert_mover", - "//xla/service:hlo_constant_folding", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service:reshape_mover", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -cc_library( - name = "cudnn_norm_rewriter", - srcs = ["cudnn_norm_rewriter.cc"], - hdrs = ["cudnn_norm_rewriter.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - "//xla:shape_util", - "//xla:types", - "//xla:util", - "//xla:window_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/stream_executor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/protobuf:dnn_proto_cc", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", - ]) + if_google([ - "@com_google_protobuf//:wrappers_cc_proto", - ]), -) - -xla_test( - name = "cudnn_norm_rewriter_test", - srcs = ["cudnn_norm_rewriter_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":cublas_cudnn", - ":cudnn_norm_rewriter", - "//xla:error_spec", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/stream_executor:device_description", - "//xla/tests:filecheck", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", - ]), -) - -cc_library( - name = "cudnn_fused_mha_rewriter", - srcs = ["cudnn_fused_mha_rewriter.cc"], - hdrs = ["cudnn_fused_mha_rewriter.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":matmul_utils", - ":stream_executor_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/stream_executor", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]), -) - -cc_library( - name = "cudnn_fused_mha_transpose_fusion", - srcs = ["cudnn_fused_mha_transpose_fusion.cc"], - hdrs = ["cudnn_fused_mha_transpose_fusion.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":matmul_utils", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "cudnn_fused_mha_rewriter_test", - srcs = ["cudnn_fused_mha_rewriter_test.cc"], - backend_tags = {"gpu": [ - "requires-gpu-nvidia", - "no_rocm", - ]}, - backends = [ - "gpu", - ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", - ":cudnn_fused_mha_rewriter", - ":cudnn_fused_mha_transpose_fusion", - "//xla:error_spec", - "//xla:test_helpers", - "//xla:util", - "//xla:xla_data_proto_cc", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:variant", + "@llvm-project//llvm:AsmParser", + "@llvm-project//llvm:BitReader", + "@llvm-project//llvm:BitWriter", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:TransformUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_module_group", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/service/gpu/autotuning:custom_kernel_fusion_autotuner", + "//xla/service/gpu/model:gpu_cost_model_stats_collection", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/runtime:thunk", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/gpu/transforms:algorithm_checker", + "//xla/service/gpu/transforms:all_gather_optimizer", + "//xla/service/gpu/transforms:all_reduce_blueconnect", + "//xla/service/gpu/transforms:all_reduce_splitter", + "//xla/service/gpu/transforms:async_collective_annotator", + "//xla/service/gpu/transforms:async_wrapper", + "//xla/service/gpu/transforms:collective_permute_cycle_decomposer", + "//xla/service/gpu/transforms:collective_permute_valid_iteration_annotator", + "//xla/service/gpu/transforms:command_buffer_scheduling", + "//xla/service/gpu/transforms:conv_rewriter", + "//xla/service/gpu/transforms:convert_async_collectives_to_sync", + "//xla/service/gpu/transforms:cudnn_custom_call_converter", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", + "//xla/service/gpu/transforms:dot_dimension_sorter", + "//xla/service/gpu/transforms:dot_operand_converter", + "//xla/service/gpu/transforms:double_buffer_loop_unrolling", + "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", + "//xla/service/gpu/transforms:fusion_wrapper", + "//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter", + "//xla/service/gpu/transforms:gemm_fusion", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/service/gpu/transforms:gemv_rewriter", + "//xla/service/gpu/transforms:layout_assignment", + "//xla/service/gpu/transforms:move_copy_to_users", + "//xla/service/gpu/transforms:pipelined_p2p_rewriter", + "//xla/service/gpu/transforms:reduce_scatter_creator", + "//xla/service/gpu/transforms:reduction_degenerate_dim_remover", + "//xla/service/gpu/transforms:reduction_dimension_grouper", + "//xla/service/gpu/transforms:reduction_layout_normalizer", + "//xla/service/gpu/transforms:reduction_splitter", + "//xla/service/gpu/transforms:rename_fusions", + "//xla/service/gpu/transforms:sanitize_constant_names", + "//xla/service/gpu/transforms:scatter_expander", + "//xla/service/gpu/transforms:scatter_slice_simplifier", + "//xla/service/gpu/transforms:softmax_rewriter_triton", + "//xla/service/gpu/transforms:stream_attribute_annotator", + "//xla/service/gpu/transforms:stream_attribute_async_wrapper", + "//xla/service/gpu/transforms:topk_specializer", + "//xla/service/gpu/transforms:topk_splitter", + "//xla/service/gpu/transforms:transpose_dimension_grouper", + "//xla/service/gpu/transforms:tree_reduction_rewriter", + "//xla/service/gpu/transforms:triton_fusion_numerics_verifier", + "//xla/service/gpu/transforms:windowed_einsum_handler", + "//xla/service/llvm_ir:llvm_util", + "//xla/service/spmd:collective_permute_motion", "//xla/service:algebraic_simplifier", - "//xla/service:computation_layout", + "//xla/service:all_gather_broadcast_reorder", + "//xla/service:all_gather_combiner", + "//xla/service:all_reduce_combiner", + "//xla/service:all_reduce_contiguous", + "//xla/service:all_reduce_folder", + "//xla/service:all_reduce_promotion", + "//xla/service:all_reduce_reassociate", + "//xla/service:async_collective_creator", + "//xla/service:batchnorm_expander", + "//xla/service:bitcast_dtypes_expander", + "//xla/service:broadcast_canonicalizer", + "//xla/service:buffer_assignment", + "//xla/service:buffer_value", + "//xla/service:call_inliner", + "//xla/service:collective_permute_decomposer", + "//xla/service:collective_pipeliner", + "//xla/service:collective_quantizer", + "//xla/service:collectives_schedule_linearizer", + "//xla/service:comparison_expander", + "//xla/service:compiler", + "//xla/service:conditional_canonicalizer", + "//xla/service:conditional_simplifier", + "//xla/service:convert_async_collectives_to_sync", + "//xla/service:convert_memory_placement_to_internal_annotations", + "//xla/service:convert_mover", + "//xla/service:convolution_4d_expander", + "//xla/service:convolution_pred_expander", + "//xla/service:copy_insertion", + "//xla/service:cpu_gpu_shape_verifier", + "//xla/service:dot_decomposer", + "//xla/service:dot_merger", + "//xla/service:dump", + "//xla/service:dynamic_dimension_inference", + "//xla/service:dynamic_dimension_simplifier", + "//xla/service:dynamic_index_splitter", + "//xla/service:dynamic_padder", + "//xla/service:eigh_expander", + "//xla/service:executable", + "//xla/service:export_hlo", + "//xla/service:flatten_call_graph", + "//xla/service:float_normalization", + "//xla/service:float_support", + "//xla/service:gather_expander", + "//xla/service:gather_simplifier", + "//xla/service:hlo_computation_deduplicator", + "//xla/service:hlo_constant_folding", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_cse", + "//xla/service:hlo_dataflow_analysis", "//xla/service:hlo_dce", "//xla/service:hlo_module_config", - "//xla/service:hlo_parser", + "//xla/service:hlo_ordering", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_proto_cc", + "//xla/service:hlo_rematerialization", "//xla/service:hlo_verifier", + "//xla/service:host_memory_transfer_asyncifier", + "//xla/service:host_offload_legalize", + "//xla/service:host_offloader", + "//xla/service:layout_assignment", "//xla/service:layout_normalization", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", + "//xla/service:llvm_compiler", + "//xla/service:logical_buffer", + "//xla/service:logistic_expander", + "//xla/service:loop_schedule_linearizer", + "//xla/service:operand_upcaster", + "//xla/service:optimization_barrier_expander", + "//xla/service:optimize_input_output_buffer_alias", + "//xla/service:qr_expander", + "//xla/service:real_imag_expander", + "//xla/service:reduce_decomposer", + "//xla/service:reduce_scatter_combiner", + "//xla/service:reduce_scatter_reassociate", + "//xla/service:reduce_window_rewriter", "//xla/service:reshape_decomposer", - "//xla/stream_executor:device_description", - "//xla/stream_executor:dnn", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_cuda//cuda:cudnn_header", - ]), -) - -xla_test( - name = "float_support_test", - srcs = ["float_support_test.cc"], - backend_tags = {"gpu": [ - "requires-gpu-sm80", - "no_rocm", - ]}, - backends = [ - "gpu", - ], - deps = [ - ":variant_visitor", - "//xla:error_spec", - "//xla:xla_proto_cc", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - ], -) - -xla_test( - name = "conv_layout_normalization_test", - srcs = ["conv_layout_normalization_test.cc"], - backends = ["gpu"], - deps = [ - "//xla:error_spec", - "//xla/hlo/ir:hlo", - "//xla/service/gpu/tests:gpu_codegen_test", # fixdeps: keep - "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "variadic_op_splitter", - srcs = ["variadic_op_splitter.cc"], - hdrs = ["variadic_op_splitter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "gpu_scatter_expander", - srcs = ["gpu_scatter_expander.cc"], - hdrs = ["gpu_scatter_expander.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", + "//xla/service:reshape_mover", + "//xla/service:result_caster", + "//xla/service:rng_bit_generator_expander", + "//xla/service:rng_expander", "//xla/service:scatter_expander", - "@com_google_absl//absl/strings:string_view", - ], -) - -xla_cc_test( - name = "variadic_op_splitter_test", - srcs = ["variadic_op_splitter_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":variadic_op_splitter", - "//xla:literal_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", - ], -) - -tf_proto_library( - name = "gpu_autotuning_proto", - srcs = ["gpu_autotuning.proto"], - cc_api_version = 2, - protodeps = [ - ":backend_configs", - "//xla:xla_data_proto", - "//xla/service:hlo_proto", - "//xla:autotuning_proto", - ], -) - -cc_library( - name = "hlo_algorithm_denylist", - srcs = ["hlo_algorithm_denylist.cc"], - hdrs = ["hlo_algorithm_denylist.h"], - deps = [ - ":backend_configs_cc", - ":gpu_autotuning_proto_cc", - "//xla:autotuning_proto_cc", - "//xla:debug_options_flags", - "//xla/hlo/ir:backend_config", - "//xla/stream_executor:dnn", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:protobuf", - "@local_tsl//tsl/platform:status", - ], -) - -xla_cc_test( - name = "hlo_algorithm_denylist_test", - srcs = ["hlo_algorithm_denylist_test.cc"], - data = ["data/hlo_algorithm_denylist.pbtxt"], - deps = [ - ":hlo_algorithm_denylist", + "//xla/service:scatter_simplifier", + "//xla/service:sharding_remover", + "//xla/service:simplify_fp_conversions", + "//xla/service:slice_sinker", + "//xla/service:slow_operation_alarm", + "//xla/service:sort_simplifier", + "//xla/service:stable_sort_expander", + "//xla/service:stochastic_convert_decomposer", + "//xla/service:sub_byte_normalization", + "//xla/service:topk_rewriter", + "//xla/service:transpose_folding", + "//xla/service:tuple_simplifier", + "//xla/service:while_loop_all_reduce_code_motion", + "//xla/service:while_loop_constant_sinking", + "//xla/service:while_loop_simplifier", + "//xla/service:while_loop_trip_count_annotator", + "//xla/service:zero_sized_hlo_elimination", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/integrations:device_mem_allocator", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:dnn", - "//xla/tests:test_utils", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:path", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -cc_library( - name = "alias_passthrough_params", - srcs = ["alias_passthrough_params.cc"], - hdrs = ["alias_passthrough_params.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -xla_cc_test( - name = "alias_passthrough_params_test", - srcs = ["alias_passthrough_params_test.cc"], - tags = [ - "nomsan", - ], - deps = [ - ":alias_passthrough_params", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:test", - ], -) - -cc_library( - name = "horizontal_loop_fusion", - srcs = ["horizontal_loop_fusion.cc"], - hdrs = ["horizontal_loop_fusion.h"], - deps = [ - ":gpu_fusible", + "//xla/stream_executor:platform_manager", + "//xla/translate/hlo_to_mhlo:hlo_utils", + "//xla/translate/mhlo_to_hlo:location_exporter", + "//xla:autotune_results_proto_cc", + "//xla:debug_options_flags", "//xla:shape_util", + "//xla:status_macros", + "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:sub_byte_normalization", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "//xla:xla_proto_cc", + "@local_tsl//tsl/lib/monitoring:counter", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", - ], + "@local_tsl//tsl/profiler/lib:scoped_annotation", + "@local_tsl//tsl/profiler/lib:traceme", + # go/keep-sorted end + ]) + xla_internal(["service:export_hlo"]) + if_google([ + "//xla/hlo/experimental/auto_sharding", + ]), ) xla_test( - name = "horizontal_loop_fusion_test", - srcs = ["horizontal_loop_fusion_test.cc"], + name = "gpu_compiler_test", + srcs = ["gpu_compiler_test.cc"], + backend_tags = { + "gpu_a100": ["no_rocm"], + "gpu_v100": ["no_rocm"], + }, backends = ["gpu"], + data = ["gpu_compiler_test_autotune_db.textproto"], deps = [ - ":gpu_device_info_for_tests", - ":horizontal_loop_fusion", - ":instruction_fusion", + ":gpu_compiler", + ":gpu_hlo_schedule", + ":metrics", + "//xla:autotune_results_proto_cc", "//xla:error_spec", + "//xla:literal", + "//xla:literal_util", "//xla:shape_util", - "//xla:test", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass", - "//xla/service:hlo_pass_pipeline", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/service:xla_debug_info_manager", + "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:casts", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) -cc_library( - name = "horizontal_input_fusion", - srcs = ["horizontal_input_fusion.cc"], - hdrs = ["horizontal_input_fusion.h"], +xla_test( + name = "gpu_offloading_test", + srcs = ["gpu_offloading_test.cc"], + backends = ["gpu"], + tags = ["no_rocm"], #TODO(rocm): weekly sync deps = [ - ":gpu_fusible", + ":backend_configs_cc", + "//xla:autotune_results_proto_cc", + "//xla:error_spec", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:buffer_value", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_rematerialization", + "//xla/service/gpu/transforms:stream_attribute_annotator", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:statusor", ], ) xla_test( - name = "horizontal_input_fusion_test", - srcs = ["horizontal_input_fusion_test.cc"], + name = "auto_sharding_gpu_compiler_test", + srcs = ["auto_sharding_gpu_compiler_test.cc"], backends = ["gpu"], + tags = ["no_oss"], # TODO(b/277355322): Make autosharding work in OSS deps = [ - ":gpu_device_info_for_tests", - ":horizontal_input_fusion", - "//xla:error_spec", - "//xla:literal_util", "//xla:shape_util", - "//xla:test", "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/stream_executor:device_description", - "//xla/tests:xla_internal_test_main", - ], -) - -xla_cc_test( - name = "gpu_float_support_test", - srcs = ["gpu_float_support_test.cc"], - deps = [ - ":backend_configs_cc", - ":gpu_float_support", - ":ir_emission_utils", - "//xla:shape_util", - "//xla:test_helpers", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:float_normalization", - "//xla/service:hlo_verifier", - "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest_main", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:logging", ], ) cc_library( - name = "reduction_degenerate_dim_remover", - srcs = ["reduction_degenerate_dim_remover.cc"], - hdrs = ["reduction_degenerate_dim_remover.h"], + name = "nvptx_compiler", + srcs = [ + "nvptx_compiler_registration.cc", + ], + tags = [ + "gpu", + "manual", + "no_rocm", + ], deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", + ":nvptx_compiler_impl", + "//xla/service:compiler", + "//xla/stream_executor/cuda:cuda_platform_id", + "@local_tsl//tsl/platform:path", ], + alwayslink = True, # Contains compiler registration ) cc_library( - name = "reduction_dimension_grouper", - srcs = ["reduction_dimension_grouper.cc"], - hdrs = ["reduction_dimension_grouper.h"], + name = "nvptx_compiler_impl", + srcs = [ + "nvptx_compiler.cc", + ], + hdrs = [ + "nvptx_compiler.h", + ], + tags = [ + "gpu", + "manual", + "no_rocm", + ], deps = [ - "//xla:shape_util", + ":buffer_sharing", + ":cublas_padding_requirements", + ":gpu_asm_opts_util", + ":gpu_compiler", + ":ir_emission_utils", + ":metrics", + ":target_constants", + "//xla:autotune_results_proto_cc", + "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:algebraic_simplifier", + "//xla/service:call_inliner", + "//xla/service:convert_mover", + "//xla/service:dot_dimension_merger", + "//xla/service:dump", + "//xla/service:float_normalization", + "//xla/service:float_support", + "//xla/service:hlo_constant_folding", + "//xla/service:hlo_cse", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_dce", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", + "//xla/service:reshape_mover", + "//xla/service:tuple_simplifier", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/service/gpu/autotuning:conv_algorithm_picker", + "//xla/service/gpu/autotuning:gemm_algorithm_picker", + "//xla/service/gpu/autotuning:gemm_fusion_autotuner", + "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/gpu/transforms:conv_padding_legalization", + "//xla/service/gpu/transforms:conv_rewriter", + "//xla/service/gpu/transforms:cublas_pad_for_gemms", + "//xla/service/gpu/transforms:cudnn_custom_call_compiler", + "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", + "//xla/service/gpu/transforms:cudnn_fused_mha_rewriter", + "//xla/service/gpu/transforms:cudnn_fused_mha_transpose_fusion", + "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:cudnn_norm_rewriter", + "//xla/service/gpu/transforms:cudnn_pad_for_convolutions", + "//xla/service/gpu/transforms:cudnn_simplify_padding", + "//xla/service/gpu/transforms:cudnn_vectorize_convolutions", + "//xla/service/gpu/transforms:dot_sparsity_rewriter", + "//xla/service/gpu/transforms:gpusolver_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", + "//xla/service/gpu/transforms:triangular_solve_rewriter", + "//xla/service/llvm_ir:llvm_util", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/stream_executor/cuda:cuda_asm_compiler", + "//xla/stream_executor/cuda:cuda_diagnostics", + "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/cuda:nvjitlink", + "//xla/stream_executor/cuda:nvjitlink_support", + "//xla/stream_executor/cuda:ptx_compilation_method", + "//xla/stream_executor/cuda:ptx_compiler", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/cuda:ptx_linking_method", + "//xla/stream_executor/gpu:gpu_asm_opts", + "//xla/stream_executor/gpu:gpu_driver_header", + "//xla/stream_executor/gpu:gpu_executor_header", + "//xla/tsl/util:env_var", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/base", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:IRReader", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + "@local_tsl//tsl/profiler/lib:traceme", ], ) -cc_library( - name = "reduction_splitter", - srcs = ["reduction_splitter.cc"], - hdrs = ["reduction_splitter.h"], - deps = [ - ":reduction_utils", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", +xla_test( + name = "nvptx_compiler_test", + srcs = [ + "nvptx_compiler_test.cc", ], -) - -xla_cc_test( - name = "reduction_splitter_test", - srcs = ["reduction_splitter_test.cc"], - deps = [ - ":reduction_splitter", - "//xla:shape_util", - "//xla:test", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", + backends = [ + "gpu_v100", + "gpu_a100", + ], + tags = [ + "no_rocm", + "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. ], -) - -cc_library( - name = "reduction_layout_normalizer", - srcs = ["reduction_layout_normalizer.cc"], - hdrs = ["reduction_layout_normalizer.h"], deps = [ - "//xla:shape_util", - "//xla:status_macros", + ":gpu_constants", + ":gpu_hlo_schedule", + ":gpu_latency_hiding_scheduler", + ":nvptx_compiler_impl", "//xla:util", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", + "//xla/hlo/utils:hlo_query", + "//xla/service:backend", + "//xla/service:buffer_assignment", + "//xla/service:buffer_value", + "//xla/service:hlo_ordering", + "//xla/service:logical_buffer", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) -cc_library( - name = "tree_reduction_rewriter", - srcs = ["tree_reduction_rewriter.cc"], - hdrs = ["tree_reduction_rewriter.h"], +xla_test( + name = "ptx_compilation_test", + srcs = [ + "ptx_compilation_test.cc", + ], + backends = [ + "gpu", + ], + tags = [ + "no_rocm", + "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. + ], deps = [ - ":reduction_utils", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", + ":gpu_executable", + ":nvptx_compiler_impl", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", + "//xla/service:executable", "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", + "//xla/stream_executor/cuda:nvjitlink_support", + "//xla/stream_executor/cuda:ptx_compilation_method", + "//xla/stream_executor/cuda:ptx_compiler_support", + "//xla/stream_executor/cuda:ptx_linking_method", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Object", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) -cc_library( - name = "gemm_broadcast_folding_rewriter", - srcs = ["gemm_broadcast_folding_rewriter.cc"], - hdrs = ["gemm_broadcast_folding_rewriter.h"], - deps = [ - ":backend_configs_cc", - ":cublas_cudnn", +xla_cc_test( + name = "gpu_aot_compilation_test", + srcs = if_gpu_is_configured([ + "gpu_aot_compilation_test.cc", + ]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + tags = [ + "gpu", + "ignore_for_dep=third_party/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h", + "no_oss", + "nomsan", # Pulls in precompiled NVIDIA libraries which cause false positives in msan. + "requires-gpu-nvidia", + ], + deps = if_cuda_is_configured([ + ":nvptx_compiler_impl", + ]) + if_rocm_is_configured([ + ":amdgpu_compiler_impl", + ]) + [ + ":gpu_transfer_manager", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", + "//xla/hlo/ir:hlo_module_group", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:gpu_plugin", + "//xla/service:platform_util", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/stream_executor", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "metrics", - srcs = ["metrics.cc"], - hdrs = ["metrics.h"], + name = "amdgpu_compiler", + srcs = [ + "amdgpu_compiler_registration.cc", + ], + local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), + tags = ["manual"], deps = [ - "@local_tsl//tsl/lib/monitoring:counter", - "@local_tsl//tsl/lib/monitoring:gauge", - "@local_tsl//tsl/lib/monitoring:sampler", + ":amdgpu_compiler_impl", + "//xla/service:compiler", + "//xla/stream_executor/rocm:rocm_platform_id", ], + alwayslink = True, # Contains compiler registration ) cc_library( - name = "dot_operand_converter", - srcs = ["dot_operand_converter.cc"], - hdrs = ["dot_operand_converter.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:op_expander_pass", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", + name = "amdgpu_compiler_impl", + srcs = [ + "amdgpu_compiler.cc", ], -) - -xla_test( - name = "dot_operand_converter_test", - srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]), - backends = [ - "gpu_a100", - "gpu_p100", - "gpu_v100", - "gpu_amd_any", + hdrs = [ + "amdgpu_compiler.h", ], - deps = if_gpu_is_configured( - [ - ":dot_operand_converter", - "@com_google_googletest//:gtest", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_matchers", - "//xla/service:pattern_matcher", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/platform:statusor", - ], - ["@local_tsl//tsl/platform:test_main"], # b/317293391 - ) + ["//xla:xla_data_proto_cc"], -) - -cc_library( - name = "make_batch_pointers", - srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), - hdrs = if_gpu_is_configured(["make_batch_pointers.h"]), + tags = ["manual"], deps = [ - "//xla:types", + ":cublas_padding_requirements", + ":gpu_compiler", + ":target_constants", "//xla:util", - "//xla/stream_executor", - "//xla/stream_executor:device_memory", - "//xla/stream_executor:typed_kernel_factory", - "//xla/stream_executor/gpu:gpu_stream_header", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/service:call_inliner", + "//xla/service:convert_mover", + "//xla/service:dot_dimension_merger", + "//xla/service:float_normalization", + "//xla/service:float_support", + "//xla/service:hlo_constant_folding", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", + "//xla/service:reshape_mover", + "//xla/service:tuple_simplifier", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/service/gpu/autotuning:conv_algorithm_picker", + "//xla/service/gpu/autotuning:gemm_algorithm_picker", + "//xla/service/gpu/llvm_gpu_backend", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/gpu/transforms:conv_padding_legalization", + "//xla/service/gpu/transforms:conv_rewriter", + "//xla/service/gpu/transforms:cublas_pad_for_gemms", + "//xla/service/gpu/transforms:cudnn_fused_conv_rewriter", + "//xla/service/gpu/transforms:gpusolver_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", + "//xla/service/gpu/transforms:triangular_solve_rewriter", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:dnn", + "//xla/stream_executor/rocm:rocm_platform_id", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ] + if_cuda_is_configured([ - ":make_batch_pointers_kernel", - ]) + if_rocm_is_configured([ - "//xla/stream_executor/rocm:rocm_helpers", + "@llvm-project//llvm:ir_headers", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ] + if_rocm_is_configured([ + # keep sorted + "@local_config_rocm//rocm:rocm_headers", ]), ) -cuda_library( - name = "make_batch_pointers_kernel", - srcs = if_cuda_is_configured(["make_batch_pointers.cu.cc"]), +cc_library( + name = "xfeed_queue", + hdrs = ["xfeed_queue.h"], deps = [ - "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/synchronization", + "@local_tsl//tsl/platform:logging", ], ) cc_library( - name = "triangular_solve_rewriter", - srcs = ["triangular_solve_rewriter.cc"], - hdrs = ["triangular_solve_rewriter.h"], + name = "io_feed_manager", + srcs = [ + "infeed_manager.cc", + "outfeed_manager.cc", + "xla_executor_state.h", + ], + hdrs = [ + "infeed_manager.h", + "outfeed_manager.h", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ - ":cublas_cudnn", + ":xfeed_queue", + "//xla:literal", + "//xla:shape_tree", "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla:util", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/gpu:gpu_executor_header", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:notification", "@local_tsl//tsl/platform:statusor", ], ) -tsl_gpu_library( - name = "runtime_intrinsics", - srcs = ["runtime_intrinsics.cc"], - hdrs = ["runtime_intrinsics.h"], +cc_library( + name = "gpu_hlo_schedule", + srcs = ["gpu_hlo_schedule.cc"], + hdrs = ["gpu_hlo_schedule.h"], deps = [ + ":backend_configs_cc", + ":gpu_latency_hiding_scheduler", "//xla:shape_util", "//xla:util", - "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:buffer_value", "//xla/service:collective_ops_utils", - "//xla/service:custom_call_status", - "//xla/service:custom_call_target_registry", - "//xla/service:platform_util", - "//xla/stream_executor", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_pass_pipeline", + "//xla/service:latency_hiding_scheduler", + "//xla/service:p2p_schedule_preparation", + "//xla/service:profile_guided_latency_estimator", + "//xla/service/gpu/model:analytical_latency_estimator", + "//xla/service/gpu/transforms:pgle_accuracy_checker", + "//xla/service/gpu/transforms:schedule_postprocessing", + "//xla/service/gpu/transforms:scheduling_instruction_annotator", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:traceme", + "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ], - alwayslink = 1, ) xla_test( - name = "runtime_intrinsics_test", - srcs = ["runtime_intrinsics_test.cc"], + name = "gpu_hlo_schedule_test", + srcs = [ + "gpu_hlo_schedule_test.cc", + ], backends = ["gpu"], deps = [ - ":runtime_intrinsics", + ":gpu_hlo_schedule", + "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:backend", + "//xla/service:hlo_module_config", + "//xla/service:hlo_ordering", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ], ) cc_library( - name = "hlo_fusion_stats", - srcs = ["hlo_fusion_stats.cc"], - hdrs = ["hlo_fusion_stats.h"], + name = "gpu_p2p_pipeliner", + srcs = ["gpu_p2p_pipeliner.cc"], + hdrs = ["gpu_p2p_pipeliner.h"], deps = [ + "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:collective_pipeliner", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass_pipeline", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", ], ) xla_cc_test( - name = "hlo_fusion_stats_test", - srcs = ["hlo_fusion_stats_test.cc"], - tags = [ - "nomsan", + name = "gpu_p2p_pipeliner_test", + srcs = [ + "gpu_p2p_pipeliner_test.cc", ], deps = [ - ":hlo_fusion_stats", + ":gpu_p2p_pipeliner", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", + "//xla/service:hlo_pass_pipeline", + "//xla/service:hlo_verifier", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", ], ) cc_library( - name = "scatter_slice_simplifier", - srcs = ["scatter_slice_simplifier.cc"], - hdrs = ["scatter_slice_simplifier.h"], + name = "gpu_spmd_pipeline", + srcs = ["gpu_spmd_pipeline.cc"], + hdrs = ["gpu_spmd_pipeline.h"], deps = [ - "//xla:shape_util", - "//xla:util", + ":runtime_intrinsics", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", + "//xla/hlo/transforms:hlo_constant_splitter", + "//xla/service:algebraic_simplifier", + "//xla/service:conditional_simplifier", + "//xla/service:gather_expander", + "//xla/service:hlo_constant_folding", + "//xla/service:hlo_dce", + "//xla/service:hlo_module_config", "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "//xla/service:hlo_pass_pipeline", + "//xla/service:reshape_mover", + "//xla/service:scatter_expander", + "//xla/service:sharding_propagation", + "//xla/service:sort_simplifier", + "//xla/service:tuple_simplifier", + "//xla/service:while_loop_constant_sinking", + "//xla/service:while_loop_simplifier", + "//xla/service/gpu/transforms:algebraic_simplifier", + "//xla/service/spmd:collective_permute_motion", + "//xla/service/spmd:stateful_rng_spmd_partitioner", + "//xla/service/spmd/shardy:shardy_xla_pass", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/log:check", ], ) xla_cc_test( - name = "scatter_slice_simplifier_test", - srcs = ["scatter_slice_simplifier_test.cc"], - deps = [ - ":scatter_slice_simplifier", - "//xla:shape_util", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest_main", + name = "gpu_spmd_pipeline_test", + srcs = [ + "gpu_spmd_pipeline_test.cc", ], -) - -cc_library( - name = "conv_layout_normalization", - srcs = ["conv_layout_normalization.cc"], - hdrs = ["conv_layout_normalization.h"], deps = [ - ":cublas_cudnn", + ":gpu_spmd_pipeline", "//xla:shape_util", - "//xla:status_macros", "//xla:util", + "//xla/client:executable_build_options", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:protobuf", + "//xla/service:algebraic_simplifier", + "//xla/service:hlo_module_config", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass_pipeline", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) -cc_library( - name = "topk_specializer", - srcs = ["topk_specializer.cc"], - hdrs = ["topk_specializer.h"], +xla_cc_test( + name = "while_transformer_test", + srcs = ["while_transformer_test.cc"], + tags = [ + "nomsan", + ], deps = [ + "//xla:comparison_util", + "//xla:literal_util", "//xla:shape_util", - "//xla:status_macros", - "//xla:util", + "//xla:test", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:hlo_proto_cc", - "//xla/service:tuple_util", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", + "//xla/service:while_loop_analysis", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", ], ) +cuda_library( + name = "stream_executor_util_kernel", + srcs = ["stream_executor_util_kernel.cu.cc"], + tags = ["no_rocm"], + deps = ["@local_config_cuda//cuda:cuda_headers"], +) + cc_library( - name = "topk_splitter", - srcs = ["topk_splitter.cc"], - hdrs = ["topk_splitter.h"], + name = "stream_executor_util", + srcs = ["stream_executor_util.cc"], + hdrs = ["stream_executor_util.h"], + copts = tsl_copts(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ + ":cublas_cudnn", + ":launch_dimensions", + "//xla:autotuning_proto_cc", "//xla:shape_util", + "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:hlo_module_config", + "//xla/stream_executor", + "//xla/stream_executor:data_type", + "//xla/stream_executor:dnn", + "//xla/stream_executor:kernel_spec", + "//xla/stream_executor:launch_dim", + "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/util:env_var", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", - "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", "@com_google_absl//absl/types:span", + "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", - ], + ] + if_cuda_is_configured([ + ":stream_executor_util_kernel", + ]), ) xla_cc_test( - name = "topk_splitter_test", - srcs = ["topk_splitter_test.cc"], + name = "stream_executor_util_test", + srcs = ["stream_executor_util_test.cc"], deps = [ - ":topk_splitter", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_dce", - "//xla/service:pattern_matcher", - "//xla/service:topk_rewriter", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", + ":stream_executor_util", + "//xla:autotuning_proto_cc", + "//xla/service:hlo_module_config", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/time", + "@com_google_googletest//:gtest_main", ], ) -xla_test( - name = "topk_test", - srcs = ["topk_test.cc"], - backends = ["gpu"], +cc_library( + name = "gpu_asm_opts_util", + srcs = ["gpu_asm_opts_util.cc"], + hdrs = ["gpu_asm_opts_util.h"], + compatible_with = get_compatible_with_portable(), + copts = tsl_copts(), deps = [ - ":topk_specializer", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service:platform_util", - "//xla/service:topk_rewriter", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "//xla:xla_proto_cc", + "//xla/stream_executor/gpu:gpu_asm_opts", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", ], ) cc_library( - name = "copy_fusion", - srcs = ["copy_fusion.cc"], - hdrs = ["copy_fusion.h"], + name = "hlo_fusion_analysis", + srcs = ["hlo_fusion_analysis.cc"], + hdrs = ["hlo_fusion_analysis.h"], + compatible_with = get_compatible_with_portable(), deps = [ - ":gpu_fusible", + ":backend_configs_cc", ":hlo_traversal", ":ir_emission_utils", ":reduction_utils", + "//xla:shape_util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - ], -) - -cc_library( - name = "algorithm_checker", - srcs = ["algorithm_checker.cc"], - hdrs = ["algorithm_checker.h"], - deps = [ - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:algorithm_util", - "//xla/service:hlo_pass", "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", - "@com_google_absl//absl/status", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", ], ) -xla_test( - name = "dot_algorithm_support_test", - srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]), - backends = [ - "gpu_v100", - "gpu_a100", - "gpu_amd_any", - ], - tags = [ - "nomac", - ], +xla_cc_test( + name = "hlo_fusion_analysis_test", + srcs = ["hlo_fusion_analysis_test.cc"], deps = [ - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", + ":backend_configs_cc", + ":gpu_device_info_for_tests", + ":hlo_fusion_analysis", + ":hlo_traversal", + ":ir_emission_utils", + "//xla:protobuf_util", "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", + "//xla/tests:xla_internal_test_main", "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "kernel_reuse_cache", - srcs = ["kernel_reuse_cache.cc"], - hdrs = ["kernel_reuse_cache.h"], - deps = [ - ":executable_proto_cc", - ":kernel_arguments", + name = "buffer_comparator", + srcs = if_gpu_is_configured(["buffer_comparator.cc"]), + hdrs = if_gpu_is_configured(["buffer_comparator.h"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = if_gpu_is_configured([ + # keep sorted + ":buffer_comparator_kernel", + ":gpu_asm_opts_util", ":launch_dimensions", + "//xla:shape_util", "//xla:status_macros", "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/stream_executor:launch_dim", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", + "//xla/service:hlo_module_config", + "//xla/stream_executor", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:asm_compiler", + "@com_google_absl//absl/base", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:env", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - ], + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + ]) + if_rocm_is_configured([ + # keep sorted + "@local_config_rocm//rocm:rocm_headers", + ]), ) -xla_cc_test( - name = "kernel_reuse_cache_test", - srcs = ["kernel_reuse_cache_test.cc"], - deps = [ - ":executable_proto_cc", - ":kernel_reuse_cache", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/log:check", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:test", - ], +gpu_kernel_library( + name = "buffer_comparator_kernel", + srcs = if_gpu_is_configured(["buffer_comparator.cu.cc"]), + copts = rocm_copts(), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), ) -cc_library( - name = "kernel_arguments", - srcs = ["kernel_arguments.cc"], - hdrs = ["kernel_arguments.h"], +xla_test( + name = "buffer_comparator_test", + srcs = if_gpu_is_configured(["buffer_comparator_test.cc"]), + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), deps = [ - ":gpu_constants", + ":stream_executor_util", "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], + "//xla:types", + "//xla/service:hlo_module_config", + "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:device_memory_handle", + "//xla/stream_executor:platform_manager", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ] + if_gpu_is_configured([ + ":buffer_comparator", + "//xla/stream_executor:device_memory", + ]), ) cc_library( - name = "hlo_traversal", - srcs = ["hlo_traversal.cc"], - hdrs = ["hlo_traversal.h"], - compatible_with = get_compatible_with_portable(), + name = "buffer_sharing", + srcs = ["buffer_sharing.cc"], + hdrs = ["buffer_sharing.h"], deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":hlo_fusion_analysis", + ":ir_emission_utils", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", ], ) -xla_cc_test( - name = "hlo_traversal_test", - srcs = ["hlo_traversal_test.cc"], +cc_library( + name = "gpu_fusible", + srcs = ["gpu_fusible.cc"], + hdrs = ["gpu_fusible.h"], + compatible_with = get_compatible_with_portable(), deps = [ + ":backend_configs_cc", + ":hlo_fusion_analysis", ":hlo_traversal", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:hlo_test_base", + ":ir_emission_utils", + ":launch_dimensions", + ":reduction_utils", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:instruction_fusion", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/synchronization", ], ) -cc_library( - name = "fusion_wrapper", - srcs = ["fusion_wrapper.cc"], - hdrs = ["fusion_wrapper.h"], +xla_cc_test( + name = "gpu_fusible_test", + srcs = ["gpu_fusible_test.cc"], + tags = [ + "nomsan", + ], deps = [ ":gpu_fusible", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", ], ) -xla_cc_test( - name = "fusion_wrapper_test", - srcs = ["fusion_wrapper_test.cc"], +xla_test( + name = "float_support_test", + srcs = ["float_support_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm80", + "no_rocm" + ]}, + backends = [ + "gpu", + ], deps = [ - ":fusion_wrapper", + ":variant_visitor", + "//xla:error_spec", + "//xla:xla_proto_cc", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_googletest//:gtest_main", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], ) -xla_cc_test( - name = "copy_fusion_test", - srcs = ["copy_fusion_test.cc"], +xla_test( + name = "conv_layout_normalization_test", + srcs = ["conv_layout_normalization_test.cc"], + backends = ["gpu"], deps = [ - ":copy_fusion", + "//xla:error_spec", "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu/tests:gpu_codegen_test", # fixdeps: keep "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", + "//xla/tests:test_macros_header", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) -xla_cc_test( - name = "autotuner_util_test", - srcs = if_cuda_is_configured(["autotuner_util_test.cc"]), - data = [ - "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb", - "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", - "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", - ], - deps = if_cuda_is_configured([ - # keep sorted - ":autotuner_util", - "//xla:autotune_results_proto_cc", +cc_library( + name = "hlo_algorithm_denylist", + srcs = ["hlo_algorithm_denylist.cc"], + hdrs = ["hlo_algorithm_denylist.h"], + deps = [ + ":backend_configs_cc", "//xla:autotuning_proto_cc", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_description_proto_cc", - "//xla/stream_executor:platform", - "//xla/stream_executor:platform_manager", - "//xla/stream_executor/host:host_platform", - "//xla/tests:hlo_test_base", - "//xla/tests:verified_hlo_module", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/container:flat_hash_set", + "//xla:debug_options_flags", + "//xla/hlo/ir:backend_config", + "//xla/service/gpu/autotuning:gpu_autotuning_proto_cc", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/log:scoped_mock_log", - "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ]) + [ - "//xla/tests:xla_internal_test_main", # Keep outside GPU guard ], ) -cc_library( - name = "double_buffer_loop_unrolling", - srcs = ["double_buffer_loop_unrolling.cc"], - hdrs = ["double_buffer_loop_unrolling.h"], +xla_cc_test( + name = "hlo_algorithm_denylist_test", + srcs = ["hlo_algorithm_denylist_test.cc"], + data = ["data/hlo_algorithm_denylist.pbtxt"], deps = [ - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/ir:hlo_instruction_utils", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_ops_utils", - "//xla/service:flatten_call_graph", - "//xla/service:hlo_pass", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", + ":hlo_algorithm_denylist", + "//xla/stream_executor:dnn", + "//xla/tests:test_utils", "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", ], ) xla_cc_test( - name = "double_buffer_loop_unrolling_test", - srcs = ["double_buffer_loop_unrolling_test.cc"], + name = "gpu_float_support_test", + srcs = ["gpu_float_support_test.cc"], deps = [ - ":double_buffer_loop_unrolling", - "//xla:test", + ":backend_configs_cc", + ":gpu_float_support", + ":ir_emission_utils", + "//xla:shape_util", + "//xla:test_helpers", "//xla:xla_data_proto_cc", - "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:tuple_simplifier", - "//xla/tests:filecheck", + "//xla/service:float_normalization", + "//xla/service:hlo_verifier", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_test( - name = "determinism_test", - srcs = if_gpu_is_configured(["determinism_test.cc"]), - backends = [ - "gpu_a100", - "gpu_amd_any", - ], - tags = [ - "no_rocm", #TODO(rocm): TEMP, sync 24-06-24 + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = if_gpu_is_configured( - [ - ":autotuner_util", - "@com_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "//xla:literal", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/gpu/tests:gpu_codegen_test", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:literal_test_util", - "//xla/stream_executor/gpu:gpu_timer", - "//xla/tests:test_utils", - "@local_tsl//tsl/platform:statusor", - ], - ["@local_tsl//tsl/platform:test_main"], # b/317293391 - ), ) cc_library( - name = "gpu_symbol_repository", - hdrs = ["gpu_symbol_repository.h"], + name = "metrics", + srcs = ["metrics.cc"], + hdrs = ["metrics.h"], deps = [ - "//xla:autotune_results_proto_cc", - "//xla:xla_proto_cc", - "//xla/service:symbol_repository", + "@local_tsl//tsl/lib/monitoring:counter", + "@local_tsl//tsl/lib/monitoring:gauge", + "@local_tsl//tsl/lib/monitoring:sampler", ], ) cc_library( - name = "collective_permute_cycle_decomposer", - srcs = ["collective_permute_cycle_decomposer.cc"], - hdrs = ["collective_permute_cycle_decomposer.h"], + name = "make_batch_pointers", + srcs = if_gpu_is_configured(["make_batch_pointers.cc"]), + hdrs = if_gpu_is_configured(["make_batch_pointers.h"]), deps = [ - ":backend_configs_cc", - "//xla:comparison_util", - "//xla:literal_util", - "//xla:shape_util", + "//xla:types", "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_parser", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", + "//xla/stream_executor", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/gpu:gpu_stream_header", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", - ], + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + ":make_batch_pointers_kernel", + ]) + if_rocm_is_configured([ + "//xla/stream_executor/rocm:rocm_helpers", + ]), ) -xla_cc_test( - name = "collective_permute_cycle_decomposer_test", - srcs = ["collective_permute_cycle_decomposer_test.cc"], +cuda_library( + name = "make_batch_pointers_kernel", + srcs = if_cuda_is_configured(["make_batch_pointers.cu.cc"]), deps = [ - ":collective_permute_cycle_decomposer", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_parser", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@local_config_cuda//cuda:cuda_headers", # build_cleaner: keep ], ) -cc_library( - name = "collective_permute_valid_iteration_annotator", - srcs = ["collective_permute_valid_iteration_annotator.cc"], - hdrs = ["collective_permute_valid_iteration_annotator.h"], +tsl_gpu_library( + name = "runtime_intrinsics", + srcs = ["runtime_intrinsics.cc"], + hdrs = ["runtime_intrinsics.h"], deps = [ - "//xla:literal_util", - "//xla/hlo/ir:hlo", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/service:while_loop_analysis", + "//xla/service:custom_call_status", + "//xla/service:custom_call_target_registry", + "//xla/service:platform_util", + "//xla/stream_executor", + "//xla/stream_executor:stream_finder", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], + alwayslink = 1, ) -xla_cc_test( - name = "collective_permute_valid_iteration_annotator_test", - srcs = ["collective_permute_valid_iteration_annotator_test.cc"], +xla_test( + name = "runtime_intrinsics_test", + srcs = ["runtime_intrinsics_test.cc"], + backends = ["gpu"], deps = [ - ":collective_permute_valid_iteration_annotator", + ":runtime_intrinsics", "//xla/hlo/ir:hlo", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass_pipeline", - "//xla/service:while_loop_trip_count_annotator", "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:test_main", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "stream_attribute_annotator", - srcs = ["stream_attribute_annotator.cc"], - hdrs = ["stream_attribute_annotator.h"], + name = "hlo_fusion_stats", + srcs = ["hlo_fusion_stats.cc"], + hdrs = ["hlo_fusion_stats.h"], deps = [ - ":backend_configs_cc", - ":gpu_fusible", - "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_pass", - "//xla/service/gpu/runtime:thunk", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "stream_attribute_annotator_test", - srcs = ["stream_attribute_annotator_test.cc"], + name = "hlo_fusion_stats_test", + srcs = ["hlo_fusion_stats_test.cc"], + tags = [ + "nomsan", + ], deps = [ - ":backend_configs_cc", - ":stream_attribute_annotator", - "//xla/hlo/ir:hlo", + ":hlo_fusion_stats", + "//xla/service:hlo_parser", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/strings:string_view", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", ], ) cc_library( - name = "stream_attribute_async_wrapper", - srcs = ["stream_attribute_async_wrapper.cc"], - hdrs = ["stream_attribute_async_wrapper.h"], + name = "conv_layout_normalization", + srcs = ["conv_layout_normalization.cc"], + hdrs = ["conv_layout_normalization.h"], deps = [ - ":backend_configs_cc", + ":cublas_cudnn", + "//xla:shape_util", + "//xla:status_macros", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "//xla/service/gpu/runtime:thunk", - "@com_google_absl//absl/container:flat_hash_set", + "//xla/service:hlo_creation_utils", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:statusor", ], ) -xla_cc_test( - name = "stream_attribute_async_wrapper_test", - srcs = ["stream_attribute_async_wrapper_test.cc"], +xla_test( + name = "dot_algorithm_support_test", + srcs = if_gpu_is_configured(["dot_algorithm_support_test.cc"]), + backends = [ + "gpu_v100", + "gpu_a100", + "gpu_h100", + "gpu_amd_any", + ], + tags = [ + "nomac", + ], deps = [ - ":backend_configs_cc", - ":stream_attribute_async_wrapper", + "//xla:shape_util", + "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", ], ) cc_library( - name = "gpu_windowed_einsum_handler", - srcs = ["gpu_windowed_einsum_handler.cc"], - hdrs = ["gpu_windowed_einsum_handler.h"], + name = "kernel_reuse_cache", + srcs = ["kernel_reuse_cache.cc"], + hdrs = ["kernel_reuse_cache.h"], deps = [ - ":backend_configs_cc", - "//xla:literal_util", + ":executable_proto_cc", + ":kernel_arguments", + ":launch_dimensions", + "//xla:status_macros", "//xla:util", - "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:hlo_creation_utils", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/service:shape_inference", - "@com_google_absl//absl/algorithm:container", + "//xla/stream_executor:launch_dim", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "gpu_windowed_einsum_handler_test", - srcs = ["gpu_windowed_einsum_handler_test.cc"], + name = "kernel_reuse_cache_test", + srcs = ["kernel_reuse_cache_test.cc"], deps = [ - ":backend_configs_cc", - ":gpu_windowed_einsum_handler", - "//xla/hlo/ir:hlo", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", + ":executable_proto_cc", + ":kernel_reuse_cache", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:test", ], ) cc_library( - name = "triton_fusion_numerics_verifier", - srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier.cc"]), - hdrs = if_gpu_is_configured(["triton_fusion_numerics_verifier.h"]), - deps = if_gpu_is_configured([ - ":autotuner_compile_util", - ":autotuner_util", - ":backend_configs_cc", - ":buffer_comparator", - ":ir_emission_utils", + name = "kernel_arguments", + srcs = ["kernel_arguments.cc"], + hdrs = ["kernel_arguments.h"], + deps = [ + ":gpu_constants", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:executable", - "//xla/service:hlo_pass", - "//xla/service:shaped_buffer", - "//xla/service:hlo_module_config", - "//xla/stream_executor:stream", - "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", - ]), -) - -xla_test( - name = "triton_fusion_numerics_verifier_test", - srcs = if_gpu_is_configured(["triton_fusion_numerics_verifier_test.cc"]), - backend_tags = {"gpu": [ - "requires-gpu-sm80", - ]}, - tags = ["no_rocm"], - backends = ["gpu"], - deps = [ - ":autotuner_compile_util", - ":autotuner_util", - ":triton_fusion_numerics_verifier", - "//xla:shape_util", - "//xla:test_helpers", - "//xla/hlo/ir:hlo", - "//xla/service:platform_util", - "//xla/stream_executor:platform", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", ], ) cc_library( - name = "pipelined_p2p_rewriter", - srcs = ["pipelined_p2p_rewriter.cc"], - hdrs = ["pipelined_p2p_rewriter.h"], + name = "hlo_traversal", + srcs = ["hlo_traversal.cc"], + hdrs = ["hlo_traversal.h"], + compatible_with = get_compatible_with_portable(), deps = [ "//xla:shape_util", - "//xla:util", "//xla/hlo/ir:hlo", - "//xla/hlo/utils:hlo_query", - "//xla/service:collective_ops_utils", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", ], ) xla_cc_test( - name = "pipelined_p2p_rewriter_test", - srcs = ["pipelined_p2p_rewriter_test.cc"], + name = "hlo_traversal_test", + srcs = ["hlo_traversal_test.cc"], deps = [ - ":pipelined_p2p_rewriter", + ":hlo_traversal", "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", + "@com_google_googletest//:gtest_main", + ], +) + +xla_test( + name = "determinism_test", + srcs = if_gpu_is_configured(["determinism_test.cc"]), + backends = [ + "gpu_a100", + "gpu_amd_any", + ], + tags = [ + "no_rocm", #TODO(rocm): TEMP, weekly sync + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = if_gpu_is_configured( + [ + "//xla/service/gpu/autotuning:autotuner_util", + "@com_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "//xla:literal", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", + "//xla/stream_executor/gpu:gpu_timer", + "//xla/tests:test_utils", + "@local_tsl//tsl/platform:statusor", + ], + ["@local_tsl//tsl/platform:test_main"], # b/317293391 + ), +) + +cc_library( + name = "gpu_symbol_repository", + hdrs = ["gpu_symbol_repository.h"], + deps = [ + "//xla:autotune_results_proto_cc", + "//xla:xla_proto_cc", + "//xla/service:symbol_repository", ], ) @@ -6142,41 +2958,12 @@ xla_cc_test( "//xla/service:profile_guided_latency_estimator", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], ) - -cc_library( - name = "scheduling_instruction_annotator", - srcs = ["scheduling_instruction_annotator.cc"], - hdrs = ["scheduling_instruction_annotator.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service:hlo_pass", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scheduling_instruction_annotator_test", - srcs = ["scheduling_instruction_annotator_test.cc"], - deps = [ - ":scheduling_instruction_annotator", - "//xla/hlo/ir:hlo", - "//xla/tests:filecheck", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc index 04483ce86c78c9..ae541ba167f582 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.cc @@ -32,21 +32,21 @@ limitations under the License. #include "xla/service/dot_dimension_merger.h" #include "xla/service/float_normalization.h" #include "xla/service/float_support.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/conv_algorithm_picker.h" -#include "xla/service/gpu/cublas_pad_for_gemms.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/autotuning/conv_algorithm_picker.h" +#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h" #include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/cudnn_fused_conv_rewriter.h" -#include "xla/service/gpu/cusolver_rewriter.h" -#include "xla/service/gpu/gemm_algorithm_picker.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/gpu/gpu_conv_padding_legalization.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/target_constants.h" -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" +#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" +#include "xla/service/gpu/transforms/gpusolver_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" @@ -123,8 +123,8 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(&conv_bf16_support); pipeline.AddPass(); - pipeline.AddPass(gpu_version); - pipeline.AddPass(); + pipeline.AddPass(gpu_version); + pipeline.AddPass(); auto rcc = std::get(gpu_version); pipeline.AddPass(rcc, dnn_version, GetToolkitVersion()); @@ -135,7 +135,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(); pipeline.AddPass(); - // tf2xla bridge, DepthwiseConvolutionConverter and GpuConvRewriter + // tf2xla bridge, DepthwiseConvolutionConverter and ConvRewriter // introduces reshapes and transposes that can be eliminated using // AlgebraicSimplifier We run algsimp to a fixed point. AlgebraicSimplifierOptions options = @@ -144,7 +144,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( options.set_enable_unconditional_reduce_of_concat_replacement(false); pipeline.AddPass>(options, gpu_version); - // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and + // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover // to a fixed point. Include algsimp because ReshapeMover relies on it. [&, &pipeline = pipeline.AddPass>( @@ -166,7 +166,7 @@ absl::Status AMDGPUCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(options, gpu_version); }(); - // GpuConvRewriter, GpuConvPaddingLegalization and + // ConvRewriter, ConvPaddingLegalization and // CudnnConvPadForTensorCores may add instructions which can be simplified // by constant folding. pipeline.AddPass(); @@ -240,7 +240,7 @@ absl::Status AMDGPUCompiler::AddConvAndGemmAutotuningPasses( absl::Status AMDGPUCompiler::AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { if (debug_options.xla_gpu_enable_cub_radix_sort()) { - pipeline->AddPass(); + pipeline->AddPass(); } return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/gpu/amdgpu_compiler.h b/third_party/xla/xla/service/gpu/amdgpu_compiler.h index 483647bbdfdadd..062a0ef6ca8363 100644 --- a/third_party/xla/xla/service/gpu/amdgpu_compiler.h +++ b/third_party/xla/xla/service/gpu/amdgpu_compiler.h @@ -20,7 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "llvm/IR/Module.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" diff --git a/third_party/xla/xla/service/gpu/autotuning/BUILD b/third_party/xla/xla/service/gpu/autotuning/BUILD new file mode 100644 index 00000000000000..aa82b8678bdb1d --- /dev/null +++ b/third_party/xla/xla/service/gpu/autotuning/BUILD @@ -0,0 +1,542 @@ +# Description: +# Components that implement GPU autotuning. + +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tf_proto_library", +) +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load("//xla:xla.bzl", "xla_cc_test") +load("//xla/tests:build_defs.bzl", "xla_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +cc_library( + name = "gemm_fusion_autotuner", + srcs = ["gemm_fusion_autotuner.cc"], + hdrs = ["gemm_fusion_autotuner.h"], + tags = [ + "gpu", + "no_rocm", + ], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:algorithm_util", + "//xla/service:dump", + "//xla/service:executable", + "//xla/service:float_normalization", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:shaped_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:gpu_float_support", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:split_k_gemm_rewriter", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/transforms:cudnn_fusion_compiler", + "//xla/service/gpu/transforms:fusion_wrapper", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:priority_fusion", + "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", + "//xla/tools:hlo_decomposer_lib", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/lib/core:bits", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + ], +) + +xla_test( + name = "gemm_fusion_autotuner_test", + timeout = "long", + srcs = ["gemm_fusion_autotuner_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + backends = [ + "gpu", + ], + tags = [ + "no_rocm", + "nomac", + ], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + ":gemm_fusion_autotuner", + "//xla:autotuning_proto_cc", + "//xla:error_spec", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/pjrt/distributed:key_value_store_interface", + "//xla/service:call_inliner", + "//xla/service:dump", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass_pipeline", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/transforms:gemm_fusion", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tools:hlo_decomposer_lib", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest", + "@local_config_cuda//cuda:cuda_headers", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gemm_algorithm_picker", + srcs = ["gemm_algorithm_picker.cc"], + hdrs = ["gemm_algorithm_picker.h"], + tags = ["gpu"], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu:variant_visitor", + "//xla/stream_executor", + "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/profiler/lib:scoped_annotation", + ], +) + +cc_library( + name = "autotuner_util", + srcs = ["autotuner_util.cc"], + hdrs = ["autotuner_util.h"], + tags = ["gpu"], + deps = [ + "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:dump", + "//xla/service/gpu:gpu_asm_opts_util", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:base64", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:statusor", + ], +) + +# We need a separate target, as runtime executable cannot depend on compilation +# pipeline. +cc_library( + name = "autotuner_compile_util", + srcs = ["autotuner_compile_util.cc"], + hdrs = ["autotuner_compile_util.h"], + tags = ["gpu"], + deps = [ + ":autotuner_util", + "//xla:executable_run_options", + "//xla:shape_util", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:compiler", + "//xla/service:executable", + "//xla/service:maybe_owning_device_memory", + "//xla/service:shaped_buffer", + "//xla/service/gpu:gpu_executable_run_options", + "//xla/service/gpu:ir_emission_utils", + "//xla/stream_executor", + "//xla/stream_executor/gpu:redzone_allocator", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "autotuner_compile_util_test", + srcs = ["autotuner_compile_util_test.cc"], + backends = ["gpu"], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "gemm_algorithm_picker_test", + srcs = ["gemm_algorithm_picker_test.cc"], + backends = [ + "gpu_v100", + "gpu_amd_any", + ], + deps = [ + ":autotuner_util", + ":gemm_algorithm_picker", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:platform_util", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:variant_visitor", + "//xla/service/gpu/transforms:gemm_rewriter", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/protobuf:dnn_proto_cc", + ], +) + +cc_library( + name = "conv_algorithm_picker", + srcs = ["conv_algorithm_picker.cc"], + hdrs = ["conv_algorithm_picker.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + tags = ["gpu"], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + ":gpu_autotuning_proto_cc", + "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:debug_options_flags", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:slow_operation_alarm", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:gpu_asm_opts_util", + "//xla/service/gpu:gpu_conv_runner", + "//xla/service/gpu:hlo_algorithm_denylist", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:dnn", + "//xla/stream_executor:lazy_op_runner", + "//xla/stream_executor:numeric_options", + "//xla/stream_executor:scratch_allocator", + "//xla/stream_executor/cuda:cuda_platform_id", + "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/util:env_var", + "//xla/tsl/util/proto:proto_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:numbers", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + # keep sorted + "//xla/service/gpu:buffer_comparator", + "//xla/stream_executor/gpu:redzone_allocator", + "@local_config_cuda//cuda:cudnn_header", + ]), +) + +xla_test( + name = "conv_algorithm_picker_test", + srcs = ["conv_algorithm_picker_test.cc"], + backends = [ + "gpu_v100", + "gpu_amd_any", + ], + tags = [ + "noasan", + "nomsan", + ], + deps = [ + ":autotuner_util", + ":conv_algorithm_picker", + "//xla:debug_options_flags", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:platform_util", + "//xla/service:tuple_simplifier", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/transforms:conv_rewriter", + "//xla/stream_executor:device_description", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "custom_kernel_fusion_autotuner", + srcs = ["custom_kernel_fusion_autotuner.cc"], + hdrs = ["custom_kernel_fusion_autotuner.h"], + tags = ["gpu"], + deps = [ + ":autotuner_compile_util", + ":autotuner_util", + "//xla:autotuning_proto_cc", + "//xla:status_macros", + "//xla:util", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_pass", + "//xla/service:shaped_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:gpu_float_support", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:split_k_gemm_rewriter", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/kernels:custom_kernel", + "//xla/service/gpu/kernels:custom_kernel_fusion", + "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/stream_executor/gpu:redzone_allocator", + "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/time", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "custom_kernel_fusion_autotuner_test", + srcs = ["custom_kernel_fusion_autotuner_test.cc"], + backends = [ + "gpu", + ], + tags = ["no_rocm"], + deps = [ + ":autotuner_util", + ":custom_kernel_fusion_autotuner", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass_pipeline", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:test", + ], +) + +tf_proto_library( + name = "gpu_autotuning_proto", + srcs = ["gpu_autotuning.proto"], + cc_api_version = 2, + protodeps = [ + "//xla/service/gpu:backend_configs", + "//xla:xla_data_proto", + "//xla/service:hlo_proto", + "//xla:autotuning_proto", + ], +) + +xla_cc_test( + name = "autotuner_util_test", + srcs = ["autotuner_util_test.cc"], + data = [ + "//xla/tools/hlo_opt:gpu_specs/a100_sxm_40.txtpb", + "//xla/tools/hlo_opt:gpu_specs/a100_sxm_80.txtpb", + "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", + ], + tags = [ + "gpu", + "no_rocm", + ], + deps = [ + ":autotuner_util", + "//xla:autotune_results_proto_cc", + "//xla:autotuning_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:dump", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "//xla/stream_executor/host:host_platform", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/log:scoped_mock_log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc similarity index 98% rename from third_party/xla/xla/service/gpu/autotuner_compile_util.cc rename to third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc index b3e880b08ac01f..9922ea247a4e44 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include #include @@ -34,7 +34,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_executable_run_options.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/maybe_owning_device_memory.h" diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h similarity index 96% rename from third_party/xla/xla/service/gpu/autotuner_compile_util.h rename to third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h index 5137fcf95b43a0..02b1bfa301fdbe 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ -#define XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_ #include #include @@ -33,7 +33,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/stream_executor/device_memory_allocator.h" @@ -174,4 +174,4 @@ class RedzoneBuffers { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_AUTOTUNER_COMPILE_UTIL_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_COMPILE_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc rename to third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc index 1db5afb8988222..a8b959482ebba0 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_compile_util_test.cc @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" #include #include #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" diff --git a/third_party/xla/xla/service/gpu/autotuner_util.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc similarity index 98% rename from third_party/xla/xla/service/gpu/autotuner_util.cc rename to third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc index 93c946f57b6462..79bb7441ea636e 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include #include @@ -43,6 +43,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/dump.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/shape.h" @@ -130,6 +131,9 @@ absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key, return absl::OkStatus(); } + tsl::Env* default_env = tsl::Env::Default(); + TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(cache_dir), default_env)); + TF_ASSIGN_OR_RETURN(const std::string file_path, GetCacheFilePath(cache_dir, key)); @@ -145,7 +149,6 @@ absl::Status AddResultToFileBasedCacheIfEnabled(const AutotuneCacheKey& key, // file. Also avoids reading incomplete files. (This may not work on all file // systems.) std::string temp_file_path = tsl::io::GetTempFilename(".textproto"); - tsl::Env* default_env = tsl::Env::Default(); TF_RETURN_IF_ERROR( tsl::WriteStringToFile(default_env, temp_file_path, result_str)); return default_env->RenameFile(temp_file_path, file_path); @@ -293,6 +296,7 @@ std::string ToCanonicalString(const HloInstruction* instr) { auto options = HloPrintOptions::Canonical(); if (instr->opcode() != HloOpcode::kFusion) { options.set_print_backend_config(true); + options.set_sort_backend_config(true); return instr->ToString(options); } options.set_print_subcomputation_mode( diff --git a/third_party/xla/xla/service/gpu/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h similarity index 98% rename from third_party/xla/xla/service/gpu/autotuner_util.h rename to third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index 4634fc21b44fa4..6d5c32182410b1 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ -#define XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_ #include #include @@ -331,4 +331,4 @@ absl::StatusOr GetBase64EncodedSha256Hash(absl::string_view s); } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_AUTOTUNER_UTIL_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_AUTOTUNER_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/autotuner_util_test.cc b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc similarity index 94% rename from third_party/xla/xla/service/gpu/autotuner_util_test.cc rename to third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc index 37fb56ed67fb83..974f4d4d2816c2 100644 --- a/third_party/xla/xla/service/gpu/autotuner_util_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include #include @@ -33,11 +33,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/dump.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -110,11 +111,10 @@ results { return str; } - static std::unique_ptr NewStreamExecutor() { + static stream_executor::StreamExecutor* NewStreamExecutor() { stream_executor::Platform* platform = stream_executor::PlatformManager::PlatformWithName("Host").value(); - stream_executor::StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetUncachedExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } absl::Status PopulateResultCache() { @@ -211,11 +211,10 @@ TEST_F(AutotunerUtilTest, FailIfRequireCompleteAotAutotuning) { ->MakeNonfusionComputations(absl::flat_hash_set()); EXPECT_THAT(computations, Not(IsEmpty())); const HloInstruction* instruction = *computations[0]->instructions().begin(); - std::unique_ptr executor = - NewStreamExecutor(); + stream_executor::StreamExecutor* executor = NewStreamExecutor(); auto options = DebugOptions(); options.set_xla_gpu_require_complete_aot_autotune_results(true); - AutotuneConfig config(DeviceConfig{executor.get()}, options); + AutotuneConfig config(DeviceConfig{executor}, options); EXPECT_THAT( AutotunerUtil::Autotune(instruction, config, [&] { return AutotuneResult(); }), @@ -234,12 +233,11 @@ TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) { ->MakeNonfusionComputations(absl::flat_hash_set()); EXPECT_THAT(computations, Not(IsEmpty())); const HloInstruction* instruction = *computations[0]->instructions().begin(); - std::unique_ptr executor = - NewStreamExecutor(); + stream_executor::StreamExecutor* executor = NewStreamExecutor(); { // By default, JIT autotuning is OK. - AutotuneConfig config(DeviceConfig{executor.get()}, DebugOptions()); + AutotuneConfig config(DeviceConfig{executor}, DebugOptions()); TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] { return AutotuneResult(); }).status()); @@ -249,7 +247,7 @@ TEST_F(AutotunerUtilTest, OkIfJitAutotuningDisabledButAlreadyLoadedAOT) { auto options = DebugOptions(); options.set_xla_gpu_require_complete_aot_autotune_results(true); - AutotuneConfig config(DeviceConfig{executor.get()}, options); + AutotuneConfig config(DeviceConfig{executor}, options); // Even though JIT autotuning is disabled, there is no cache miss when running // autotuning for the same entry, so no error should be raised either. TF_EXPECT_OK(AutotunerUtil::Autotune(instruction, config, [&] { @@ -280,14 +278,14 @@ class FileBasedCacheTest : public AutotunerUtilTest { return file_content; } - static void Write(const absl::string_view filepath, - const absl::string_view content) { + void Write(const absl::string_view filepath, + const absl::string_view content) { + TF_CHECK_OK(CreateDirIfNeeded(cache_dir_, tsl::Env::Default())); TF_CHECK_OK(tsl::WriteStringToFile(tsl::Env::Default(), std::string(filepath), content)); } - std::unique_ptr executor_ = - NewStreamExecutor(); + stream_executor::StreamExecutor* executor_ = NewStreamExecutor(); std::unique_ptr module_ = ParseAndReturnVerifiedModule(kHloText).value(); const HloInstruction* dot_ = hlo_query::GetFirstInstructionWithOpcode( @@ -296,10 +294,9 @@ class FileBasedCacheTest : public AutotunerUtilTest { tsl::Env* default_env = tsl::Env::Default(); std::string cache_dir; CHECK(default_env->LocalTempFilename(&cache_dir)); - CHECK_OK(default_env->CreateDir(cache_dir)); return cache_dir; }(); - AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_.get()}, [&] { + AutotuneConfig config_ = AutotuneConfig(DeviceConfig{executor_}, [&] { DebugOptions options; options.set_xla_gpu_per_fusion_autotune_cache_dir(cache_dir_); return options; diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc similarity index 99% rename from third_party/xla/xla/service/gpu/conv_algorithm_picker.cc rename to third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index d80b57da40d74e..26774cf9de208c 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/conv_algorithm_picker.h" +#include "xla/service/gpu/autotuning/conv_algorithm_picker.h" #include #include @@ -43,11 +43,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_autotuning.pb.h" #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/gpu/hlo_algorithm_denylist.h" #include "xla/service/gpu/stream_executor_util.h" @@ -57,6 +57,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/lazy_op_runner.h" diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h similarity index 92% rename from third_party/xla/xla/service/gpu/conv_algorithm_picker.h rename to third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h index e6dea8b2f20b0c..173a0c61481e57 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ -#define XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_ #include #include @@ -30,22 +30,16 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/gpu_conv_runner.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" -#include "xla/shape.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" -#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) -#include "xla/stream_executor/gpu/redzone_allocator.h" -#endif - namespace xla { namespace gpu { @@ -156,4 +150,4 @@ class GpuConvAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_CONV_ALGORITHM_PICKER_H_ diff --git a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc similarity index 92% rename from third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc rename to third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc index d9a3a691da0565..96520143e0fe4c 100644 --- a/third_party/xla/xla/service/gpu/conv_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/conv_algorithm_picker.h" +#include "xla/service/gpu/autotuning/conv_algorithm_picker.h" #include #include @@ -22,9 +22,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/platform_util.h" @@ -32,7 +32,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -68,7 +68,7 @@ ENTRY main { ->GetDeviceDescription() .gpu_compute_capability(); bool changed = false; - TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get())); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get())); changed = false; DebugOptions opts = DefaultDebugOptionsIgnoringFlags(); @@ -92,7 +92,7 @@ ENTRY main { // should have the new scratch bytes. TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(kHlo)); changed = false; - TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvRewriter(cc), m.get())); + TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(ConvRewriter(cc), m.get())); changed = false; TF_ASSERT_OK_AND_ASSIGN(changed, RunHloPass(GpuConvAlgorithmPicker(cfg), m.get())); diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc similarity index 97% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc rename to third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc index d5114bc6e6edbc..a920b9919c0aca 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" +#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h" #include #include @@ -32,8 +32,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/kernels/custom_kernel_fusion.h" diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h similarity index 85% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h rename to third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h index f6cd0c0fa5b6d1..07aad07aebd9ba 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ -#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -21,7 +21,7 @@ limitations under the License. #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/hlo_pass_interface.h" #include "xla/xla.pb.h" @@ -50,4 +50,4 @@ class CustomKernelFusionAutotuner : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_CUSTOM_KERNEL_FUSION_AUTOTUNER_H_ diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc rename to third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc index aa6c1d2ffa46c3..8defca998de755 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/custom_kernel_fusion_autotuner_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" +#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h" #include #include @@ -21,7 +21,7 @@ limitations under the License. #include #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc rename to third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc index a2de14cb2e4cf7..8a870f88993c1e 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_algorithm_picker.h" +#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h" #include #include @@ -34,8 +34,8 @@ limitations under the License. #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/cublas_cudnn.h" diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h similarity index 91% rename from third_party/xla/xla/service/gpu/gemm_algorithm_picker.h rename to third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h index be2686ddc93e86..237358388b16eb 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ -#define XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_ #include #include @@ -26,7 +26,7 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" #include "xla/shape.h" @@ -67,4 +67,4 @@ class GemmAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_ALGORITHM_PICKER_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_GEMM_ALGORITHM_PICKER_H_ diff --git a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc rename to third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc index e387aad44ef341..f1bd1876f4b6e2 100644 --- a/third_party/xla/xla/service/gpu/gemm_algorithm_picker_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_algorithm_picker_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_algorithm_picker.h" +#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h" #include #include @@ -21,9 +21,9 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/gpu/variant_visitor.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" @@ -31,7 +31,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/protobuf/dnn.pb.h" diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc similarity index 93% rename from third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc rename to third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 0a6188495febf2..269fd815363c23 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_fusion_autotuner.h" +#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include #include @@ -57,22 +57,22 @@ limitations under the License. #include "xla/service/dump.h" #include "xla/service/executable.h" #include "xla/service/float_normalization.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" -#include "xla/service/gpu/cudnn_fusion_compiler.h" -#include "xla/service/gpu/fusion_wrapper.h" -#include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/priority_fusion.h" #include "xla/service/gpu/split_k_gemm_rewriter.h" #include "xla/service/gpu/stream_executor_util.h" +#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/fusion_wrapper.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" @@ -321,6 +321,21 @@ absl::StatusOr GetLimits(const HloDotInstruction& dot) { int GetLogEveryN() { return VLOG_IS_ON(3) ? 100 : 1000; } +int64_t PriorityFusionShapeSize(const Shape& shape) { + // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the + // pointer size is used only to determine the size of tuple types. We + // shouldn't have any tuples in the autotuned module, so it's safe to use + // a constant here, instead of piping the real value. + constexpr int64_t kPointerSize = 8; + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +HloCostAnalysis::Options PriorityFusionOptions() { + return {/*shape_size=*/PriorityFusionShapeSize, + /*per_second_rates=*/{}, + /*count_multiple_input_accesses=*/true}; +} + absl::StatusOr> TritonGemmAutotuneExtractor( const TritonGemmConfig& config, const se::DeviceDescription& gpu_device_info, @@ -347,24 +362,16 @@ absl::StatusOr> TritonGemmAutotuneExtractor( if (config.split_k > 1) { TF_RETURN_IF_ERROR(MakeDotSplitKBatch(cloned_dot_fusion, config)); - GpuFloatSupport bf16_support(gpu_device_info.cuda_compute_capability(), - BF16); - FloatNormalization float_normalization(&bf16_support); - TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); - - auto shape_size_function = [&](const Shape& shape) { - // The real pointer size is set in GpuCompiler. In HloCostAnalysis, the - // pointer size is used only to determine the size of tuple types. We - // shouldn't have any tuples in the autotuned module, so it's safe to use - // a constant here, instead of piping the real value. - constexpr int64_t kPointerSize = 8; - return ShapeUtil::ByteSizeOf(shape, kPointerSize); - }; - GpuPriorityFusion priority_fusion( - /*thread_pool=*/nullptr, gpu_device_info, - GpuHloCostAnalysis::Options{/*shape_size=*/shape_size_function, - /*per_second_rates=*/{}, - /*count_multiple_input_accesses=*/true}); + for (PrimitiveType type : + {BF16, F8E5M2, F8E4M3FN, F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}) { + GpuFloatSupport float_support(gpu_device_info.cuda_compute_capability(), + type); + FloatNormalization float_normalization(&float_support); + TF_RETURN_IF_ERROR(float_normalization.Run(new_module.get()).status()); + } + + PriorityFusion priority_fusion( + /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions()); TF_RETURN_IF_ERROR(priority_fusion.Run(new_module.get()).status()); // If the priority fusion pass above skipped some instructions, turn them @@ -376,8 +383,9 @@ absl::StatusOr> TritonGemmAutotuneExtractor( } absl::StatusOr> CublasGemmAutotuneExtractor( - const AutotuneConfig& config, const int32_t toolkit_version, - const HloFusionInstruction* fusion, const DebugOptions& debug_opts) { + const AutotuneConfig& config, const se::DeviceDescription& gpu_device_info, + const int32_t toolkit_version, const HloFusionInstruction* fusion, + const DebugOptions& debug_opts) { const HloComputation* fusion_computation = fusion->called_computations().at(0); std::unique_ptr new_module = @@ -397,11 +405,13 @@ absl::StatusOr> CublasGemmAutotuneExtractor( PrecisionConfig::ALG_DOT_F32_F32_F32); } - for (bool fp8 : {true, false}) { + for (GemmRewriterOptions::DType dtype : + {GemmRewriterOptions::DType::kFp8Only, + GemmRewriterOptions::DType::kNonFp8Only}) { GemmRewriter rewriter(config.GetGpuComputeCapability(), toolkit_version, - fp8); - GpuInstructionFusion fusion_pass( - /*may_duplicate=*/false, config.GetExecutor()->GetDeviceDescription()); + GemmRewriterOptions{dtype}); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, gpu_device_info, PriorityFusionOptions()); TF_RETURN_IF_ERROR(rewriter.Run(new_module.get()).status()); TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); } @@ -484,6 +494,15 @@ absl::Status DumpOriginalFusion(AutotunerCompileUtil& util, // Using the original module for its debug info and name in the first // parameter. It's better to include the name of both the original module // and the extracted module, to avoid name clashes. + std::string rendered_graph_name = + absl::StrCat("gemm_fusion_", fusion_id, ".", module->name(), ".dot"); + std::string rendered_graph = RenderGraph(rendered_graph_name, *module, + RenderedGraphFormat::kDot, true); + DumpToFileInDir( + /*module=*/*fusion.GetModule(), + /*file_prefix=*/"", + /*file_suffix=*/rendered_graph_name, + /*contents=*/rendered_graph); DumpToFileInDirOrStdout( /*module=*/*fusion.GetModule(), /*file_prefix=*/"", @@ -517,8 +536,9 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, triton_gemm_config, device_desc, fusion, debug_opts, /*allow_filtering_kernels_spilling_registers=*/true); } else if (result.has_gemm()) { - return CublasGemmAutotuneExtractor(autotune_config, toolkit_version, - fusion, debug_opts); + return CublasGemmAutotuneExtractor(autotune_config, device_desc, + toolkit_version, fusion, + debug_opts); } else { LOG(FATAL) << "Unknown result type: " << result.DebugString(); } @@ -771,11 +791,12 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, }) .value_or(nullptr); } else if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN(executable, - compile_util.Compile([&](const DebugOptions& opts) { - return CublasGemmAutotuneExtractor( - config_, toolkit_version_, fusion, opts); - })); + TF_ASSIGN_OR_RETURN( + executable, compile_util.Compile([&](const DebugOptions& opts) { + return CublasGemmAutotuneExtractor( + config_, config_.GetExecutor()->GetDeviceDescription(), + toolkit_version_, fusion, opts); + })); } else { LOG(FATAL) << "Unsupported config type: " << config.index(); } @@ -1176,15 +1197,21 @@ absl::Status ExchangeResults(KeyValueStoreInterface& key_value_store, TF_RETURN_IF_ERROR(key_value_store.Set( absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, shard_index), results_str)); + VLOG(2) << "Rank " << shard_index << ": published results"; for (int i = 0; i < shard_count; ++i) { if (i == shard_index) { continue; } + VLOG(2) << "Rank " << shard_index << ": waiting for results from rank " << i + << " / " << shard_count; TF_ASSIGN_OR_RETURN( std::string autotune_results_str, key_value_store.Get( absl::StrFormat("%s_%d_%d", kKeyPrefix, module_id, i), - absl::InfiniteDuration())); + // TODO(b/361009609): reset to infinite duration once solved. + // Using an infinite duration here leads to issues with MPI, see + // https://github.com/google/jax/issues/22995. + absl::Hours(24))); TF_RETURN_IF_ERROR( AutotunerUtil::LoadAutotuneResults(autotune_results_str, true)); } diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h similarity index 94% rename from third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h rename to third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h index 281579226a74c1..b12fe00c9ede2a 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner.h +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ -#define XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ +#ifndef XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_ +#define XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_ #include #include @@ -34,8 +34,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/distributed/key_value_store_interface.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/device_description.h" @@ -144,4 +144,4 @@ class GemmFusionAutotunerImpl { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_FUSION_AUTOTUNER_H_ +#endif // XLA_SERVICE_GPU_AUTOTUNING_GEMM_FUSION_AUTOTUNER_H_ diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc similarity index 88% rename from third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc rename to third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index 8cb7e8dc87e229..7af1805ecf7010 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_fusion_autotuner.h" +#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include #include @@ -39,12 +39,13 @@ limitations under the License. #include "xla/service/call_inliner.h" #include "xla/service/dump.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gemm_fusion.h" -#include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" @@ -56,9 +57,9 @@ limitations under the License. #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" #include "xla/tools/hlo_decomposer.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -467,23 +468,23 @@ ENTRY %e { })"; auto module = ParseAndReturnVerifiedModule(kHloText).value(); - EXPECT_THAT( - backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor(), - {/*device_allocator=*/nullptr, - /*thread_pool=*/nullptr, - /*layout_canonicalization_callback=*/{}, - /*is_autotuning_compilation=*/true}), - ::testing::AnyOf( - tsl::testing::StatusIs( - tsl::error::CANCELLED, - absl::StrFormat( - "Compilation result discarded due to register spilling")), - // Hopper can't spill registers since wgmma instructions are - // asynchronous, instead it just runs out of them. - tsl::testing::StatusIs( - tsl::error::RESOURCE_EXHAUSTED, - absl::StrFormat("Register allocation failed")))); + EXPECT_THAT(backend().compiler()->RunBackend( + std::move(module), backend().default_stream_executor(), + {/*device_allocator=*/nullptr, + /*thread_pool=*/nullptr, + /*layout_canonicalization_callback=*/{}, + /*is_autotuning_compilation=*/true}), + ::testing::AnyOf( + tsl::testing::StatusIs( + tsl::error::CANCELLED, + "Compilation result discarded due to register spilling"), + // Hopper can't spill registers since wgmma instructions are + // asynchronous, instead it just runs out of them. + tsl::testing::StatusIs(tsl::error::RESOURCE_EXHAUSTED, + "Register allocation failed"), + tsl::testing::StatusIs( + tsl::error::INTERNAL, + ::testing::HasSubstr("Insufficient registers")))); } // Modify block_k back to 16 once b/337839570 is fixed. @@ -617,9 +618,12 @@ ENTRY main { pipeline.AddPass(autotune_config, GetToolkitVersion(), &thread_pool, key_value_store); pipeline.AddPass(); - for (bool fp8_rewrite : {true, false}) { + for (GemmRewriterOptions::DType dtype : + {GemmRewriterOptions::DType::kFp8Only, + GemmRewriterOptions::DType::kNonFp8Only}) { pipeline.AddPass(autotune_config.GetGpuComputeCapability(), - GetToolkitVersion(), fp8_rewrite); + GetToolkitVersion(), + GemmRewriterOptions{dtype}); } TF_EXPECT_OK(HloTestBase::RunHloPass(&pipeline, module.get())); @@ -654,18 +658,18 @@ TEST_F(GemmFusionAutotunerDumpTest, DumpingWorks) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( fusion1 { - p0 = f32[3333,3333] parameter(0) - s = f32[3333,3333] sine(p0) - p1 = f32[3333,3333] parameter(1) - c = f32[3333,3333] cosine(p1) - ROOT dot = f32[3333,3333] dot(s, c), + p0 = f32[333,333] parameter(0) + s = f32[333,333] sine(p0) + p1 = f32[333,333] parameter(1) + c = f32[333,333] cosine(p1) + ROOT dot = f32[333,333] dot(s, c), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { - p0 = f32[3333,3333] parameter(0) - p1 = f32[3333,3333] parameter(1) - ROOT rr = f32[3333,3333] fusion(p0, p1), kind=kCustom, calls=fusion1, + p0 = f32[333,333] parameter(0) + p1 = f32[333,333] parameter(1) + ROOT rr = f32[333,333] fusion(p0, p1), kind=kCustom, calls=fusion1, backend_config={"fusion_backend_config": {kind: "__triton_gemm"}} })", config)); @@ -941,6 +945,55 @@ ENTRY wais { INSTANTIATE_TEST_SUITE_P(GemmFusionAutotunerConfigSweep, GemmFusionAutotunerConfigTest, ::testing::Bool()); +TEST_F(GemmFusionAutotunerTest, SplitKFLoatNormalization) { + if (!GetCudaComputeCapability().IsAtLeastHopper()) { + GTEST_SKIP() << "f8 types are only supported from Hopper onwards."; + } + const se::CudaComputeCapability compute_capability = + GetCudaComputeCapability(); + se::GpuDeviceInfoProto deviceless_proto; + auto ccc = deviceless_proto.mutable_cuda_compute_capability(); + ccc->set_major(compute_capability.major); + ccc->set_minor(compute_capability.minor); + DeviceConfig test_config{backend().default_stream_executor(), + backend().memory_allocator()}; + AutotuneConfig autotune_config{test_config, GetDebugOptionsForTest()}; + GemmFusionAutotunerImpl autotuner(autotune_config, GetToolkitVersion(), + GetDebugOptionsForTest(), nullptr); + TF_ASSERT_OK_AND_ASSIGN( + auto compile_util, + AutotunerCompileUtil::Create(autotune_config, GetDebugOptionsForTest())) + + std::unique_ptr module = ParseAndReturnVerifiedModule(R"( +HloModule module + +%gemm_fusion_dot_computation (parameter_0: f8e5m2[256,256], parameter_1: f8e4m3fn[128,256]) -> f8e5m2[256,128] { + %parameter_0 = f8e5m2[256,256]{1,0} parameter(0) + %parameter_1 = f8e4m3fn[128,256]{1,0} parameter(1) + %dot.1 = f32[256,128]{1,0} dot(f8e5m2[256,256]{1,0} %parameter_0, f8e4m3fn[128,256]{1,0} %parameter_1), lhs_contracting_dims={0}, rhs_contracting_dims={1} + ROOT %convert.2 = f8e5m2[256,128]{1,0} convert(f32[256,128]{1,0} %dot.1) +} +ENTRY entry { + %p0 = f8e5m2[256,256]{1,0} parameter(0) + %p1 = f8e4m3fn[128,256]{1,0} parameter(1) + ROOT r = f8e5m2[256,128]{1,0} fusion(f8e5m2[256,256]{1,0} %p0, f8e4m3fn[128,256]{1,0} %p1), kind=kCustom, calls=%gemm_fusion_dot_computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false} +})") + .value(); + GemmFusionAutotunerImpl::TilingConfigs configs; + configs.emplace_back(DynCast( + module->entry_computation()->root_instruction()), + std::vector{ + GemmFusionAutotunerImpl::Config(TritonGemmConfig( + /*block_m=*/32, + /*block_n=*/64, + /*block_k=*/64, + /*split_k=*/4, + /*num_stages=*/1, + /*num_warps=*/4, + /*num_ctas=*/1))}); + CHECK_OK(autotuner.CompileAll(*compile_util, configs)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_autotuning.proto b/third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto similarity index 100% rename from third_party/xla/xla/service/gpu/gpu_autotuning.proto rename to third_party/xla/xla/service/gpu/autotuning/gpu_autotuning.proto diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc index 624d324e739e92..0ffb8e3fe63de9 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.cc +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -79,7 +79,7 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, // first, i.e. before processing other outputs (that may overwrite the input). stream_executor::GpuDeviceInfoProto device_info; stream_executor::DeviceDescription device_description(device_info); - auto analysis = HloFusionAnalysis::Create(fusion, &device_description); + auto analysis = HloFusionAnalysis::Create(*user, device_description); bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; const HloInstruction* reduction_hero = diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index c21784b1b3dda8..39f00c826fdc7c 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -28,6 +27,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" @@ -53,25 +54,22 @@ limitations under the License. #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/ir_emitter_unnested.h" #include "xla/service/gpu/metrics.h" -#include "xla/service/gpu/runtime/conditional_thunk.h" -#include "xla/service/gpu/runtime/sequential_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/service/gpu/runtime/while_thunk.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_ordering.h" #include "xla/service/logical_buffer.h" #include "xla/shape.h" +#include "xla/status_macros.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/scoped_annotation.h" namespace xla::gpu { @@ -102,8 +100,10 @@ void RemoveUnusedAndUninitializedGlobals( } } -static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, - absl::string_view cache_file_path) { +} // namespace + +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path) { std::string resolved_path; if (!tsl::io::ResolveTestPrefixes(cache_file_path, resolved_path)) { return FailedPrecondition("File path can not be resolved: %s", @@ -114,7 +114,7 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, TF_RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), resolved_path, &serialized)); CompilationCacheProto proto; - if (!proto.ParseFromString(std::string(serialized))) { + if (!proto.ParseFromString(serialized)) { return Internal("Failed to parse serialized CompilationCacheProto."); } // Register all cached kernel names with the name uniquer to avoid @@ -131,8 +131,6 @@ static absl::Status LoadCache(IrEmitterContext& ir_emitter_context, return absl::OkStatus(); } -} // namespace - absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h index d7005f879c3994..a451af5a149fad 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.h @@ -31,6 +31,7 @@ limitations under the License. #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/gpu_executable.h" +#include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/runtime/sequential_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/hlo.pb.h" @@ -66,6 +67,9 @@ struct CompileModuleResults { void ForAllThunks(const std::function& fn, ThunkSequence* thunk_sequence); +absl::Status LoadCache(IrEmitterContext& ir_emitter_context, + absl::string_view cache_file_path); + absl::StatusOr CompileModuleToLlvmIr( HloModule* hlo_module, llvm::LLVMContext* llvm_context, const std::string& target_triple, const std::string& data_layout, diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc deleted file mode 100644 index 387a3f4d3aec8e..00000000000000 --- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.cc +++ /dev/null @@ -1,272 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/cudnn_workspace_rewriter.h" - -#include -#include -#include - -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/string_view.h" -#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_clone_context.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/cuda/cuda_dnn.h" -#include "xla/stream_executor/dnn.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -// create cuDNN graphs from HloCustomCall -absl::StatusOr HloCustomCallToCuDnnGraph( - se::dnn::DnnSupport& dnn_support, HloCustomCallInstruction* custom_call) { - if (IsFwdCustomCallTofMHA(*custom_call)) { - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(custom_call)); - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = custom_call->operand(3); - bias_shape = bias->shape(); - } - } - - TF_ASSIGN_OR_RETURN( - const auto gpu_config, - custom_call->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(custom_call->shape(), {0})}; - - bool has_activation = - xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; - if (has_activation) { - output_shapes.push_back( - ShapeUtil::GetSubshape(custom_call->shape(), {1})); - } - - Shape q_shape = custom_call->operand(0)->shape(); - Shape k_shape = custom_call->operand(1)->shape(); - Shape v_shape = custom_call->operand(2)->shape(); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - cudnn_mask_type, - q_shape, - k_shape, - v_shape, - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionOperationGraph( - dnn_support, fmha_config.lhs_bmm1, fmha_config.rhs_bmm1, - fmha_config.rhs_bmm2, fmha_config.output, fmha_config.bias, - fmha_config.activation, static_cast(*fmha_config.fmha_scale), - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.dropout_rate, dnn_mask_type)); - return std::move(graph); - } else { - TF_ASSIGN_OR_RETURN( - auto gpu_config, - custom_call->backend_config()); - xla::gpu::CudnnfMHABackendConfig& config = - *gpu_config.mutable_cudnn_fmha_backend_config(); - - int input_index = 0; - Shape bmm1_grad_gemm1_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm1_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm2_rhs_shape = - custom_call->operand(input_index++)->shape(); - Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); - input_index++; - Shape d_output_shape = custom_call->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, - GetCudnnfMHAKind(custom_call)); - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - std::optional bias_shape; - if (has_bias) { - bias_shape = custom_call->operand(input_index++)->shape(); - } - - std::optional fwd_output_shape = - custom_call->operand(input_index++)->shape(); - if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || - config.mask_type() == - xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { - // skip q_seqlen and kv_seqlen - input_index += 2; - } - TF_RET_CHECK(input_index == custom_call->operand_count()); - - int output_index = 0; - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - std::optional d_s_shape; - std::optional d_bias_shape; - bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; - if (has_dbias) { - d_bias_shape = - ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); - } - // The last one is the workspace. - TF_RET_CHECK(output_index == - custom_call->shape().tuple_shapes().size() - 1); - TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - - const bool force_deterministic = - RequireDeterminism(custom_call->GetModule()->config()); - // set the correct force_deterministic attribute here - config.set_force_deterministic(force_deterministic); - TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); - - GpufMHABackwardDescriptor descriptor = { - kind, - config, - cudnn_mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_config, - GpufMHABackwardConfig::For(descriptor)); - TF_ASSIGN_OR_RETURN( - se::dnn::FMHAMaskKind dnn_mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(fmha_config.mask_type)); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( - dnn_support, fmha_config.bmm1_grad_gemm1_rhs, - fmha_config.bmm1_grad_gemm2_rhs, fmha_config.bmm2_grad_gemm1_lhs, - fmha_config.bmm2_grad_gemm2_rhs, fmha_config.d_output, - fmha_config.d_bmm1_lhs, fmha_config.d_bmm1_rhs, - fmha_config.d_bmm2_rhs, fmha_config.bias, fmha_config.dropout_rate, - fmha_config.seed, *fmha_config.fmha_scale, - fmha_config.dropout_rate && *fmha_config.dropout_rate > 0.0, - fmha_config.bias != std::nullopt, dnn_mask_type, - force_deterministic)); - return std::move(graph); - } -} - -class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { - public: - explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport& dnn_support) - : dnn_support_(dnn_support) {} - - absl::Status HandleCustomCall(HloInstruction* hlo) override { - if (!IsCustomCallTofMHA(*hlo)) { - // don't do anything about other cuDNN custom calls - return absl::OkStatus(); - } - TF_ASSIGN_OR_RETURN(auto gpu_config, - hlo->backend_config()); - - TF_ASSIGN_OR_RETURN( - se::gpu::CudnnGraph graph, - HloCustomCallToCuDnnGraph(dnn_support_, - DynCast(hlo))); - auto workspace = graph.Graph().get_workspace_size(); - if (workspace != 0) { - // rewrite custom call to have correct workspace size - VLOG(4) << "Rewriting: " << hlo->ToString(); - Shape* shape = hlo->mutable_shape(); - shape->mutable_tuple_shapes(shape->tuple_shapes_size() - 1) - ->set_dimensions(0, workspace); - MarkAsChanged(); - } - return absl::OkStatus(); - } - - private: - se::dnn::DnnSupport& dnn_support_; -}; - -} // namespace - -absl::StatusOr CuDnnWorkspaceRewriter::Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) { - XLA_SCOPED_LOGGING_TIMER("cuDNN workspace rewriter"); - return CuDnnCustomCallVisitor(dnn_support_) - .RunOnModule(module, execution_threads); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/custom_call_test.cc b/third_party/xla/xla/service/gpu/custom_call_test.cc index 87050db0129d56..4331d6efba429c 100644 --- a/third_party/xla/xla/service/gpu/custom_call_test.cc +++ b/third_party/xla/xla/service/gpu/custom_call_test.cc @@ -57,7 +57,7 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/test_helpers.h" #include "xla/tests/client_library_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA diff --git a/third_party/xla/xla/service/gpu/determinism_test.cc b/third_party/xla/xla/service/gpu/determinism_test.cc index 93c5b1591f110c..2d4a7a94ce087f 100644 --- a/third_party/xla/xla/service/gpu/determinism_test.cc +++ b/third_party/xla/xla/service/gpu/determinism_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_timer.h" @@ -97,6 +97,7 @@ ENTRY e { if (!rocm.has_hipblaslt()) { GTEST_SKIP() << "No hipblas-lt support on this architecture!"; } + debug_options_.set_xla_gpu_enable_triton_gemm(false); #endif // TENSORFLOW_USE_ROCM debug_options_.set_xla_gpu_triton_fusion_level(0); diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment.cc b/third_party/xla/xla/service/gpu/execution_stream_assignment.cc index 6ee0f2bfbc1f7b..8a55dc555e0550 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment.cc +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment.cc @@ -34,7 +34,8 @@ limitations under the License. namespace xla::gpu { -ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) { +ExecutionStreamAssignment::ExecutionStreamAssignment( + const HloModule* module, ExecutionStreamAssignmentOptions options) { std::unique_ptr call_graph = CallGraph::Build(module); // We'll walk the `CallGraph` starting from the entrypoint. The instructions @@ -88,14 +89,18 @@ ExecutionStreamAssignment::ExecutionStreamAssignment(const HloModule* module) { // Asynchronous calls will result in a new `ExecutionStreamId` being // dispensed for the called computations. CHECK_EQ(callsite.instruction()->opcode(), HloOpcode::kAsyncStart); - const ExecutionStreamId async_stream_id = next_stream_id++; - enqueue_called_computations(callsite, async_stream_id); + enqueue_called_computations(callsite, next_stream_id); AsyncExecutionStreamIds streams; streams.source_stream_id = pending.stream_id; - streams.destination_stream_id = async_stream_id; + streams.destination_stream_id = next_stream_id; CHECK(async_instructions_.try_emplace(callsite.instruction(), streams) .second); + + next_stream_id++; + if (next_stream_id.value() > options.number_of_execution_streams) { + next_stream_id = ExecutionStreamId(1); + } } else { // Synchronous calls will result in the called computations being // invoked using the same `ExecutionStreamId`. diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment.h b/third_party/xla/xla/service/gpu/execution_stream_assignment.h index adbd7f04ace5ec..cb0e87ae0e44f2 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment.h +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment.h @@ -26,6 +26,12 @@ limitations under the License. namespace xla::gpu { +struct ExecutionStreamAssignmentOptions { + // The `ExecutionStreamAssignment` will round-robin across this many + // `ExecutionStreams`. + int number_of_execution_streams = 4; +}; + // `ExecutionStreamAssignments` represent a mapping from `HloInstructions` to // `ExecutionStreamIds`. Asynchronous calls (`async-start`, `async-update`, and // `async-done`) result in the target computations being assigned new @@ -37,7 +43,8 @@ class ExecutionStreamAssignment { // pass the module through the `FlattenCallGraph` pass. // // The ExecutionStreamAssignment does not take ownership of the `HloModule`. - explicit ExecutionStreamAssignment(const HloModule* module); + explicit ExecutionStreamAssignment( + const HloModule* module, ExecutionStreamAssignmentOptions options = {}); // Returns the `ExecutionStreamId` for the given instruction, which *must* be // synchronous. Returns an error if the instruction is either not reachable diff --git a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc index cf7ec32ab62757..e6abd3e3f5e101 100644 --- a/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/execution_stream_assignment_test.cc @@ -69,6 +69,10 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { p0 = f32[2,2] parameter(0) ROOT add = f32[2,2] add(p0, p0) } + leaf3 { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } // Entry computation that calls each of the leaves asynchronously. ENTRY entry { @@ -77,21 +81,30 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { kind=kLoop, calls=leaf1 start2 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), kind=kLoop, calls=leaf2 + start3 = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), + kind=kLoop, calls=leaf3 update1 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start1) update2 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start2) + update3 = ((f32[2,2]), f32[2,2], s32[]) fusion-update(start3) done1 = f32[2,2] fusion-done(update1) done2 = f32[2,2] fusion-done(update2) - ROOT done = f32[2,2] add(done1, done2) + done3 = f32[2,2] fusion-done(update3) + ROOT done = f32[2,2] custom-call(done1, done2, done3), + custom_call_target="target" } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr)); - ExecutionStreamAssignment assignment(module.get()); + ExecutionStreamAssignment assignment( + module.get(), + ExecutionStreamAssignmentOptions{/*number_of_execution_streams=*/2}); // The outermost computation should run on `ExecutionStreamId(0)`. The two // asynchronous branches should be launched on `ExecutionStreamId(1)` and - // `ExecutionStreamId(2)`, respectively. + // `ExecutionStreamId(2)`, respectively. The third asynchronous branch should + // reuse `ExecutionStreamId(1)` because we set `number_of_execution_streams` + // to `2`. ExpectExecutionStreamForSyncInstructions( assignment, FindComputation(module.get(), "entry"), ExecutionStreamId(0)); for (std::string_view instruction : {"start1", "update1", "done1"}) { @@ -108,6 +121,13 @@ TEST_F(ExecutionStreamAssignmentTest, AsyncFusion) { /*source_stream_id=*/ExecutionStreamId(0), /*destination_stream_id=*/ExecutionStreamId(2)})); } + for (std::string_view instruction : {"start3", "update3", "done3"}) { + EXPECT_THAT(assignment.GetAsyncExecutionStreamIds(Cast( + FindInstruction(module.get(), instruction))), + IsOkAndHolds(AsyncExecutionStreamIds{ + /*source_stream_id=*/ExecutionStreamId(0), + /*destination_stream_id=*/ExecutionStreamId(1)})); + } // Leaf computations should run on the respective asynchronous // `ExecutionStreamIds`. diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index 2a184c04a6594c..4fc4af0a4cfafa 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -19,14 +19,14 @@ limitations under the License. #include #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/gpu/fusion_merger.h" -#include "xla/service/gpu/horizontal_input_fusion.h" -#include "xla/service/gpu/horizontal_loop_fusion.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/multi_output_fusion.h" -#include "xla/service/gpu/priority_fusion.h" -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/fusion_merger.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" @@ -63,8 +63,8 @@ HloPassPipeline FusionPipeline( shape_size_bytes_function, /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}; - fusion.AddPass(thread_pool, gpu_device_info, - std::move(cost_analysis_options)); + fusion.AddPass(thread_pool, gpu_device_info, + std::move(cost_analysis_options)); } else { fusion.AddPass(/*may_duplicate=*/false, gpu_device_info); @@ -77,8 +77,7 @@ HloPassPipeline FusionPipeline( fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); fusion.AddPass(); - fusion.AddPass(gpu_device_info, - shape_size_bytes_function); + fusion.AddPass(gpu_device_info, shape_size_bytes_function); fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); fusion.AddPass(); @@ -88,8 +87,8 @@ HloPassPipeline FusionPipeline( HloPassPipeline HorizontalFusionPipeline( const se::DeviceDescription& gpu_device_info) { HloPassFix horizontal_fusion("horizontal fusion"); - horizontal_fusion.AddPass(); - horizontal_fusion.AddPass(gpu_device_info); + horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); horizontal_fusion.AddPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 1d2be88970161d..889c31cbec9246 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -8,49 +8,6 @@ package( licenses = ["notice"], ) -cc_library( - name = "in_place_dynamic_update_slice", - srcs = ["in_place_dynamic_update_slice.cc"], - hdrs = ["in_place_dynamic_update_slice.h"], - deps = [ - ":fusion_emitter", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:dynamic_update_slice_util", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - ], -) - -xla_cc_test( - name = "in_place_dynamic_update_slice_test", - srcs = ["in_place_dynamic_update_slice_test.cc"], - deps = [ - ":fusions", - ":in_place_dynamic_update_slice", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "in_place_dynamic_update_slice_mlir", srcs = ["in_place_dynamic_update_slice_mlir.cc"], @@ -87,8 +44,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -122,6 +79,7 @@ cc_library( deps = [ ":fusion_emitter", "//xla:shape_util", + "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", "//xla/ffi:attribute_map", @@ -147,6 +105,10 @@ cc_library( "//xla/service/gpu/runtime:dynamic_slice_thunk", "//xla/service/gpu/runtime:gemm_thunk", "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:nccl_all_reduce_thunk", + "//xla/service/gpu/runtime:nccl_api", + "//xla/service/gpu/runtime:nccl_clique_key", + "//xla/service/gpu/runtime:nccl_collective_thunk", "//xla/service/gpu/runtime:thunk", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -167,6 +129,12 @@ cc_library( xla_test( name = "dynamic_slice_fusion_test", srcs = if_cuda_is_configured(["dynamic_slice_fusion_test.cc"]), + backend_tags = { + "gpu": [ + "multi_gpu", + "no_oss", + ], + }, backends = ["gpu"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ @@ -184,11 +152,14 @@ xla_test( "//xla/service:executable", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "//xla/service/gpu:dynamic_slice_fusion_rewriter", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter", "//xla/stream_executor", "//xla/stream_executor:device_description", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@local_tsl//tsl/platform:errors", @@ -244,23 +215,16 @@ cc_library( hdrs = ["fusions.h"], visibility = ["//xla/service/gpu:__subpackages__"], deps = [ - ":concatenate", ":concatenate_mlir", ":copy", ":cudnn", ":custom", ":fusion_emitter", - ":in_place_dynamic_update_slice", ":in_place_dynamic_update_slice_mlir", - ":input_slices", ":input_slices_mlir", - ":loop", ":loop_mlir", - ":reduction", ":reduction_mlir", - ":scatter", ":scatter_mlir", - ":transpose", ":transpose_mlir", ":triton", "//xla:shape_util", @@ -270,6 +234,13 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/legacy:concatenate", + "//xla/service/gpu/fusions/legacy:in_place_dynamic_update_slice", + "//xla/service/gpu/fusions/legacy:input_slices", + "//xla/service/gpu/fusions/legacy:loop", + "//xla/service/gpu/fusions/legacy:reduction", + "//xla/service/gpu/fusions/legacy:scatter", + "//xla/service/gpu/fusions/legacy:transpose", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -292,8 +263,8 @@ cc_library( "//xla/service:gpu_plugin", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:affine_map_printer", "//xla/stream_executor:device_description", "//xla/tests:filecheck", @@ -320,54 +291,23 @@ cc_library( ], ) -cc_library( - name = "loop", - srcs = ["loop.cc"], - hdrs = ["loop.h"], - deps = [ - ":fusion_emitter", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_fusible", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "loop_mlir", srcs = ["loop_mlir.cc"], hdrs = ["loop_mlir.h"], deps = [ - ":loop", "//xla:shape_util", "//xla:status_macros", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", @@ -392,8 +332,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -403,17 +343,17 @@ cc_library( srcs = ["scatter_mlir.cc"], hdrs = ["scatter_mlir.h"], deps = [ - ":loop", "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/service:scatter_simplifier", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -441,8 +381,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) @@ -458,15 +398,14 @@ cc_library( "//xla:util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/mlir/utils:type_util", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -486,6 +425,7 @@ xla_test( name = "transpose_mlir_test", srcs = ["transpose_mlir_test.cc"], backends = ["gpu"], + tags = ["no_rocm"], # TODO(rocm): weekly sync 24-08-20 deps = [ ":mlir_emitter_test_base", ":transpose_mlir", @@ -493,108 +433,13 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "loop_test", - srcs = ["loop_test.cc"], - deps = [ - ":fusion_emitter", - ":fusions", - "//xla:status_macros", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) -cc_library( - name = "scatter", - srcs = ["scatter.cc"], - hdrs = ["scatter.h"], - deps = [ - ":fusion_emitter", - ":loop", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scatter_test", - srcs = ["scatter_test.cc"], - deps = [ - ":fusions", - ":scatter", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "tiling_util", - srcs = ["tiling_util.cc"], - hdrs = ["tiling_util.h"], - visibility = ["//xla/service/gpu:__subpackages__"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:target_util", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "triton", srcs = ["triton.cc"], @@ -696,23 +541,22 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service:dump", "//xla/service:executable", - "//xla/service:hlo_module_config", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu:cudnn_fusion_compiler", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/runtime:thunk", "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/service/gpu/transforms:cudnn_fusion_compiler", "//xla/stream_executor:dnn", "//xla/stream_executor:stream_executor_h", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -740,88 +584,13 @@ cc_library( ], ) -cc_library( - name = "reduction", - srcs = ["reduction.cc"], - hdrs = ["reduction.h"], - deps = [ - ":fusion_emitter", - ":reduction_base", - ":thunk_util", - ":tiling_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:kernel_arguments", - "//xla/service/gpu:kernel_reuse_cache", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu:reduction_utils", - "//xla/service/gpu:target_util", - "//xla/service/gpu/runtime:kernel_thunk", - "//xla/service/gpu/runtime:thunk", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "//xla/service/llvm_ir:llvm_util", - "//xla/service/llvm_ir:loop_emitter", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "reduction_test", - srcs = ["reduction_test.cc"], - deps = [ - ":fusion_emitter", - ":reduction", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - ], -) - cc_library( name = "reduction_base", srcs = ["reduction_base.cc"], hdrs = ["reduction_base.h"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], deps = [ ":fusion_emitter", - ":tiling_util", "//xla:shape_util", "//xla:union_find", "//xla:util", @@ -865,11 +634,11 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:reduction_utils", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/mlir:type_util", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -897,60 +666,15 @@ xla_test( ":mlir_emitter_test_base", ":reduction_mlir", "//xla:error_spec", + "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - ], -) - -cc_library( - name = "concatenate", - srcs = ["concatenate.cc"], - hdrs = ["concatenate.h"], - deps = [ - ":fusion_emitter", - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:loop_emitter", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "concatenate_test", - srcs = ["concatenate_test.cc"], - deps = [ - ":concatenate", - ":fusions", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", ], ) @@ -959,9 +683,8 @@ cc_library( srcs = ["concatenate_mlir.cc"], hdrs = ["concatenate_mlir.h"], deps = [ - ":concatenate", - ":loop", "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu/fusions/mlir:computation_partitioner", @@ -990,96 +713,8 @@ xla_test( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu/model:indexing_test_utils", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "transpose", - srcs = ["transpose.cc"], - hdrs = ["transpose.h"], - deps = [ - ":fusion_emitter", - ":tiling_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:target_util", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:llvm_util", - "//xla/service/llvm_ir:loop_emitter", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "transpose_test", - srcs = ["transpose_test.cc"], - deps = [ - ":fusions", - ":transpose", - "//xla:status_macros", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "input_slices", - srcs = ["input_slices.cc"], - hdrs = ["input_slices.h"], - deps = [ - ":fusion_emitter", - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:elemental_ir_emitter", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", "@local_tsl//tsl/platform:statusor", ], ) @@ -1095,10 +730,10 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:computation_partitioner", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", @@ -1126,21 +761,3 @@ xla_test( "@com_google_googletest//:gtest", ], ) - -xla_cc_test( - name = "input_slices_test", - srcs = ["input_slices_test.cc"], - deps = [ - ":fusions", - ":input_slices", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc index f2cecc5d6cac80..d7ebabf2a7f723 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir.cc @@ -35,10 +35,9 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/concatenate.h" -#include "xla/service/gpu/fusions/loop.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" @@ -52,6 +51,16 @@ using llvm::SmallVector; using mlir::Value; using mlir::ValueRange; +const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) { + const HloInstruction& concat = analysis.fusion_hero(0).instruction(); + int64_t dim = concat.concatenate_dimension(); + auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) { + return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim); + }; + HloInstruction* operand = *absl::c_max_element(concat.operands(), less); + return operand->shape(); +} + // Computes the unroll factor that divides concat dimension of all operands. int ComputeUnrollFactor(const HloFusionAnalysis& analysis, int unroll_factor_for_the_largest_shape) { diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc index c0637cbe12dc74..30ca0f6c0cda05 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/concatenate_mlir_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -52,7 +52,7 @@ TEST_F(MlirConcatenateFusionTest, ThreadIdIndexing) { thread_id_printer_.SetSymbolName(1, "unroll_id"); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirConcatenateFusion fusion(analysis); constexpr auto kIndexing = R"( @@ -102,9 +102,9 @@ TEST_F(MlirConcatenateFusionTest, StandAloneConcatenate) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 200)> - // CHECK-DAG: #[[MAP_3:.*]] = affine_map<(d0, d1) -> (d1 * 128 + d0 + 600)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 200) + // CHECK-DAG: #[[MAP_3:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0 + 600) // CHECK-LABEL: fused_computation // CHECK-SAME: %[[ARG_0:[a-zA-Z0-9]*]]: {{[^,]*}}, @@ -152,7 +152,7 @@ TEST_F(MlirConcatenateFusionTest, PrologueEpilogue) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 + 64)> + // CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 64) // CHECK-LABEL: fused_computation // CHECK-DAG: %[[C_63:.*]] = arith.constant 63 @@ -254,9 +254,9 @@ TEST_F(MlirConcatenateFusionTest, Vectorization) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: affine_map<(d0, d1) -> (d1 * 128 + d0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0)> - // CHECK-DAG: affine_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002)> + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1) -> (d1 * 128 + d0) + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0) + // CHECK-DAG: #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 2 + d1 * 256 + s0 + 640002) // CHECK-LABEL: fused_computation // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 9e9e1ce7560500..172c53ac49b65e 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -31,11 +31,10 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/dump.h" #include "xla/service/executable.h" -#include "xla/service/gpu/cudnn_fusion_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/service/hlo_module_config.h" +#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/stream_executor/dnn.h" @@ -43,9 +42,9 @@ limitations under the License. #include "xla/stream_executor/stream_executor_memory_allocator.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" @@ -88,17 +87,49 @@ class CuDnnFusionTest : public GpuCodegenTest { } }; -TEST_F(CuDnnFusionTest, DumpingWorks) { - HloModuleConfig config; - DebugOptions options = GetDebugOptionsForTest(); - std::string output_directory; - if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory)) { - output_directory = tsl::testing::TmpDir(); +class CuDnnFusionFileCheckTest : public CuDnnFusionTest { + public: + CuDnnFusionFileCheckTest() { + if (!tsl::io::GetTestUndeclaredOutputsDir(&output_directory_)) { + output_directory_ = tsl::testing::TmpDir(); + } } - options.set_xla_dump_to(output_directory); - config.set_debug_options(options); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(R"( + + DebugOptions GetDebugOptionsForTest() override { + DebugOptions options = CuDnnFusionTest::GetDebugOptionsForTest(); + options.set_xla_dump_to(output_directory_); + return options; + } + + absl::StatusOr RunCuDnnFileCheck(absl::string_view hlo, + absl::string_view pattern) { + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + const std::string root_name( + module->entry_computation()->root_instruction()->name()); + BinaryMap dnn_compiled_graphs; + CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(), + dnn_compiled_graphs); + // Run filecheck even if CuDnnFusionCompiler failed. + cudnn_compiler.Run(module.get()).IgnoreError(); + std::string dump; + TF_RETURN_IF_ERROR(tsl::ReadFileToString( + tsl::Env::Default(), + tsl::io::JoinPath( + output_directory_, + FilenameFor(*module, /*prefix=*/"", + /*suffix=*/ + absl::StrCat("cudnn_fusion_", root_name, ".json"))), + &dump)); + return RunFileCheck(dump, pattern); + } + + private: + std::string output_directory_; +}; + +TEST_F(CuDnnFusionFileCheckTest, F32DotGraphIsConvertedCorrectly) { + EXPECT_TRUE(*RunCuDnnFileCheck(R"( fd0 { p0 = f32[64,64] parameter(0) p1 = f32[64,64] parameter(1) @@ -111,20 +142,7 @@ ENTRY e { ROOT d0 = f32[64,64] fusion(p0, p1), kind=kCustom, calls=fd0, backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion","cudnn_fusion_config":{"plan_id":"0"}}} })", - config)); - BinaryMap dnn_compiled_graphs; - CuDnnFusionCompiler cudnn_compiler(*backend().default_stream_executor(), - dnn_compiled_graphs); - TF_ASSERT_OK_AND_ASSIGN(bool changed, cudnn_compiler.Run(module.get())); - EXPECT_TRUE(changed); - std::string dump; - TF_EXPECT_OK(tsl::ReadFileToString( - tsl::Env::Default(), - tsl::io::JoinPath(output_directory, - FilenameFor(*module, /*prefix=*/"", - /*suffix=*/"cudnn_fusion_d0.json")), - &dump)); - EXPECT_TRUE(*RunFileCheck(dump, R"( + R"( CHECK: "nodes": [ CHECK: "inputs": { CHECK: "A": "p0", diff --git a/third_party/xla/xla/service/gpu/fusions/custom.cc b/third_party/xla/xla/service/gpu/fusions/custom.cc index 3a95abfa402021..31cd030449d705 100644 --- a/third_party/xla/xla/service/gpu/fusions/custom.cc +++ b/third_party/xla/xla/service/gpu/fusions/custom.cc @@ -58,12 +58,16 @@ limitations under the License. #include "xla/service/gpu/runtime/dynamic_slice_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" +#include "xla/service/gpu/runtime/nccl_all_reduce_thunk.h" +#include "xla/service/gpu/runtime/nccl_api.h" +#include "xla/service/gpu/runtime/nccl_collective_thunk.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/stream_executor_util.h" #include "xla/service/hlo.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/status_macros.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -755,6 +759,119 @@ absl::StatusOr EmitCustomCall( return result; } +absl::StatusOr EmitCollective( + IrEmitterContext& ir_emitter_context, const HloFusionAdaptor& adaptor, + const HloFusionInstruction& fusion_instr, const HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kReduceScatter) { + return absl::UnimplementedError( + "Dynamic slice fusion with collectives only works for reduce-scatter " + "instruction"); + } + + const BufferAssignment& buffer_assignment = + ir_emitter_context.buffer_assignment(); + + std::vector>> + offset_buffer_indices(2, std::nullopt); + std::vector> orig_shapes(2, std::nullopt); + std::vector> sliced_shapes(2, std::nullopt); + std::vector> offset_byte_sizes(2, std::nullopt); + + std::vector slice_instrs(2, nullptr); + std::vector> arguments; + + // Collect slice information for inputs. + unsigned arg_idx = 0; + TF_ASSIGN_OR_RETURN(arguments.emplace_back(), + GetOperandSlice(buffer_assignment, adaptor, fusion_instr, + *instr->operand(arg_idx), slice_instrs, + /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion_instr, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx++)); + + // Collect slice information for outputs. + TF_ASSIGN_OR_RETURN( + arguments.emplace_back(), + GetResultSlice(buffer_assignment, adaptor, fusion_instr, *instr, + slice_instrs, /*shape_idx=*/{}, arg_idx)); + TF_RETURN_IF_ERROR(CollectSliceInfo( + buffer_assignment, fusion_instr, + absl::Span(slice_instrs), offset_buffer_indices, + orig_shapes, sliced_shapes, offset_byte_sizes, arg_idx)); + + if (absl::c_all_of(slice_instrs, [&](HloInstruction* slice_instr) { + return slice_instr && + slice_instr->opcode() != HloOpcode::kDynamicUpdateSlice; + })) { + return absl::InternalError( + "DynamicSliceFusion with reduce-scatter expects a dynamic-update-slice " + "operation."); + } + + // Provide fake allocations for inputs and outputs. + std::vector> fake_allocations(2); + unsigned fake_arg_idx = 0; + int64_t operand_byte_size = + ShapeUtil::ByteSizeOf(instr->operand(fake_arg_idx)->shape()); + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, operand_byte_size, /*color=*/0); + BufferAllocation::Slice slice_operand(fake_allocations[fake_arg_idx].get(), 0, + operand_byte_size); + fake_arg_idx++; + TF_RET_CHECK(instr->shape().IsArray() && + "The output is not expected to be a tuple."); + int64_t out_fake_byte_size = + ShapeUtil::ByteSizeOf(instr->shape()); // TODO: we don't need this + fake_allocations[fake_arg_idx] = std::make_unique( + /*index=*/fake_arg_idx, out_fake_byte_size, /*color=*/0); + BufferAllocation::Slice slice_out_fake(fake_allocations[fake_arg_idx].get(), + 0, out_fake_byte_size); + + // Generate the hero thunk and wrap it in a dynamic-slice thunk. + ThunkSequence seq; + auto thunk_info = Thunk::ThunkInfo::WithProfileAnnotation(instr); + std::vector buffers; + const Shape& src_shape = instr->operand(0)->shape(); + const Shape& dst_shape = instr->shape(); + buffers.push_back(NcclCollectiveThunk::Buffer{ + ShapeUtil::ElementsIn(src_shape), slice_operand, slice_out_fake, + src_shape.layout().memory_space(), dst_shape.layout().memory_space(), + nullptr, nullptr}); + + if (instr->opcode() == HloOpcode::kReduceScatter) { + int64_t replica_count = instr->GetModule()->config().replica_count(); + int64_t partition_count = instr->GetModule()->config().num_partitions(); + auto rs = static_cast(instr); + TF_RETURN_IF_ERROR(NcclReduceScatterStartThunk::CheckImplementable( + rs, replica_count, partition_count)); + + // TODO: add special handling for degenerate case - where no communication + // is needed. Just copy. + auto rs_start_thunk = std::make_unique( + thunk_info, NcclApi::Default(), rs, buffers); + auto rs_done = std::make_unique( + /*kind=*/Thunk::kNcclReduceScatterDone, + /*thunk_info=*/Thunk::ThunkInfo::WithProfileAnnotation(rs), + /*async_events=*/rs_start_thunk->async_events(), + /*async_stream_kind=*/AsyncStreamKind::kCollective); + seq.emplace_back(std::move(rs_start_thunk)); + seq.emplace_back(std::move(rs_done)); + } else { + return absl::InternalError("Expected reduce-scatter hero instruction"); + } + + std::unique_ptr thunk = std::make_unique( + thunk_info, std::make_unique(std::move(seq)), + std::move(arguments), std::move(fake_allocations), + std::move(offset_buffer_indices), std::move(orig_shapes), + std::move(sliced_shapes), std::move(offset_byte_sizes)); + FusionEmissionResult result; + result.thunks.push_back(std::move(thunk)); + return result; +} + } // namespace absl::StatusOr CustomFusion::Emit( @@ -807,6 +924,16 @@ absl::StatusOr DynamicSliceFusion::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { const HloFusionAdaptor& adaptor = analysis_.fusion(); + // Only reduce-scatter is supported for now. + auto maybe_collective = + HloBfsFindIf(/*roots=*/adaptor.GetRoots(), /*fusion=*/adaptor, + /*visit=*/[](HloInstructionAdaptor node) -> bool { + return node.opcode() == HloOpcode::kReduceScatter; + }); + if (maybe_collective != std::nullopt) { + return EmitCollective(ir_emitter_context, adaptor, fusion, + &maybe_collective->instruction()); + } auto maybe_custom_call_adaptor = HloBfsFindIf( adaptor.GetRoots(), adaptor, [](auto node) { return node.opcode() == HloOpcode::kCustomCall; }); diff --git a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc index 954d4b656acb0c..f97a90dc80b5ef 100644 --- a/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/dynamic_slice_fusion_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include #include #include +#include #include +#include #include "absl/status/status.h" #include "xla/client/lib/constants.h" @@ -24,9 +26,11 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" +#include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/custom_call_target_registry.h" -#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/shape.h" @@ -34,6 +38,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_types.h" #include "xla/stream_executor/stream.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" @@ -80,6 +85,36 @@ class DynamicSliceFusionTest : public HloTestBase { config.set_debug_options(debug_options); return config; } + + HloModuleConfig GetModuleConfigWithDeterministicOps() { + DebugOptions debug_options = GetDebugOptionsForTest(); + debug_options.set_xla_gpu_deterministic_ops(true); + HloModuleConfig config; + config.set_debug_options(debug_options); + return config; + } + + std::vector GetAddressComputations(const HloModule& module) { + std::vector computations; + for (auto computation : module.computations()) { + if (!computation->IsFusionComputation()) { + continue; + } + auto backend_config = computation->FusionInstruction() + ->backend_config(); + if (backend_config.ok()) { + const FusionBackendConfig& fusion_backend_config = + backend_config.value().fusion_backend_config(); + const std::string name = + fusion_backend_config.custom_fusion_config().name(); + if (name == "dynamic_address_computation" || + name == "address_computation") { + computations.push_back(computation); + } + } + } + return computations; + } }; TEST_F(DynamicSliceFusionTest, CublasGemmSimple) { @@ -237,8 +272,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, ContiguousSlice) { @@ -867,7 +904,7 @@ TEST_F(DynamicSliceFusionTest, CustomCallSimple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1010,12 +1047,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallWithTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/true); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1059,12 +1096,12 @@ TEST_F(DynamicSliceFusionTest, NilTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1103,12 +1140,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallLegacyAPI) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1141,12 +1178,12 @@ TEST_F(DynamicSliceFusionTest, NilTupleLegacyAPI) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1327,8 +1364,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDynamicWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, DynamicContiguousSlice) { @@ -2156,8 +2195,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSWithWorkspace) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) { @@ -2241,8 +2282,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSWorkspaceIgnored) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetS32NotConstant) { @@ -2435,8 +2478,10 @@ TEST_F(DynamicSliceFusionTest, CublasGemmDUSOffsetOOB) { backend_config={"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}}} })"; - EXPECT_TRUE(RunAndCompareTwoModules(hlo_ref, hlo_opt, error_spec, - /*run_hlo_passes=*/false)); + EXPECT_TRUE(RunAndCompareTwoModules( + hlo_ref, hlo_opt, GetModuleConfigWithDeterministicOps(), + GetModuleConfigWithDeterministicOps(), error_spec, + /*run_hlo_passes=*/false)); } TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { @@ -2445,9 +2490,7 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { &b, "__xla_test$$memcpy", /*operands=*/ {DynamicSlice(Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), - {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "start0"), - Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), "start1")}, - {2, 128})}, + {ConstantR0(&b, 2), ConstantR0(&b, 0)}, {2, 128})}, ShapeUtil::MakeShape(F32, {2, 128}), /*opaque=*/"", /*has_side_effect=*/false, /*output_operand_aliasing=*/{}, /*literal=*/nullptr, @@ -2460,11 +2503,10 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallSimple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); DynamicSliceFusionRewriter pass(PLATFORM); @@ -2502,11 +2544,7 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { DynamicSlice( Parameter(&b, 0, ShapeUtil::MakeShape(S32, {4, 128}), "p0"), - {Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), - "start0"), - Parameter(&b, 2, ShapeUtil::MakeShape(S32, {}), - "start1")}, - {3, 128}), + {ConstantR0(&b, 20), ConstantR0(&b, 0)}, {3, 128}), }), }, ShapeUtil::MakeTupleShape({ @@ -2532,12 +2570,12 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/true); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2545,6 +2583,15 @@ TEST_F(DynamicSliceFusionTest, DynamicCustomCallWithTuple) { DynamicSliceFusionRewriter pass(PLATFORM); TF_ASSERT_OK_AND_ASSIGN(auto changed, this->RunHloPass(&pass, hlo_opt.get())); EXPECT_TRUE(changed); + EXPECT_TRUE(*RunFileCheck(hlo_opt->ToString(), R"( + // CHECK: %address-computation{{.+}} { + // CHECK: {{.+}} = {{.+}} slice + // CHECK: {{.+}} = {{.+}} dynamic-slice + // CHECK: {{.+}} = {{.+}} custom-call + // CHECK: ENTRY {{.+}} { + // CHECK-NOT: {{.+}} = {{.+}} slice + // CHECK-NOT: {{.+}} = {{.+}} dynamic-slice + )")); EXPECT_TRUE(RunAndCompareTwoModules(std::move(hlo_ref), std::move(hlo_opt), error_spec, /*run_hlo_passes=*/false)); @@ -2639,12 +2686,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUS) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2735,12 +2782,12 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUSTuple) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_ref, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); - debug_options.set_xla_gpu_enable_address_computation_fusion(true); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(true); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo_opt, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -2754,6 +2801,201 @@ TEST_F(DynamicSliceFusionTest, CustomCallDUSTuple) { /*run_hlo_passes=*/false)); } +TEST_F(DynamicSliceFusionTest, ReduceScatterDUSConstant) { + // DUS offset is a constant + const char* hlo_ref = R"( + HloModule test, replica_count=2 + + add.clone { + x.1 = f16[] parameter(0) + y.1 = f16[] parameter(1) + ROOT add.462 = f16[] add(x.1, y.1) + } + + ENTRY %main.9 { + param_0 = f16[128,128]{1,0} parameter(0) + param_1 = f16[128,128]{1,0} parameter(1) + constant_20 = u32[] constant(20) + constant_0 = u32[] constant(0) + reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone + ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0) + })"; + + const char* hlo_opt = R"( + HloModule test, replica_count=2 + + %add { + %param_0 = f16[] parameter(0) + %param_1 = f16[] parameter(1) + ROOT %add.1 = f16[] add(%param_0, %param_1) + } + + %address-computation { + %p1 = f16[128,128]{1,0} parameter(1) + %p0 = f16[128,128]{1,0} parameter(0) + %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add + %p2 = u32[] parameter(2) + %p3 = u32[] parameter(3) + ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3) + } + + ENTRY %main.9 { + %param_0.1 = f16[128,128]{1,0} parameter(0) + %param_1.1 = f16[128,128]{1,0} parameter(1) + %constant_20 = u32[] constant(20) + %constant_0 = u32[] constant(0) + ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0.1, %param_1.1, %constant_20, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false} + })"; + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true, + error_spec)); +} + +TEST_F(DynamicSliceFusionTest, ReduceScatterDUSParameterOffset) { + // DUS offset is a parameter. This enforces a d2h copy. + const char* hlo_ref = R"( + HloModule test, replica_count=2 + + add.clone { + x.1 = f16[] parameter(0) + y.1 = f16[] parameter(1) + ROOT add.462 = f16[] add(x.1, y.1) + } + + ENTRY %main.9 { + param_0 = f16[128,128]{1,0} parameter(0) + param_1 = f16[128,128]{1,0} parameter(1) + param_2 = u32[] parameter(2) + constant_0 = u32[] constant(0) + reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone + ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0) + })"; + + const char* hlo_opt = R"( + HloModule test, replica_count=2 + + %add { + %param_0 = f16[] parameter(0) + %param_1 = f16[] parameter(1) + ROOT %add.1 = f16[] add(f16[] %param_0, f16[] %param_1) + } + + %address-computation { + %p1 = f16[128,128]{1,0} parameter(1) + %p0 = f16[128,128]{1,0} parameter(0) + %reduce-scatter.1 = f16[64,128]{1,0} reduce-scatter(%p0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add + %p2 = u32[] parameter(2) + %p3 = u32[] parameter(3) + ROOT %loop_dynamic_update_slice_fusion.1 = f16[128,128]{1,0} dynamic-update-slice(%p1, %reduce-scatter.1, %p2, %p3) + } + + ENTRY %main.9 { + %param_0 = f16[128,128]{1,0} parameter(0) + %param_1 = f16[128,128]{1,0} parameter(1) + %param_2 = u32[] parameter(2) + %constant_0 = u32[] constant(0) + ROOT %address_computation = f16[128,128]{1,0} fusion(%param_0, %param_1, %param_2, %constant_0), kind=kCustom, calls=%address-computation, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__custom_fusion","custom_fusion_config":{"name":"dynamic_address_computation"}},"force_earliest_schedule":false} + })"; + + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated(hlo_ref, hlo_opt, true, true, + error_spec)); +} + +TEST_F(DynamicSliceFusionTest, ReduceScatterDUSLoopIterationOffset) { + const char* hlo_ref = R"( + HloModule jit_scan, replica_count=2 + + %add { + %param_0 = f32[] parameter(0) + %param_1 = f32[] parameter(1) + ROOT %add.1 = f32[] add(%param_0, %param_1) + } + + %region_0.14 { + %arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + %get-tuple-element.16 = s32[] get-tuple-element(%arg_tuple.15), index=0 + %constant.21 = s32[] constant(1) + %add.37 = s32[] add(%get-tuple-element.16, %constant.21) + %get-tuple-element.20 = f32[128,128]{1,0} get-tuple-element(%arg_tuple.15), index=4 + %get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=2 + %reduce-scatter.1 = f32[64,128]{1,0} reduce-scatter(%get-tuple-element.20), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=%add + %reshape.32 = f32[1,64,128]{2,1,0} reshape(%reduce-scatter.1) + %constant.23 = s32[] constant(0) + %compare.33 = pred[] compare(%get-tuple-element.16, %constant.23), direction=LT + %constant.22 = s32[] constant(128) + %add.34 = s32[] add(%get-tuple-element.16, %constant.22) + %select.35 = s32[] select(%compare.33, %add.34, %get-tuple-element.16) + %dynamic-update-slice.36 = f32[128,128,128]{2,1,0} dynamic-update-slice(%get-tuple-element.18, %reshape.32, %select.35, %constant.23, %constant.23) + %get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(%arg_tuple.15), index=3 + ROOT %tuple.38 = tuple(%add.37, %get-tuple-element.20, %dynamic-update-slice.36, %get-tuple-element.19, %get-tuple-element.20) + } + + %region_1.39 { + %arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + %get-tuple-element.41 = s32[] get-tuple-element(%arg_tuple.40), index=0 + %constant.46 = s32[] constant(128) + ROOT %compare.47 = pred[] compare(%get-tuple-element.41, %constant.46), direction=LT + } + + ENTRY %main.55 { + %constant.4 = s32[] constant(0) + %Arg_1.2 = f32[128,128]{1,0} parameter(1) + %constant.5 = f32[] constant(0) + %broadcast.6 = f32[128,128,128]{2,1,0} broadcast(%constant.5), dimensions={} + %Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) + %Arg_0.1 = f32[128,128]{1,0} parameter(0) + %tuple.7 = tuple(%constant.4, %Arg_1.2, %broadcast.6, %Arg_2.3, %Arg_0.1) + %while.48 = while(%tuple.7), condition=%region_1.39, body=%region_0.14 + %get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(%while.48), index=1 + %get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(%while.48), index=2 + ROOT %tuple.54 = tuple(%get-tuple-element.50, %get-tuple-element.51) + })"; + DebugOptions debugoptions = GetDebugOptionsForTest(); + + HloModuleConfig ref_config; + debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(false); + debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false); + ref_config.set_debug_options(debugoptions); + TF_ASSERT_OK_AND_ASSIGN(auto ref_module, + ParseAndReturnVerifiedModule(hlo_ref, ref_config)); + TF_ASSERT_OK_AND_ASSIGN(auto ref_module_opt, + GetOptimizedModule(std::move(ref_module))); + + HloModuleConfig opt_config; + debugoptions.set_xla_gpu_enable_dynamic_slice_fusion(true); + opt_config.set_debug_options(debugoptions); + debugoptions.set_xla_gpu_enable_pipelined_reduce_scatter(false); + TF_ASSERT_OK_AND_ASSIGN(auto module_with_adddress_computation_flag, + ParseAndReturnVerifiedModule(hlo_ref, opt_config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module_with_adddress_computation, + GetOptimizedModule(std::move(module_with_adddress_computation_flag))); + + std::vector address_computations_opt = + GetAddressComputations(*module_with_adddress_computation); + std::vector address_computations_ref = + GetAddressComputations(*ref_module_opt); + EXPECT_EQ(address_computations_ref.size(), 0); + ASSERT_EQ(address_computations_opt.size(), 1); + + // Check that reduce scatter happens in the fusion in optimized module and not + // outside the fusion. + EXPECT_TRUE(*RunFileCheck(address_computations_opt[0]->ToString(), R"( + // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}}) + // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}}) + )")); + EXPECT_TRUE(*RunFileCheck( + address_computations_opt[0]->FusionInstruction()->parent()->ToString(), + "// CHECK-NOT: {{.+}} = {{.*}}reduce-scatter")); + + ErrorSpec error{/*aabs=*/1e-3, /*arel=*/1e-3}; + EXPECT_TRUE(RunAndCompareTwoModulesReplicated( + std::move(ref_module_opt), std::move(module_with_adddress_computation), + false, true, error)); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 00835aaaf7fd84..200f06f8461db5 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -27,24 +27,24 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/fusions/concatenate.h" #include "xla/service/gpu/fusions/concatenate_mlir.h" #include "xla/service/gpu/fusions/copy.h" #include "xla/service/gpu/fusions/cudnn.h" #include "xla/service/gpu/fusions/custom.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" #include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h" -#include "xla/service/gpu/fusions/input_slices.h" #include "xla/service/gpu/fusions/input_slices_mlir.h" -#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/legacy/concatenate.h" +#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" +#include "xla/service/gpu/fusions/legacy/input_slices.h" +#include "xla/service/gpu/fusions/legacy/loop.h" +#include "xla/service/gpu/fusions/legacy/reduction.h" +#include "xla/service/gpu/fusions/legacy/scatter.h" +#include "xla/service/gpu/fusions/legacy/transpose.h" #include "xla/service/gpu/fusions/loop_mlir.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/reduction.h" #include "xla/service/gpu/fusions/reduction_mlir.h" -#include "xla/service/gpu/fusions/scatter.h" #include "xla/service/gpu/fusions/scatter_mlir.h" -#include "xla/service/gpu/fusions/transpose.h" #include "xla/service/gpu/fusions/transpose_mlir.h" #include "xla/service/gpu/fusions/triton.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -95,11 +95,11 @@ std::optional> HloFusionInfo::GetCopyFusion() bool HloFusionInfo::CanEmitDynamicUpdateSliceInPlace() const { auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - instr_, + analysis().fusion(), [this](const HloInstruction* instruction, const ShapeIndex& index) { return GetAllocationSlice(*buffer_assignment_, instruction, index); }, - analysis().fusion_roots()); + instr_); return ret.ok() && *ret; } @@ -113,19 +113,8 @@ std::unique_ptr GetFusionEmitter( .GetModule() ->config() .debug_options(); - auto check_mlir_emitters = [&](int64_t required_level, bool check = true) { - if (opts.xla_gpu_mlir_emitter_level() < required_level) { - return false; - } - CHECK(!check || - mlir_converter::IsHloConversionSupported( - analysis.fusion(), - fusion_info.analysis().device_info().gpu_compute_capability())) - << "Unsupported fusion: " - << analysis.fusion_root(0).instruction().parent()->ToString(); - - VLOG(5) << "Emitting with MLIR."; - return true; + auto check_mlir_emitters = [&](int64_t required_level) { + return opts.xla_gpu_mlir_emitter_level() >= required_level; }; switch (analysis.GetEmitterFusionKind()) { @@ -166,7 +155,7 @@ std::unique_ptr GetFusionEmitter( } return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: { - if (check_mlir_emitters(/*required_level=*/2, false)) { + if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique(analysis); } return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.h b/third_party/xla/xla/service/gpu/fusions/fusions.h index 9011c80d7f9f43..f7406b463b9117 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.h +++ b/third_party/xla/xla/service/gpu/fusions/fusions.h @@ -73,8 +73,9 @@ class PreBufferAssignmentFusionInfo : public FusionInfo { : FusionInfo(analysis) {} bool CanEmitDynamicUpdateSliceInPlace() const override { - // Optimistically assume all DUS fusions are in-place. - return true; + auto ret = CanEmitFusedDynamicUpdateSliceInPlaceForGpu( + analysis().fusion(), /*get_allocation_slice=*/{}); + return ret.value_or(false); } std::optional> GetCopyFusion() diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc index cc9d10ec7decaa..885d745b9a7978 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.cc @@ -68,7 +68,7 @@ constexpr int kDUSUpdateIndex = 1; LaunchDimensions MlirInPlaceDynamicUpdateSliceFusion::launch_dimensions() const { const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return CalculateLaunchDimensions(update_shape, analysis_.device_info()); } @@ -83,7 +83,7 @@ MlirInPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( auto launch_dims = launch_dimensions(); // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, update_shape, indexing_context); } @@ -98,7 +98,7 @@ MlirInPlaceDynamicUpdateSliceFusion::GetEpilogues( llvm::zip(dus_ops_, analysis_.fusion_roots())) { epilogues.push_back( mlir_converter::EpilogueSpecification::FromIdentityIndexing( - dus_op, &root.instruction(), mlir_context)); + &dus_op.instruction(), &root.instruction(), mlir_context)); } return epilogues; } @@ -133,7 +133,8 @@ absl::Status MlirInPlaceDynamicUpdateSliceFusion::EmitEntryFunction( llvm::SmallVector results; for (auto [instr, root, output] : llvm::zip(dus_ops_, analysis_.fusion_roots(), output_tensors)) { - const auto* dus_instr = Cast(instr); + const auto* dus_instr = + Cast(&instr.instruction()); const auto& update_shape = dus_instr->update()->shape(); SmallVector update_indices; auto start_indices = ProvideParameterRange( diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h index e1a5bc5310e88a..2ed84a06522b16 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h @@ -76,7 +76,7 @@ class MlirInPlaceDynamicUpdateSliceFusion : public MlirFusionEmitterBase { private: const HloFusionAnalysis& analysis_; - std::vector dus_ops_; + std::vector dus_ops_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc index b68a95e9516bfd..f18173ba0f54dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -56,7 +56,7 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirInPlaceDynamicUpdateSliceFusion fusion(analysis); auto thread_id_update_indexing = fusion.ComputeThreadIdToInputIndexing( @@ -100,8 +100,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 6)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 6)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6), domain: d0 in [0, 29] + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 6), domain: d0 in [0, 29] // CHECK: func.func @fused_computation // CHECK-SAME: %arg0: tensor<20x30xf32> // CHECK-SAME: %arg1: tensor<5x6xf32> @@ -112,8 +112,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, SimpleDUS) { // CHECK-DAG: %[[C_15:.*]] = arith.constant 15 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 29]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 29]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] @@ -151,8 +151,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP_1:.*]] = affine_map<(d0) -> (d0 floordiv 3)> - // CHECK-DAG: #[[MAP_2:.*]] = affine_map<(d0) -> (d0 mod 3)> + // CHECK-DAG: #[[MAP_1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 3), domain: d0 in [0, 5] + // CHECK-DAG: #[[MAP_2:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 3), domain: d0 in [0, 5] // CHECK: func.func @fused_computation // CHECK-SAME: %arg0: tensor<7x8xf32> // CHECK-SAME: %arg1: tensor<2x3xf32> @@ -162,8 +162,8 @@ TEST_F(MlirInPlaceDynamicUpdateSliceFusionTest, OutOfBoundDUS) { // CHECK-DAG: %[[C_5:.*]] = arith.constant 5 // CHECK-DAG: %[[C_0:.*]] = arith.constant 0 // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x - // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]] in [0, 5]) - // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]] in [0, 5]) + // CHECK: %[[INPUT_INDEX_0:.*]] = xla_gpu.apply_indexing #[[MAP_1]](%[[THREAD_ID]]) + // CHECK: %[[INPUT_INDEX_1:.*]] = xla_gpu.apply_indexing #[[MAP_2]](%[[THREAD_ID]]) // CHECK: %[[I0:.*]] = xla_gpu.pure_call @fused_computation_i0 // CHECK: %[[I1:.*]] = xla_gpu.pure_call @fused_computation_i1 // CHECK: %[[IDX0:.*]] = arith.index_cast %[[I0]] diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc index 5297bf7526de9d..d2739e0a8c3765 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir.cc @@ -38,9 +38,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc index abeb57accdabdf..4aec42e1c25aac 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/input_slices_mlir_test.cc @@ -45,7 +45,7 @@ TEST_F(MlirInputSlicesFusionTest, ThreadIndexing) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); auto emitter = GetEmitter(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/BUILD similarity index 66% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD rename to third_party/xla/xla/service/gpu/fusions/ir/BUILD index d618413f13e817..8250669ca75d2b 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/ir/BUILD @@ -1,4 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//xla/tests:build_defs.bzl", "xla_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -18,7 +19,9 @@ td_library( srcs = glob(["*.td"]), includes = ["."], deps = [ + "@llvm-project//mlir:BuiltinDialectTdFiles", "@llvm-project//mlir:CallInterfacesTdFiles", + "@llvm-project//mlir:ControlFlowInterfacesTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", "@llvm-project//mlir:OpBaseTdFiles", "@llvm-project//mlir:SideEffectInterfacesTdFiles", @@ -66,11 +69,15 @@ gentbl_cc_library( strip_include_prefix = ".", tbl_outs = [ ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + ], "xla_gpu_attrs.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + ], "xla_gpu_attrs.cc.inc", ), ], @@ -79,21 +86,46 @@ gentbl_cc_library( deps = [":xla_gpu_td_files"], ) +gentbl_cc_library( + name = "xla_gpu_types_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-typedef-decls", + "-typedefs-dialect=xla_gpu", + ], + "xla_gpu_types.h.inc", + ), + ( + [ + "-gen-typedef-defs", + "-typedefs-dialect=xla_gpu", + ], + "xla_gpu_types.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "xla_gpu_types.td", + deps = [":xla_gpu_td_files"], +) + cc_library( name = "xla_gpu", srcs = [ "xla_gpu_attrs.cc", "xla_gpu_dialect.cc", "xla_gpu_ops.cc", + "xla_gpu_types.cc", ], hdrs = [ - "xla_gpu_attrs.h", "xla_gpu_ops.h", ], deps = [ ":xla_gpu_attrs_inc_gen", ":xla_gpu_dialect_inc_gen", ":xla_gpu_ops_inc_gen", + ":xla_gpu_types_inc_gen", "//xla/service/gpu/model:indexing_analysis", "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", @@ -108,3 +140,18 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +xla_test( + name = "xla_gpu_ops_test", + srcs = ["xla_gpu_ops_test.cc"], + backends = ["gpu"], + deps = [ + ":xla_gpu", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD new file mode 100644 index 00000000000000..381d5a3220b1df --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir new file mode 100644 index 00000000000000..946e73494584be --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/canonicalize.mlir @@ -0,0 +1,224 @@ +// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s + +#map0 = #xla_gpu.indexing_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2), + domain: s0 in [-10, 10], s1 in [0, 2]> +func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { + %0:2 = xla_gpu.apply_indexing #map0 [%s0, %s1] + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<()[s0] -> (s0 + 1, s0 mod 2), +// CHECK-SAME: domain: s0 in [-10, 10]> + +// CHECK-LABEL: func.func @simplify_apply_indexing +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]]] + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2), + domain: d0 in [0, 1], d1 in [0, 2], d2 in [0, 3], s0 in [-11, 11], s1 in [0, 3]> +func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, + %d2: index, %s0: index, %s1: index) -> (index, index, index) { + %0:3 = xla_gpu.apply_indexing #map0(%d0, %d1, %d2)[%s0, %s1] + func.return %0#0, %0#1, %0#2 : index, index, index +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1), +// CHECK-SAME: domain: d0 in [0, 1], d1 in [0, 3], s0 in [-11, 11]> + +// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims +// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, +// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-SAME: (%[[ARG_0]], %[[ARG_2]]) +// CHECK-SAME: [%[[ARG_3]]] + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0), + domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]> +func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) + -> (index, index, index, index, index) { + %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + +// CHECK-LABEL: func.func @fold_indexing_map_results +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) + +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + +// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]] +// CHECK: return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]] + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0), + domain: d0 in [-10, 10], d1 in [0, 2], s0 in [-1, 1]> +func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { + %0:5 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#2 : index +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2), +// CHECK-SAME: domain: d0 in [0, 2]> + +// CHECK-LABEL: func.func @remove_unused_results +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) + +// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]]) +// CHECK: return %[[NEW_RESULT]] + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3), + domain: d0 in [0, 10], d1 in [0, 5], s0 in [-10, 10], s1 in [0, 4]> +func.func @fold_operands(%d0: index) -> index { + %d1 = arith.constant 1 : index + %s0 = arith.constant 2 : index + %s1 = arith.constant 3 : index + %0 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0, %s1] + func.return %0 : index +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 3), +// CHECK-SAME: domain: d0 in [0, 10]> + +// CHECK-LABEL: func.func @fold_operands +// CHECK-SAME: %[[ARG_0:.*]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]]) + +// ----- + +func.func @fold_operands_and_results(%arg0: index, %arg1: index) + -> (index, index) { + %0:2 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (0, d1), + domain: d0 in [0, 4], d1 in [0, 5]>(%arg0, %arg1) + return %0#0, %0#1 : index, index +} + +// CHECK-LABEL: func.func @fold_operands_and_results +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 +// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index + +// ----- + +func.func @fold_sequence(%arg0: index, %arg1: index) -> index { + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 100 + 42), + domain: d0 in [0, 10000]>(%0) + func.return %1 : index +} + +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]> +// CHECK-LABEL: func.func @fold_sequence +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-SAME: (%[[ARG0]], %[[ARG1]]) + +// ----- + +func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<()[s0] -> (s0 mod 100 + 42), + domain: s0 in [0, 10000]>(%0) + func.return %1 : index +} + +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1 + 42), +// CHECK-SAME: domain: d0 in [0, 5], d1 in [0, 4]> +// CHECK-LABEL: func.func @fold_sequence_sym +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-SAME: (%[[ARG0]], %[[ARG1]]) + +// ----- + +#indexing_map1 = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0 + 8512), + domain: d0 in [0, 1], d1 in [0, 607]> +#indexing_map2 = #xla_gpu.indexing_map< + (d0, d1, d2) -> (((d1 floordiv 32 + 1) mod 3) * 64 + + (d1 mod 32) * 2 + (d0 floordiv 192) * 192 + d2), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> + +func.func @fold_sequence_no_simplification_needed(%i: index) -> index { + %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} + %ind0 = xla_gpu.apply_indexing #indexing_map1(%i, %thread_id_x) + %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + func.return %ind1 : index +} +// CHECK: xla_gpu.apply_indexing +// CHECK-NOT: xla_gpu.apply_indexing + +// ----- + +#indexing_map1 = #xla_gpu.indexing_map<(d0) -> (3 * d0), + domain: d0 in [0, 9407]> +#indexing_map2 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 1), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> +#indexing_map3 = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0 floordiv 32 + 2), + domain: d0 in [0, 9407], d1 in [0, 607], d2 in [0, 1]> + +func.func @no_fold_when_producer_has_two_users(%i: index) -> (index, index) { + %thread_id_x = gpu.thread_id x {xla.range = [0 : index, 607 : index]} + %ind0 = xla_gpu.apply_indexing #indexing_map1(%thread_id_x) + %ind1 = xla_gpu.apply_indexing #indexing_map2(%ind0, %thread_id_x, %i) + %ind2 = xla_gpu.apply_indexing #indexing_map3(%ind0, %thread_id_x, %i) + func.return %ind1, %ind2 : index, index +} +// CHECK-COUNT-3: xla_gpu.apply_indexing + +// ----- + +func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 5], d1 in [0, 4]>(%arg0, %arg1) + %1 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0, d1) -> (d0 + d1), + domain: d0 in [0, 4], d1 in [0, 10000]>(%arg1, %0) + func.return %1 : index +} + +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 2 + d1), +// CHECK-SAME: domain: d0 in [0, 4], d1 in [0, 5]> +// CHECK-LABEL: func.func @fold_sequence_shared_operands +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) +// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] +// CHECK-SAME: (%[[ARG1]], %[[ARG0]]) + +// ----- + +func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index) + -> (tensor<2x3xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + ^bb0(%current : f32): + xla_gpu.yield %current : f32 + } + return %ret : tensor<2x3xf32> +} +// CHECK-LABEL: func.func @atomic_rmw_empty +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32> +// CHECK: return %[[ARG0]] + + +// ----- + +func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) + -> (tensor<2x3xf32>) { + %cst = arith.constant 0.0 : f32 + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + ^bb0(%current : f32): + xla_gpu.yield %cst : f32 + } + return %ret : tensor<2x3xf32> +} +// CHECK-LABEL: func.func @atomic_rmw_cst +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32> +// CHECK-NEXT: %[[CST:.*]] = arith.constant +// CHECK-NEXT: atomic_rmw +// CHECK: xla_gpu.yield %[[CST]] diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir new file mode 100644 index 00000000000000..cd2e09f0adab72 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/indexing_map_attr.mlir @@ -0,0 +1,136 @@ +// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1, d2)[s0] -> (d0), +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2], +// CHECK-SAME: d1 in [5, 8], +// CHECK-SAME: d2 in [10, 12], +// CHECK-SAME: s0 in [0, 32], +// CHECK-SAME: d0 mod 2 in [0, 1], +// CHECK-SAME: d0 + s0 in [1, 10] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), + domain: + d0 in [1, 2], + d1 in [5, 8], + d2 in [10, 12], + s0 in [0, 32], + d0 mod 2 in [0, 1], + d0 + s0 in [1, 10] + > + +func.func private @indexing_map_attr(!xla_gpu.indexed_vector<64x64x32xf64, #map>) +// CHECK-LABEL: @indexing_map_attr +// CHECK: !xla_gpu.indexed_vector<64x64x32xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: s0 in [0, 10] +// CHECK-SAME: s1 in [0, 5] +// CHECK-SAME: s2 in [0, 32] +// CHECK-SAME: d0 mod 2 in [0, 1] +// CHECK-SAME: d0 + s0 in [1, 10] +// CHECK-SAME: d1 + s1 + s2 in [1, 32] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 10], + s1 in [0, 5], + s2 in [0, 32], + d0 mod 2 in [0, 1], + d0 + s0 in [1, 10], + d1 + s1 + s2 in [1, 32] + > +func.func private @more_range_vars(!xla_gpu.indexed_vector<100x32xf64, #map>) +// CHECK-LABEL: @more_range_vars +// CHECK: !xla_gpu.indexed_vector<100x32xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0)[s0] -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [0, 100] +// CHECK-SAME: s0 in [-3, -1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0), + domain: + d0 in [0, 100], + s0 in [-3, -1] + > +func.func private @indexing_map_small(!xla_gpu.indexed_vector<100xf64, #map>) +// CHECK-LABEL: @indexing_map_small +// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1, d2)[s0] -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: d2 in [10, 12] +// CHECK-SAME: s0 in [0, 32] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0), + domain: + d0 in [1, 2], + d1 in [5, 8], + d2 in [10, 12], + s0 in [0, 32] + > +func.func private @no_constraints(!xla_gpu.indexed_vector<32xf64, #map>) +// CHECK-LABEL: @no_constraints +// CHECK: !xla_gpu.indexed_vector<32xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: ()[s0] -> (s0) +// CHECK-SAME: domain: +// CHECK-SAME: s0 in [3, 5] +// CHECK-SAME: s0 mod 2 in [0, 1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<()[s0] -> (s0), + domain: + s0 in [3, 5], + s0 mod 2 in [0, 1] + > +func.func private @no_dimensions(!xla_gpu.indexed_vector<100xf64, #map>) +// CHECK-LABEL: @no_dimensions +// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0) -> (d0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [3, 5] +// CHECK-SAME: d0 mod 2 in [0, 1] +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<(d0) -> (d0), + domain: + d0 in [3, 5], + d0 mod 2 in [0, 1] + > +func.func private @no_symbols(!xla_gpu.indexed_vector<100xf64, #map>) +// CHECK-LABEL: @no_symbols +// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> + +// ----- + +// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: () -> () +// CHECK-SAME: > +#map = #xla_gpu.indexing_map<() -> ()> +func.func private @empty(!xla_gpu.indexed_vector<100xf64, #map>) +// CHECK-LABEL: @empty +// CHECK: !xla_gpu.indexed_vector<100xf64, #[[$INDEX_MAP]]> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/inlining.mlir rename to third_party/xla/xla/service/gpu/fusions/ir/tests/inlining.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir new file mode 100644 index 00000000000000..999a6de959328c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/invalid.mlir @@ -0,0 +1,236 @@ +// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics + +#map0 = #xla_gpu.indexing_map< + (d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32] +> +func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { + // expected-error @+1 {{operand count must match the number of dimensions and symbols in the affine map}} + %0:2 = xla_gpu.apply_indexing #map0 (%d0) + func.return %0#0, %0#1 : index, index +} + +// ----- + +#map0 = #xla_gpu.indexing_map< + (d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32], + d0 mod 2 in [0, 1], + d0 + s0 in [1, 10] +> +func.func @cannot_have_constraints(%d0: index, %d1: index, %s0: index) -> (index, index) { + // expected-error @+1 {{apply indexing op cannot have any constraints}} + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#0, %0#1 : index, index +} + +// ----- + +#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]> +func.func @loop_result_num_mismatch(%input: tensor<1024x32xf32>, + %init: f32) -> (f32) { + // expected-error @+1 {{mismatch in number of loop-carried values and results}} + %sum:2 = "xla_gpu.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla_gpu.yield %add : f32 + }) : (f32) -> (f32, f32) + func.return %sum#0 : f32 +} + +// ----- + +#map = #xla_gpu.indexing_map<()[s0] -> (s0, s0), domain: s0 in [0, 1024]> +func.func @loop_iv_num_mismatch(%input: tensor<1024x32xf32>, + %init: f32) -> (f32) { + // expected-error @+1 {{mismatch in number of induction variables 2 and RangeVars}} + %sum = "xla_gpu.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla_gpu.yield %add : f32 + }) : (f32) -> (f32) + func.return %sum : f32 +} + +// ----- + +#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]> +func.func @loop_types_mismatch(%input: tensor<1024x32xf32>, %init: f32) -> (i32) { + // expected-error @+1 {{block iter arg type = 'f32', result type = 'i32' and init operand type = 'f32' should match}} + %sum = "xla_gpu.loop"(%init) <{ + indexing_map_attr = #map, + operandSegmentSizes = array + }> ({ + ^bb0(%i: index, %j: index, %iter: f32): + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %iter, %t : f32 + xla_gpu.yield %add : f32 + }) : (f32) -> (i32) + func.return %sum : i32 +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]> +func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { + // expected-error @+1 {{mismatch in number of dims operands 0 and DimVars in the indexing map}} + %sum = xla_gpu.loop ()[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } {xla.range = [0 : index, 42 : index]} + func.return %sum : f32 +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]> +func.func @indicies_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> { + // expected-error @+1 {{number of indices must match number of dimensions of indexing map}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map> +} + +// ----- + +#map = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +func.func @no_thread_id_in(%input: tensor<32x64xf32>, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{must have thread_id dimension in both indexing maps}} + %0 = xla_gpu.materialize @exp(%input) at #map() : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +#map1 = #xla_gpu.indexing_map<()[s0, s1] -> (s0, s1), domain: s0 in [0, 1024], s1 in [0, 32]> +func.func @no_thread_id_out(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{must have thread_id dimension in both indexing maps}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 64], s0 in [0, 1024], s1 in [0, 32]> +func.func @thread_id_bounds_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{thread_id dimension must have the same bounds in both indexing maps}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32], d0 + s0 in [0, 1024]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +func.func @thread_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the thread_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0, s0), domain: d0 in [0, 32], s0 in [0, 1024]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +func.func @symbol_count_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{number of symbols in both indexing_maps must match}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 32]> +func.func @symbol_domain_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{domain of symbols of indexing_maps must match}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]> +func.func @symbol_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 mod 2 in [0, 0]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]> +func.func @symbol_constraint_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 1024]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64], s0 + s1 in [0, 32]> +func.func @symbol_constraint_interval_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for all symbols}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64]> +func.func @vector_mapping_depends_on_block_id(%input: tensor<32x64xf32>, %thread_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{vector mapping indices must not depend on the block ID}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> +#map1 = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], s0 in [0, 1024], s1 in [0, 64]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> +func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 2 in [0, 0]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, s1), domain: d0 in [0, 32], d1 in [0, 64], s0 in [0, 1024], s1 in [0, 64], d1 mod 4 in [0, 0]> +func.func @block_id_constraints_mismatch(%input: tensor<32x64xf32>, %thread_id: index, %block_id: index, %output: tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> { + // expected-error @+1 {{constraints of indexing maps must be equal for the block_id dimension}} + %0 = xla_gpu.materialize @exp(%input) at #map(%thread_id, %block_id) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + func.return %0 : !xla_gpu.indexed_vector<32x64xf32, #map1> +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir new file mode 100644 index 00000000000000..c4378a35fd4284 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -0,0 +1,165 @@ +// RUN: mlir_fusions_opt %s --split-input-file | FileCheck %s +// Verify the printed output can be parsed. +// RUN: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s +// Verify the generic form can be parsed. +// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s + +func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) { + %shared1 = xla_gpu.allocate_shared : tensor<2xf32> + %shared2 = xla_gpu.allocate_shared : tensor<2xf32> + %sync:2 = xla_gpu.sync_threads %shared1, %shared2 + : tensor<2xf32>, tensor<2xf32> + return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32> +} +// CHECK-LABEL: @shared_and_sync +// CHECK-NEXT: allocate_shared +// CHECK-NEXT: allocate_shared +// CHECK-NEXT: sync_threads +// CHECK-NEXT: return + +// ----- + +func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index) + -> (tensor<2x3xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 42.0 : f32 + %add = arith.addf %current, %c42 : f32 + xla_gpu.yield %add : f32 + } + return %ret : tensor<2x3xf32> +} +// CHECK-LABEL: @atomic_rmw +// CHECK: xla_gpu.atomic_rmw + +// ----- + +func.func private @add(%a: f32, %b: f32) -> f32 { + %ret = arith.addf %a, %b : f32 + return %ret : f32 +} + +func.func @caller(%a: f32, %b: f32) -> f32 { + %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = arith.addf %c, %d : f32 + return %ret : f32 +} +// CHECK-LABEL: @caller +// CHECK: %[[C:.*]] = xla_gpu.pure_call @add +// CHECK: %[[D:.*]] = xla_gpu.pure_call @add +// CHECK: arith.addf %[[C]], %[[D]] + +// CHECK-CSE: @caller +// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add +// CHECK-CSE: arith.addf %[[C]], %[[C]] + +// ----- + +#map0 = #xla_gpu.indexing_map< +(d0, d1)[s0] -> (d0, d1 + s0), + domain: + d0 in [1, 2], + d1 in [5, 8], + s0 in [0, 32] +> +func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1)[%s0] + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1)[s0] -> (d0, d1 + s0) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [1, 2] +// CHECK-SAME: d1 in [5, 8] +// CHECK-SAME: s0 in [0, 32] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing +// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP0]] +// CHECK-SAME: (%[[d0]], %[[d1]])[%[[s0]]] + +// ----- + +#map0 = #xla_gpu.indexing_map< +(d0, d1) -> (d0, d1), + domain: + d0 in [0, 2], + d1 in [1, 3] +> +func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { + %0:2 = xla_gpu.apply_indexing #map0 (%d0, %d1) + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: (d0, d1) -> (d0, d1) +// CHECK-SAME: domain: +// CHECK-SAME: d0 in [0, 2] +// CHECK-SAME: d1 in [1, 3] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing_no_symbols +// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP0]] +// CHECK-SAME: (%[[d0]], %[[d1]]) + +// ----- + +#map0 = #xla_gpu.indexing_map< + ()[s0] -> (s0, s0), + domain: + s0 in [2, 4] +> +func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { + %0:2 = xla_gpu.apply_indexing #map0 [%s0] + func.return %0#0, %0#1 : index, index +} +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map< +// CHECK-SAME: ()[s0] -> (s0, s0) +// CHECK-SAME: domain: +// CHECK-SAME: s0 in [2, 4] +// CHECK-SAME: > + +// CHECK-LABEL: @apply_indexing_no_dims +// CHECK: (%[[s0:.*]]: index) +// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]]] + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]> +func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { + %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } {xla.range = [0 : index, 42 : index]} + func.return %sum : f32 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map +// CHECK: %0 = xla_gpu.loop (%{{.*}})[%[[I:.*]], %[[J:.*]]] in #[[$MAP]] +// CHECK-SAME: iter_args(%[[SUM_ITER:.*]] = %{{.*}}) -> (f32) { +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[%[[I]], %[[J]]] +// CHECK: %[[ADD:.*]] = arith.addf %{{.*}}, %[[EXTRACTED]] : f32 +// CHECK: xla_gpu.yield %[[ADD]] : f32 +// CHECK: } {xla.range = [0 : index, 42 : index]} + +// ----- + +func.func private @exp(%p0: tensor<32x64xf32>, %i: index, %j: index) -> f32 + +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1), domain: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]> +func.func @materialize_and_insert(%input: tensor<32x64xf32>, %i: index, %j: index, %output: tensor<32x64xf32>) -> tensor<32x64xf32> { + %0 = xla_gpu.materialize @exp(%input) at #map(%i, %j) : (tensor<32x64xf32>) -> !xla_gpu.indexed_vector<32x64xf32, #map1> + %1 = xla_gpu.insert %0 into %output at #map1(%i, %j) : !xla_gpu.indexed_vector<32x64xf32, #map1> -> tensor<32x64xf32> into tensor<32x64xf32> + func.return %1 : tensor<32x64xf32> +} + +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (d0 + s0, d1 + s1) +// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]> +// CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (s0, s1) +// CHECK-SAME: d0 in [0, 32], d1 in [0, 2], s0 in [0, 1024], s1 in [0, 32]> +// CHECK-LABEL: @materialize_and_insert +// CHECK: %[[MATERIALIZED:.*]] = xla_gpu.materialize @exp(%{{.*}}) at #[[$MAP]](%{{.*}}, %{{.*}}) +// CHECK: xla_gpu.insert %[[MATERIALIZED]] into %{{.*}} at #[[$MAP1]](%{{.*}}, %{{.*}}) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc similarity index 62% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc index d3829056de5dc3..a3220b06ccf9d2 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.cc @@ -13,12 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" - #include #include #include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" #include "mlir/IR/AffineExpr.h" @@ -32,7 +31,7 @@ limitations under the License. #define GET_ATTRDEF_LIST #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" namespace xla { namespace gpu { @@ -43,8 +42,8 @@ using mlir::AffineExpr; using mlir::ArrayRef; using mlir::AsmParser; using mlir::AsmPrinter; -using mlir::failed; using mlir::failure; +using mlir::success; ParseResult ParseInterval(AsmParser& parser, Interval& interval) { // ParseResult converts to `true` if parsing failed. @@ -54,60 +53,73 @@ ParseResult ParseInterval(AsmParser& parser, Interval& interval) { } void PrintDimVars(AsmPrinter& p, ArrayRef dim_vars) { - for (int i = 0; i < dim_vars.size(); ++i) { - p << "d" << i << " in " << dim_vars[i].bounds << "\n"; - } + int index = 0; + llvm::interleaveComma(dim_vars, p, [&](const DimVar& dim_var) { + p << "d" << index++ << " in " << dim_var.bounds; + }); } -mlir::FailureOr> ParseDimVars( - AsmParser& parser, ArrayRef dim_names) { - SmallVector dim_vars; - for (const auto& dim_name : dim_names) { +ParseResult ParseDimVars(AsmParser& parser, ArrayRef dim_names, + SmallVector& dim_vars) { + dim_vars.reserve(dim_names.size()); + for (const auto& [index, dim_name] : llvm::enumerate(dim_names)) { if (parser.parseKeyword(dim_name) || parser.parseKeyword("in") || ParseInterval(parser, dim_vars.emplace_back().bounds)) { return failure(); } + if (index < dim_names.size() - 1 && parser.parseComma()) { + return failure(); + } } - return dim_vars; + return success(); } void PrintRangeVars(AsmPrinter& p, ArrayRef range_vars) { - for (int i = 0; i < range_vars.size(); ++i) { - p << "s" << i << " in " << range_vars[i].range << "\n"; - } + int index = 0; + llvm::interleaveComma(range_vars, p, [&](const RangeVar& range_var) { + p << "s" << index++ << " in " << range_var.range; + }); } -mlir::FailureOr> ParseRangeVars( - AsmParser& parser, ArrayRef range_symbol_names) { - SmallVector range_vars; - for (const auto& range_symbol_name : range_symbol_names) { +ParseResult ParseRangeVars(AsmParser& parser, + ArrayRef range_symbol_names, + SmallVector& range_vars) { + range_vars.reserve(range_symbol_names.size()); + for (const auto& [index, range_symbol_name] : + llvm::enumerate(range_symbol_names)) { if (parser.parseKeyword(range_symbol_name) || parser.parseKeyword("in") || ParseInterval(parser, range_vars.emplace_back().range)) { return failure(); } + if (index < range_symbol_names.size() - 1 && parser.parseComma()) { + return failure(); + } } - return range_vars; + return success(); } void PrintConstraints(AsmPrinter& p, ArrayRef> constraints) { - for (const auto& [constrained_expression, range] : constraints) { - p << constrained_expression << " in " << range << "\n"; - } + llvm::interleaveComma(constraints, p, [&](const auto& constraint) { + p << constraint.first << " in " << constraint.second; + }); } -mlir::FailureOr>> ParseConstraints( +ParseResult ParseConstraints( AsmParser& parser, - ArrayRef> symbolSet) { - SmallVector> constraints; - while (failed(parser.parseOptionalGreater())) { + ArrayRef> symbolSet, + SmallVector>& constraints) { + // In order for there to be any constraints, there must be at least 1 symbol + // or dimension meaning there will be commas for as long as there are + // constraints left. + while (succeeded(parser.parseOptionalComma())) { auto& constraint = constraints.emplace_back(); if (parser.parseAffineExpr(symbolSet, constraint.first) || parser.parseKeyword("in") || ParseInterval(parser, constraint.second)) { return failure(); } } - return constraints; + return success(); } mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { @@ -131,35 +143,55 @@ mlir::Attribute IndexingMapAttr::parse(mlir::AsmParser& parser, mlir::Type) { symbolSet.push_back( {symbol_strings[i], mlir::getAffineSymbolExpr(i, parser.getContext())}); } - - if (parser.parseKeyword("domain") || parser.parseColon()) { - return {}; + if (map.getNumDims() + map.getNumSymbols() > 0) { + if (parser.parseComma() || parser.parseKeyword("domain") || + parser.parseColon()) { + return {}; + } } - auto maybe_dim_vars = ParseDimVars(parser, dim_strings); - if (failed(maybe_dim_vars)) { - return {}; + + SmallVector dim_vars; + if (map.getNumDims() > 0) { + if (ParseDimVars(parser, dim_strings, dim_vars)) { + return {}; + } } - auto maybe_range_vars = ParseRangeVars(parser, symbol_strings); - if (failed(maybe_range_vars)) { - return {}; + SmallVector range_vars; + if (map.getNumSymbols() > 0) { + if (!dim_vars.empty() && parser.parseComma()) { + return {}; + } + if (ParseRangeVars(parser, symbol_strings, range_vars)) { + return {}; + } } - auto maybe_constraints = ParseConstraints(parser, symbolSet); - if (failed(maybe_constraints)) { + SmallVector> constraints; + if (ParseConstraints(parser, symbolSet, constraints) || + parser.parseGreater()) { return {}; } - // ParseConstraints consumes the > to know when to stop. - return IndexingMapAttr::get(parser.getContext(), map, *maybe_dim_vars, - *maybe_range_vars, *maybe_constraints); + return IndexingMapAttr::get(parser.getContext(), map, dim_vars, range_vars, + constraints); } void IndexingMapAttr::print(mlir::AsmPrinter& printer) const { - printer << "<\n"; + printer << "<"; printer.printStrippedAttrOrType(getMap()); - printer << "\ndomain:\n"; + if (getDimVars().size() + getRangeVars().size() + getConstraints().size() > + 0) { + printer << ", domain: "; + } PrintDimVars(printer, getDimVars()); + if (!getDimVars().empty() && + getRangeVars().size() + getConstraints().size() > 0) { + printer << ", "; + } PrintRangeVars(printer, getRangeVars()); + if (!getRangeVars().empty() && !getConstraints().empty()) { + printer << ", "; + } PrintConstraints(printer, getConstraints()); printer << ">"; } @@ -190,5 +222,10 @@ mlir::LogicalResult IndexingMapAttr::verify( return mlir::success(); } +IndexingMap IndexingMapAttr::getIndexingMap() { + return IndexingMap(getMap(), getDimVars(), getRangeVars(), /*rt_vars=*/{}, + getConstraints()); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td similarity index 89% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 8c8f98c05e2737..19dd24f2e67a2c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -17,7 +17,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS include "mlir/IR/AttrTypeBase.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" class XLAGPU_Attr traits = []> : AttrDef { @@ -55,6 +55,10 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { AttrBuilder<(ins "const ::xla::gpu::IndexingMap&":$indexing_map)>, ]; let genVerifyDecl = 1; + let extraClassDeclaration = [{ + // Returns the indexing map constructed from IndexingMapAttr. + xla::gpu::IndexingMap getIndexingMap(); + }]; } -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_ATTRS +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_ATTRS diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc similarity index 90% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc index 3dc60c91f40779..57d2d706737089 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.cc @@ -13,15 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep #include "mlir/IR/DialectImplementation.h" // IWYU pragma: keep #include "mlir/IR/OpImplementation.h" // IWYU pragma: keep #include "mlir/Transforms/InliningUtils.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" #undef GET_ATTRDEF_CLASSES +#define GET_TYPEDEF_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" +#undef GET_TYPEDEF_CLASSES namespace xla { namespace gpu { @@ -114,15 +116,20 @@ struct XlaGpuOpAsmDialectInterface : public mlir::OpAsmDialectInterface { void XlaGpuDialect::initialize() { addOperations< #define GET_OP_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" #undef GET_OP_LIST >(); addAttributes< #define GET_ATTRDEF_LIST -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.cc.inc" >(); #undef GET_ATTRDEF_LIST addInterfaces(); + addTypes< +#define GET_TYPEDEF_LIST +#include "xla/service/gpu/fusions/ir/xla_gpu_types.cc.inc" +#undef GET_TYPEDEF_LIST + >(); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td index 4400747923cb6e..9a5c539e39e591 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_dialect.td @@ -27,6 +27,7 @@ def XlaGpuDialect : Dialect { let cppNamespace = "::xla::gpu"; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; } -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_DIALECT +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_DIALECT diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc similarity index 53% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index dfa4d056a80bda..39a0318af58896 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include #include #include #include -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" @@ -39,12 +38,13 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" // IWYU pragma: keep #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/TypeRange.h" #include "mlir/IR/TypeUtilities.h" // IWYU pragma: keep #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.cc.inc" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -54,12 +54,15 @@ namespace { using llvm::ArrayRef; using mlir::AffineExpr; using mlir::AffineMap; +using mlir::Block; using mlir::failure; using mlir::getAffineConstantExpr; using mlir::getAffineDimExpr; using mlir::getAffineSymbolExpr; using mlir::LogicalResult; using mlir::MLIRContext; +using mlir::OpAsmParser; +using mlir::OpAsmPrinter; using mlir::OpBuilder; using mlir::OperationState; using mlir::PatternRewriter; @@ -68,6 +71,7 @@ using mlir::Region; using mlir::SmallVector; using mlir::success; using mlir::Type; +using mlir::TypeRange; using mlir::Value; using mlir::ValueRange; @@ -123,85 +127,50 @@ void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, const IndexingMap& indexing_map) { - build(builder, result, operands, indexing_map.GetAffineMap(), - indexing_map.GetDimVars(), indexing_map.GetRangeVars()); + SmallVector result_types(indexing_map.GetAffineMap().getNumResults(), + builder.getIndexType()); + IndexingMapAttr indexing_map_attr = + IndexingMapAttr::get(builder.getContext(), indexing_map); + build(builder, result, result_types, operands, indexing_map_attr); } void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, ValueRange operands, AffineMap affine_map, ArrayRef dim_vars, ArrayRef range_vars) { - SmallVector lower_bounds, upper_bounds; - for (const DimVar& dim_var : dim_vars) { - lower_bounds.push_back(dim_var.bounds.lower); - upper_bounds.push_back(dim_var.bounds.upper); - } - for (const RangeVar& range_var : range_vars) { - lower_bounds.push_back(range_var.range.lower); - upper_bounds.push_back(range_var.range.upper); - } - build(builder, result, operands, affine_map, lower_bounds, upper_bounds); -} - -void ApplyIndexingOp::build(OpBuilder& builder, OperationState& result, - ValueRange operands, AffineMap affine_map, - ArrayRef lower_bounds, - ArrayRef upper_bounds) { - SmallVector result_types(affine_map.getNumResults(), - builder.getIndexType()); - build(builder, result, result_types, operands, affine_map, lower_bounds, - upper_bounds); + IndexingMap indexing_map(affine_map, dim_vars, range_vars, {}); + build(builder, result, operands, indexing_map); } -// Parser a comma-separated list of type %operand in [lower_bound, upper_bound]. -// Adds the parsed elements to the provided containers. -mlir::ParseResult parseOperandsWithBoundsList( - mlir::OpAsmParser& parser, - SmallVector* operands, - SmallVector* lower_bounds, - SmallVector* upper_bounds) { - int64_t lower_bound, upper_bound; - mlir::OpAsmParser::UnresolvedOperand operand; - if (parser.parseCommaSeparatedList([&]() { - if (parser.parseOperand(operand) || parser.parseKeyword("in") || - parser.parseLSquare() || parser.parseInteger(lower_bound) || - parser.parseComma() || parser.parseInteger(upper_bound) || - parser.parseRSquare()) { - return failure(); - } - operands->push_back(operand); - lower_bounds->push_back(lower_bound); - upper_bounds->push_back(upper_bound); - return success(); - })) { - return failure(); - } - return success(); +// Parses a comma-separated list of operands, ex: %d1, %d2. +mlir::ParseResult parseOperands( + OpAsmParser& parser, + SmallVector* operands) { + OpAsmParser::UnresolvedOperand operand; + return parser.parseCommaSeparatedList( + [&]() { return parser.parseOperand(operands->emplace_back()); }); } -mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser, +mlir::ParseResult ApplyIndexingOp::parse(OpAsmParser& parser, OperationState& result) { mlir::Builder& builder = parser.getBuilder(); auto index_type = builder.getIndexType(); - mlir::AffineMapAttr affine_map_attr; - if (parser.parseAttribute(affine_map_attr, "map", result.attributes)) { + IndexingMapAttr indexing_map_attr; + if (parser.parseAttribute(indexing_map_attr, "indexing_map_attr", + result.attributes)) { return failure(); } - SmallVector operands; + SmallVector operands; SmallVector lower_bounds, upper_bounds; if (succeeded(parser.parseOptionalLParen())) { - if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds, - &upper_bounds) || - parser.parseRParen()) { + if (parseOperands(parser, &operands) || parser.parseRParen()) { return failure(); } } if (succeeded(parser.parseOptionalLSquare())) { - if (parseOperandsWithBoundsList(parser, &operands, &lower_bounds, - &upper_bounds) || - parser.parseRSquare()) { + if (parseOperands(parser, &operands) || parser.parseRSquare()) { return failure(); } } @@ -209,86 +178,52 @@ mlir::ParseResult ApplyIndexingOp::parse(mlir::OpAsmParser& parser, parser.parseOptionalAttrDict(result.attributes)) { return failure(); } - result.addAttribute("lower_bounds", - builder.getDenseI64ArrayAttr(lower_bounds)); - result.addAttribute("upper_bounds", - builder.getDenseI64ArrayAttr(upper_bounds)); - - auto map = affine_map_attr.getAffineMap(); + auto map = indexing_map_attr.getMap(); result.addTypes(SmallVector(map.getNumResults(), index_type)); return success(); } -void ApplyIndexingOp::print(mlir::OpAsmPrinter& p) { - mlir::AffineMapAttr affine_map_attr = getMapAttr(); - AffineMap affine_map = affine_map_attr.getAffineMap(); - p << " " << affine_map_attr; +void ApplyIndexingOp::print(OpAsmPrinter& p) { + AffineMap affine_map = getIndexingMapAttr().getMap(); + p << " " << getIndexingMapAttr(); - auto lower_bounds = getLowerBounds(); - auto upper_bounds = getUpperBounds(); auto operands = getOperands(); unsigned num_dimensions = affine_map.getNumDims(); if (num_dimensions > 0) { p << '('; - for (int dim_id = 0; dim_id < num_dimensions; ++dim_id) { - p << operands[dim_id] << " in " << '[' << lower_bounds[dim_id] << ", " - << upper_bounds[dim_id] << ']'; - if (dim_id != num_dimensions - 1) { - p << ", "; - } - } + auto dimension_operands = operands.slice(0, num_dimensions); + llvm::interleaveComma(dimension_operands, p); p << ')'; } + unsigned num_symbols = affine_map.getNumSymbols(); if (num_symbols > 0) { p << '['; - for (int symbol_id = 0; symbol_id < num_symbols; ++symbol_id) { - unsigned operand_id = num_dimensions + symbol_id; - p << operands[operand_id] << " in " << '[' << lower_bounds[operand_id] - << ", " << upper_bounds[operand_id] << ']'; - if (symbol_id != num_symbols - 1) { - p << ", "; - } - } + auto symbol_operands = operands.slice(num_dimensions, num_symbols); + llvm::interleaveComma(symbol_operands, p); p << ']'; } - p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{ - "map", "lower_bounds", "upper_bounds"}); + + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{"indexing_map_attr"}); } LogicalResult ApplyIndexingOp::verify() { - auto affine_map = getMapAttr().getAffineMap(); + auto affine_map = getIndexingMapAttr().getMap(); unsigned num_variables = affine_map.getNumDims() + affine_map.getNumSymbols(); - if (getOperands().size() != num_variables || - getLowerBounds().size() != num_variables || - getUpperBounds().size() != num_variables) { + if (getOperands().size() != num_variables) { return emitOpError( - "operand, lower_bounds, upper_bounds count and affine map dimension " - "and symbol count must match"); + "operand count must match the number of dimensions and symbols in the " + "affine map"); + } + if (!getIndexingMapAttr().getConstraints().empty()) { + return emitOpError("apply indexing op cannot have any constraints"); } return success(); } IndexingMap ApplyIndexingOp::getIndexingMap() { - auto lower_bounds = getLowerBounds(); - auto upper_bounds = getUpperBounds(); - - AffineMap affine_map = getAffineMap(); - unsigned num_dimensions = affine_map.getNumDims(); - std::vector dim_vars; - dim_vars.reserve(num_dimensions); - for (unsigned id = 0; id < num_dimensions; ++id) { - dim_vars.push_back(DimVar{Interval{lower_bounds[id], upper_bounds[id]}}); - } - unsigned num_symbols = affine_map.getNumSymbols(); - std::vector range_vars; - range_vars.reserve(num_symbols); - for (unsigned id = num_dimensions; id < num_symbols + num_dimensions; ++id) { - range_vars.push_back( - RangeVar{Interval{lower_bounds[id], upper_bounds[id]}}); - } - return IndexingMap(affine_map, std::move(dim_vars), std::move(range_vars), - /*rt_vars=*/{}); + return getIndexingMapAttr().getIndexingMap(); } namespace { @@ -334,6 +269,19 @@ struct FoldApplyIndexingSequence LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { + SmallVector, 2> apply_indexing_ops; + bool all_apply_indexing_operands_have_one_use = true; + for (auto& operand : indexing_op->getOpOperands()) { + if (auto producer = operand.get().getDefiningOp()) { + apply_indexing_ops.push_back({operand.getOperandNumber(), producer}); + all_apply_indexing_operands_have_one_use &= producer->hasOneUse(); + } + } + if (apply_indexing_ops.empty()) { + return rewriter.notifyMatchFailure(indexing_op, + "No apply_indexing sequences found"); + } + MLIRContext* ctx = indexing_op.getContext(); int num_dims = indexing_op.getAffineMap().getNumDims(); int num_syms = indexing_op.getAffineMap().getNumSymbols(); @@ -354,53 +302,44 @@ struct FoldApplyIndexingSequence auto new_sym_vars = this_map.GetRangeVars(); mlir::DenseMap replacements; - for (auto& operand : indexing_op->getOpOperands()) { - if (auto producer = operand.get().getDefiningOp()) { - auto producer_map = producer.getIndexingMap(); - int producer_result_id = - mlir::cast(operand.get()).getResultNumber(); - int num_producer_dims = producer.getAffineMap().getNumDims(); - SmallVector producer_dim_replacements; - SmallVector producer_sym_replacements; - for (auto& producer_operand : producer->getOpOperands()) { - int producer_operand_number = producer_operand.getOperandNumber(); - bool is_dim = producer_operand_number < num_producer_dims; - auto& replacement_expr = operand_exprs[producer_operand.get()]; - if (!replacement_expr) { - if (is_dim) { - int dim_num = producer_operand_number; - replacement_expr = - getAffineDimExpr(num_dims + added_dim_args.size(), ctx); - added_dim_args.push_back(producer_operand.get()); - new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); - } else { - int sym_num = producer_operand_number - - producer.getAffineMap().getNumDims(); - replacement_expr = - getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx); - added_sym_args.push_back(producer_operand.get()); - new_sym_vars.push_back(producer_map.GetRangeVar(sym_num)); - } - } - + for (auto& [operand_id, producer] : apply_indexing_ops) { + auto producer_map = producer.getIndexingMap(); + mlir::OpResult producer_result = producer->getOpResult(0); + int producer_result_id = producer_result.getResultNumber(); + int num_producer_dims = producer.getAffineMap().getNumDims(); + SmallVector producer_dim_replacements; + SmallVector producer_sym_replacements; + for (auto& producer_operand : producer->getOpOperands()) { + int producer_operand_number = producer_operand.getOperandNumber(); + bool is_dim = producer_operand_number < num_producer_dims; + auto& replacement_expr = operand_exprs[producer_operand.get()]; + if (!replacement_expr) { if (is_dim) { - producer_dim_replacements.push_back(replacement_expr); + int dim_num = producer_operand_number; + replacement_expr = + getAffineDimExpr(num_dims + added_dim_args.size(), ctx); + added_dim_args.push_back(producer_operand.get()); + new_dim_vars.push_back(producer_map.GetDimVars(dim_num)); } else { - producer_sym_replacements.push_back(replacement_expr); + int sym_num = + producer_operand_number - producer.getAffineMap().getNumDims(); + replacement_expr = + getAffineSymbolExpr(num_syms + added_sym_args.size(), ctx); + added_sym_args.push_back(producer_operand.get()); + new_sym_vars.push_back(producer_map.GetRangeVar(sym_num)); } } - - replacements[operand_exprs[operand.get()]] = - producer.getAffineMap() - .getResult(producer_result_id) - .replaceDimsAndSymbols(producer_dim_replacements, - producer_sym_replacements); + if (is_dim) { + producer_dim_replacements.push_back(replacement_expr); + } else { + producer_sym_replacements.push_back(replacement_expr); + } } - } - - if (replacements.empty()) { - return rewriter.notifyMatchFailure(indexing_op, - "No apply_indexing sequences found"); + replacements[operand_exprs[producer_result]] = + producer.getAffineMap() + .getResult(producer_result_id) + .replaceDimsAndSymbols(producer_dim_replacements, + producer_sym_replacements); } int new_num_operands = indexing_op->getNumOperands() + @@ -410,10 +349,12 @@ struct FoldApplyIndexingSequence num_syms + added_sym_args.size()); IndexingMap new_indexing_map(new_affine_map, new_dim_vars, new_sym_vars, /*rt_vars=*/{}); - if (!new_indexing_map.Simplify()) { + if (!all_apply_indexing_operands_have_one_use && + !new_indexing_map.Simplify()) { return rewriter.notifyMatchFailure( indexing_op, "Folded indexing map was not simplified"); } + SmallVector new_operands; new_operands.reserve(new_num_operands); @@ -436,7 +377,8 @@ struct FoldApplyIndexingOperands LogicalResult matchAndRewrite(ApplyIndexingOp indexing_op, PatternRewriter& rewriter) const override { - AffineMap affine_map = indexing_op.getAffineMap(); + IndexingMap indexing_map = indexing_op.getIndexingMap(); + AffineMap affine_map = indexing_map.GetAffineMap(); MLIRContext* ctx = affine_map.getContext(); unsigned num_operands = indexing_op->getNumOperands(); @@ -446,8 +388,6 @@ struct FoldApplyIndexingOperands SmallVector> constant_values(num_operands, std::nullopt); int num_constants = 0; - SmallVector dim_id_map(num_dims, -1); - SmallVector symbol_id_map(num_symbols, -1); for (auto& operand : indexing_op->getOpOperands()) { if (auto constant = operand.get().getDefiningOp()) { @@ -466,15 +406,15 @@ struct FoldApplyIndexingOperands unsigned new_num_operands = indexing_op->getNumOperands() - num_constants; SmallVector new_operands; new_operands.reserve(new_num_operands); - SmallVector new_lbs, new_ubs; - new_lbs.reserve(new_num_operands); - new_ubs.reserve(new_num_operands); + SmallVector new_dim_vars; + new_dim_vars.reserve(num_dims); + SmallVector new_range_vars; + new_range_vars.reserve(num_symbols); unsigned new_num_dims = 0; unsigned new_num_symbols = 0; - for (auto [operand, constant_value, lb, ub] : llvm::zip( - indexing_op->getOpOperands(), constant_values, - indexing_op.getLowerBounds(), indexing_op.getUpperBounds())) { + for (auto [operand, constant_value] : + llvm::zip(indexing_op->getOpOperands(), constant_values)) { unsigned operand_id = operand.getOperandNumber(); if (constant_value.has_value()) { if (operand_id < num_dims) { @@ -485,22 +425,23 @@ struct FoldApplyIndexingOperands getAffineConstantExpr(*constant_value, ctx)); } } else { + new_operands.push_back(operand.get()); if (operand_id < num_dims) { dim_replacements.push_back(getAffineDimExpr(new_num_dims++, ctx)); + new_dim_vars.push_back(indexing_map.GetDimVars(operand_id)); } else { symbol_replacements.push_back( getAffineSymbolExpr(new_num_symbols++, ctx)); + new_range_vars.push_back( + indexing_map.GetRangeVar(operand_id - num_dims)); } - new_operands.push_back(operand.get()); - new_lbs.push_back(lb); - new_ubs.push_back(ub); } } rewriter.replaceOpWithNewOp( indexing_op, new_operands, affine_map.replaceDimsAndSymbols(dim_replacements, symbol_replacements, new_num_dims, new_num_symbols), - new_lbs, new_ubs); + new_dim_vars, new_range_vars); return success(); } }; @@ -648,8 +589,259 @@ void SyncThreadsOp::getAsmResultNames( } } +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +void LoopOp::build(OpBuilder& builder, OperationState& result, + IndexingMapAttr indexing_map_attr, ValueRange dims, + ValueRange inits, BodyBuilderFn bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + + int64_t num_ivs = indexing_map_attr.getRangeVars().size(); + result.addOperands(dims); + result.addOperands(inits); + result.addTypes(TypeRange(inits)); + Block* body_block = builder.createBlock(result.addRegion()); + // Add induction variables block args. + for (int i = 0; i < num_ivs; ++i) { + body_block->addArgument(builder.getIndexType(), result.location); + } + // Add iteration arguments block args. + for (auto init_type : TypeRange(inits)) { + body_block->addArguments(init_type, result.location); + } + + mlir::OperationName opname(LoopOp::getOperationName(), builder.getContext()); + result.addAttribute(LoopOp::getIndexingMapAttrAttrName(opname), + indexing_map_attr); + result.addAttribute( + LoopOp::getOperandSegmentSizesAttrName(opname), + builder.getDenseI32ArrayAttr({static_cast(dims.size()), + static_cast(inits.size())})); + if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(body_block); + bodyBuilder(builder, result.location, + body_block->getArguments().take_front(num_ivs), + body_block->getArguments().drop_front(num_ivs)); + } +} + +void LoopOp::build(OpBuilder& builder, OperationState& result, + const IndexingMap& indexing_map, ValueRange dims, + ValueRange inits, BodyBuilderFn bodyBuilder) { + build(builder, result, + IndexingMapAttr::get(builder.getContext(), indexing_map), dims, inits, + bodyBuilder); +} + +mlir::ParseResult LoopOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector region_args, ivs, iter_args; + SmallVector dim_operands; + + // Parse the dimension values. + OpBuilder b(parser.getContext()); + Type index_type = b.getIndexType(); + if (parser.parseOperandList(dim_operands, OpAsmParser::Delimiter::Paren) || + parser.resolveOperands(dim_operands, index_type, result.operands)) + return failure(); + // Parse the induction variables. + if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Square)) + return failure(); + for (auto iv : ivs) { + region_args.push_back(iv); + region_args.back().type = index_type; + } + + // Parse the indexing map attribute. + IndexingMapAttr indexing_map_attr; + if (parser.parseKeyword("in") || + parser.parseAttribute(indexing_map_attr, "indexing_map_attr", + result.attributes)) { + return failure(); + } + + // Parse the arguments. + SmallVector init_operands; + if (parser.parseKeyword("iter_args") || + parser.parseAssignmentList(iter_args, init_operands) || + parser.parseArrowTypeList(result.types) || + parser.resolveOperands(init_operands, result.types, parser.getNameLoc(), + result.operands)) + return failure(); + + for (auto [index, iter_arg] : llvm::enumerate(iter_args)) { + region_args.push_back(iter_arg); + region_args.back().type = result.types[index]; + } + + if (region_args.size() != result.types.size() + ivs.size()) { + return parser.emitError(parser.getNameLoc(), + "mismatch in number of induction variables + " + "loop-carried values and the number of results"); + } + + // Parse the body region. + Region* body = result.addRegion(); + if (parser.parseRegion(*body, region_args)) return failure(); + LoopOp::ensureTerminator(*body, b, result.location); + + // Parse the optional attribute list + result.addAttribute( + LoopOp::getOperandSegmentSizeAttr(), + b.getDenseI32ArrayAttr({static_cast(dim_operands.size()), + static_cast(iter_args.size())})); + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + return success(); +} + +void LoopOp::print(OpAsmPrinter& p) { + p << " (" << getDims() << ")[" << getInductionVars() << "] in " + << getIndexingMapAttr() << " iter_args("; + llvm::interleaveComma( + llvm::zip(getRegionIterArgs(), getInits()), p, + [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); + p << ") -> (" << getInits().getTypes() << ") "; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/{ + getIndexingMapAttrAttrName(), + getOperandSegmentSizesAttrName(), + }); +} + +LogicalResult LoopOp::verify() { + if (getInits().size() != getNumResults()) { + return emitOpError("mismatch in number of loop-carried values and results"); + } + IndexingMap indexing_map = getIndexingMap(); + if (indexing_map.GetRangeVarsCount() != getNumInductionVars()) { + return emitOpError() << "mismatch in number of induction variables " + << getNumInductionVars() + << " and RangeVars in the indexing map " + << indexing_map.ToString(); + } + if (indexing_map.GetDimVarsCount() != getDims().size()) { + return emitOpError() << "mismatch in number of dims operands " + << getDims().size() + << " and DimVars in the indexing map " + << indexing_map.ToString(); + } + for (auto [bb_arg, result_type, init] : + llvm::zip(getRegionIterArgs(), getResultTypes(), getInits())) { + if (bb_arg.getType() != result_type || init.getType() != result_type) { + return emitOpError() << "block iter arg type = " << bb_arg.getType() + << ", result type = " << result_type + << " and init operand type = " << init.getType() + << " should match"; + } + } + return success(); +} + +IndexingMap LoopOp::getIndexingMap() { + return getIndexingMapAttr().getIndexingMap(); +} + +//===----------------------------------------------------------------------===// +// MaterializeOp +//===----------------------------------------------------------------------===// + +VariableConstraints GetConstraintsForVariables(const IndexingMap& map) { + VariableConstraints result; + result.constraints_for_dims.resize(map.GetDimensionCount()); + result.constraints_for_symbols.resize(map.GetSymbolCount()); + for (const auto& constraint : map.GetConstraints()) { + constraint.first.walk([&](mlir::AffineExpr leaf) { + if (auto dim = mlir::dyn_cast(leaf)) { + result.constraints_for_dims[dim.getPosition()].push_back(constraint); + } else if (auto sym = mlir::dyn_cast(leaf)) { + result.constraints_for_symbols[sym.getPosition()].push_back(constraint); + } + }); + } + return result; +} + +LogicalResult MaterializeOp::verify() { + IndexingMap map_in = getMap().getIndexingMap(); + IndexingMap map_out = + getResult().getType().getIndexingMapAttr().getIndexingMap(); + if (getIndices().size() != map_in.GetDimVarsCount()) { + return emitOpError() << "number of indices must match number of dimensions " + "of indexing map"; + } + + // The thread dimension must have the same domain (range and constraints) + if (map_in.GetDimVarsCount() == 0 || map_out.GetDimVarsCount() == 0) { + return emitOpError() + << "must have thread_id dimension in both indexing maps"; + } + if (map_in.GetDimVars(0) != map_out.GetDimVars(0)) { + return emitOpError() << "thread_id dimension must have the same bounds in " + "both indexing maps"; + } + + auto variable_constraints_in = GetConstraintsForVariables(map_in); + auto variable_constraints_out = GetConstraintsForVariables(map_out); + if (variable_constraints_in.constraints_for_dims[0] != + variable_constraints_out.constraints_for_dims[0]) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the thread_id dimension"; + } + + // The two maps must have the same symbols and they must have the same domain + if (map_in.GetRangeVarsCount() != map_out.GetRangeVarsCount()) { + return emitOpError() + << "number of symbols in both indexing_maps must match"; + } + for (auto const& [range_in, range_out] : + llvm::zip(map_in.GetRangeVars(), map_out.GetRangeVars())) { + if (range_in.range != range_out.range) { + return emitOpError() << "domain of symbols of indexing_maps must match"; + } + } + if (variable_constraints_in.constraints_for_symbols != + variable_constraints_out.constraints_for_symbols) { + return emitOpError() + << "constraints of indexing maps must be equal for all symbols"; + } + + // The vector mapping indices must not depend on the block ID + if (map_out.GetDimVarsCount() > 1) { + for (auto expr : map_out.GetAffineMap().getResults()) { + if (expr.isFunctionOfDim(1)) { + return emitOpError() << "vector mapping indices must not depend on the " + << "block ID"; + } + } + } + // If there are constraints on the block ID, they must be the same in both + // maps + if (map_in.GetDimVarsCount() > 1 && map_out.GetDimVarsCount() > 1) { + if (variable_constraints_in.constraints_for_dims[1] != + variable_constraints_out.constraints_for_dims[1]) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the block_id dimension"; + } + } else if (map_in.GetDimVarsCount() > 1 && + !variable_constraints_in.constraints_for_dims[1].empty()) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the block_id dimension"; + } else if (map_out.GetDimVarsCount() > 1 && + !variable_constraints_out.constraints_for_dims[1].empty()) { + return emitOpError() << "constraints of indexing maps must be equal for " + << "the block_id dimension"; + } + + return success(); +} + } // namespace gpu } // namespace xla #define GET_OP_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.cc.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.cc.inc" diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h similarity index 60% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h index f43786f4fde0ac..e3b8fd641f9a06 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.h @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#define XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ +#include + +#include "llvm/ADT/SmallVector.h" #include "mlir/Bytecode/BytecodeOpInterface.h" // IWYU pragma: keep #include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep #include "mlir/IR/Attributes.h" // IWYU pragma: keep @@ -26,14 +29,28 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" // IWYU pragma: keep #include "mlir/Interfaces/InferTypeOpInterface.h" // IWYU pragma: keep #include "mlir/Interfaces/SideEffectInterfaces.h" // IWYU pragma: keep -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h" // IWYU pragma: keep - -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.h.inc" -#define GET_OP_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h.inc" -#undef GET_OP_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep #define GET_ATTRDEF_CLASSES -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_attrs.h.inc" +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" #undef GET_ATTRDEF_CLASSES +#define GET_TYPEDEF_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" +#undef GET_TYPEDEF_CLASSES +#define GET_OP_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h.inc" +#undef GET_OP_CLASSES + +namespace xla::gpu { + +struct VariableConstraints { + llvm::SmallVector>> + constraints_for_dims; + llvm::SmallVector>> + constraints_for_symbols; +}; +VariableConstraints GetConstraintsForVariables(const IndexingMap& map); + +} // namespace xla::gpu -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_IR_XLA_GPU_OPS_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_IR_XLA_GPU_OPS_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td similarity index 68% rename from third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td rename to third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td index c05f843465427c..9eb246f70d9d34 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.td @@ -22,7 +22,10 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "xla/service/gpu/fusions/mlir/ir/xla_gpu_dialect.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" +include "xla/service/gpu/fusions/ir/xla_gpu_types.td" class XLAGPU_Op traits = []> : Op { @@ -104,11 +107,13 @@ def XLAGPU_AtomicRMWOp : XLAGPU_Op<"atomic_rmw", }]; } -def XLAGPU_YieldOp : XLAGPU_Op<"yield", - [HasParent<"::xla::gpu::AtomicRMWOp">, Terminator]> { +def XLAGPU_YieldOp : XLAGPU_Op<"yield", [ + ParentOneOf<["::xla::gpu::AtomicRMWOp", "::xla::gpu::LoopOp"]>, + ReturnLike, Terminator]> { let summary = "Terminator for atomic_rmw ops."; let arguments = (ins AnyType:$result); + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let assemblyFormat = "$result attr-dict `:` type($result)"; } @@ -232,7 +237,7 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { symbol arguments must be equal to the respective number of dimensional and symbolic inputs in the affine map. The affine mapping can be multi-dimensional, and so the `apply_indexing` operation always returns one - value. The operands and results must all have ‘index’ type. + value. The operands and results must all have ‘index’ type. Example: @@ -242,9 +247,7 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { ``` }]; let arguments = (ins Variadic:$operands, - AffineMapAttr:$map, - DenseI64ArrayAttr:$lower_bounds, - DenseI64ArrayAttr:$upper_bounds); + XLAGPU_IndexingMapAttr:$indexing_map_attr); let results = (outs Variadic); let builders = [ @@ -255,16 +258,12 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { OpBuilder<(ins "mlir::ValueRange":$operands, "mlir::AffineMap":$affine_map, "llvm::ArrayRef":$dim_vars, "llvm::ArrayRef":$range_vars)>, - OpBuilder<(ins "mlir::ValueRange":$operands, - "mlir::AffineMap":$affine_map, - "llvm::ArrayRef":$lower_bounds, - "llvm::ArrayRef":$upper_bounds)>, ]; let extraClassDeclaration = [{ - // Returns the indexing map constructed from affine_map and the bounds. + // Returns the indexing map constructed from IndexingMapAttr. xla::gpu::IndexingMap getIndexingMap(); // Extracts the affine map from the attribute. - mlir::AffineMap getAffineMap() { return getMapAttr().getAffineMap(); } + mlir::AffineMap getAffineMap() { return getIndexingMapAttr().getMap(); } }]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; @@ -272,4 +271,97 @@ def ApplyIndexingOp : XLAGPU_Op<"apply_indexing", [Pure]> { let hasFolder = 1; } -#endif // MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS +def LoopOp : XLAGPU_Op<"loop", [ + AttrSizedOperandSegments, Pure, + SingleBlockImplicitTerminator<"xla::gpu::YieldOp"> + ]> { + let summary = "Loop nest that iterates over all feasible values of RangeVars."; + let description = [{ + + ```mlir + #map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (d0 + s0, s1), + domain: + d0 in [0, 3], + s0 in [0, 1024], + s1 in [0, 32] + > + // Initial sum set to 0. + %sum_0 = arith.constant 0.0 : f32 + %dim = arith.constant 1 : index + // iter_args binds initial values to the loop's region arguments. + %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_iter = %sum_0) -> (f32) { + %t = tensor.extract %buffer[%i, %j] : tensor<1024x32xf32> + %sum_next = arith.addf %sum_iter, %t : f32 + // Yield current iteration sum to next iteration %sum_iter or to %sum + // if final iteration. + scf.yield %sum_next : f32 + } + ``` + }]; + let arguments = (ins XLAGPU_IndexingMapAttr:$indexing_map_attr, + Variadic:$dims, + Variadic:$inits); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$region); + + let builders = [ + OpBuilder<(ins "IndexingMapAttr":$indexing_map_attr, + "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, + CArg<"llvm::function_ref", + "nullptr">)>, + OpBuilder<(ins "const IndexingMap&":$indexing_map, + "mlir::ValueRange":$dims, "mlir::ValueRange":$inits, + CArg<"llvm::function_ref", + "nullptr">)> + ]; + + let extraClassDeclaration = [{ + using BodyBuilderFn = + llvm::function_ref; + + // Returns the indexing map constructed from IndexingMapAttr. + xla::gpu::IndexingMap getIndexingMap(); + int64_t getNumInductionVars() { + return getBody()->getNumArguments() - getNumResults(); + } + mlir::BlockArgument getInductionVar(int64_t index) { + return getBody()->getArgument(index); + } + mlir::Block::BlockArgListType getInductionVars() { + return getBody()->getArguments().take_front(getNumInductionVars()); + } + mlir::Block::BlockArgListType getRegionIterArgs() { + return getBody()->getArguments().drop_front(getNumInductionVars()); + } + }]; + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + +def XLAGPU_MaterializeOp : XLAGPU_Op<"materialize", [AttrSizedOperandSegments]> { + let summary = "Reads a tensor into registers"; + let arguments = (ins Variadic:$input, + Variadic:$indices, + FlatSymbolRefAttr:$callee, + XLAGPU_IndexingMapAttr:$map); + let results = (outs XLAGPU_IndexedVectorType:$result); + let hasVerifier = 1; + let assemblyFormat = [{ + $callee `(` $input `)` `at` $map `(` $indices `)` attr-dict `:` functional-type($input, results) + }]; +} + +def XLAGPU_InsertOp : XLAGPU_Op<"insert", []> { + let summary = "Inserts an indexed vector into a tensor"; + let arguments = (ins XLAGPU_IndexedVectorType:$source, + Variadic:$indices, + AnyRankedTensor:$dest, + XLAGPU_IndexingMapAttr:$map); + let results = (outs AnyRankedTensor:$result); + let assemblyFormat = [{ + $source `into` $dest `at` $map `(` $indices `)` attr-dict `:` type($source) `->` type($dest) `into` type($result) + }]; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_OPS diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc new file mode 100644 index 00000000000000..2d9076d7803280 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" + +#include +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +class XLAGPUOpsTest : public HloTestBase { + public: + mlir::MLIRContext mlir_context_; +}; + +TEST_F(XLAGPUOpsTest, GetConstraintsForVariables) { + auto map = IndexingMap( + ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), + /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, + /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + map.AddConstraint(ParseAffineExpr("s0 mod 4", &mlir_context_), + Interval{0, 1}); + map.AddConstraint(ParseAffineExpr("s1 mod 4", &mlir_context_), + Interval{0, 2}); + map.AddConstraint(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}); + map.AddConstraint(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}); + map.AddConstraint(ParseAffineExpr("d1 mod 32", &mlir_context_), + Interval{0, 6}); + + auto constraints_for_variables = GetConstraintsForVariables(map); + EXPECT_THAT(constraints_for_variables.constraints_for_dims[0], + UnorderedElementsAre()); + EXPECT_THAT( + constraints_for_variables.constraints_for_dims[1], + UnorderedElementsAre( + Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}), + Pair(ParseAffineExpr("d1 mod 32", &mlir_context_), Interval{0, 6}))); + EXPECT_THAT( + constraints_for_variables.constraints_for_symbols[0], + UnorderedElementsAre( + Pair(ParseAffineExpr("s0 mod 4", &mlir_context_), Interval{0, 1}), + Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}))); + EXPECT_THAT( + constraints_for_variables.constraints_for_symbols[1], + UnorderedElementsAre( + Pair(ParseAffineExpr("s1 mod 4", &mlir_context_), Interval{0, 2}), + Pair(ParseAffineExpr("s0 + s1", &mlir_context_), Interval{0, 3}), + Pair(ParseAffineExpr("s1 + d1", &mlir_context_), Interval{0, 4}))); +} + +TEST_F(XLAGPUOpsTest, GetConstraintsForVariablesEmpty) { + auto map = IndexingMap( + ParseAffineMap("(d0, d1)[s0, s1] -> (d0+s0, d1+s1)", &mlir_context_), + /*dimensions=*/{{Interval{0, 5}}, {Interval{0, 2}}}, + /*range_vars=*/{{Interval{0, 32}}, {Interval{0, 1024}}}, /*rt_vars=*/{}); + auto constraints_for_variables = GetConstraintsForVariables(map); + EXPECT_THAT(constraints_for_variables.constraints_for_dims, + ElementsAre(IsEmpty(), IsEmpty())); + EXPECT_THAT(constraints_for_variables.constraints_for_symbols, + ElementsAre(IsEmpty(), IsEmpty())); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc new file mode 100644 index 00000000000000..1c1b218db7bc19 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.cc @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/Attributes.h" // IWYU pragma: keep +#include "mlir/IR/BuiltinTypes.h" // IWYU pragma: keep +#include "mlir/IR/Dialect.h" // IWYU pragma: keep +#include "mlir/IR/OpImplementation.h" // IWYU pragma: keep +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_dialect.h.inc" +#include "xla/service/gpu/model/indexing_map.h" // IWYU pragma: keep +#define GET_ATTRDEF_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_attrs.h.inc" +#undef GET_ATTRDEF_CLASSES +#define GET_TYPEDEF_LIST +#define GET_TYPEDEF_CLASSES +#include "xla/service/gpu/fusions/ir/xla_gpu_types.h.inc" + +namespace xla { +namespace gpu { + +mlir::Type IndexedVectorType::parse(mlir::AsmParser& parser) { + mlir::SmallVector shape; + mlir::Type type; + IndexingMapAttr indexing_map_attr; + if (parser.parseLess() || + parser.parseDimensionList(shape, /*allowDynamic=*/false) || + parser.parseType(type) || parser.parseComma() || + parser.parseAttribute(indexing_map_attr) || parser.parseGreater()) { + return {}; + } + return IndexedVectorType::get(parser.getContext(), shape, type, + indexing_map_attr); +} + +void IndexedVectorType::print(mlir::AsmPrinter& printer) const { + printer << "<"; + printer.printDimensionList(getShape()); + printer << "x" << getElementType() << ", " << getIndexingMapAttr() << ">"; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td new file mode 100644 index 00000000000000..5d73344654d1de --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_types.td @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES +#define XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypes.td" +include "mlir/IR/BuiltinTypeInterfaces.td" +include "xla/service/gpu/fusions/ir/xla_gpu_dialect.td" +include "xla/service/gpu/fusions/ir/xla_gpu_attrs.td" + +class XLAGPU_Type traits = []> + : TypeDef { + let mnemonic = typeMnemonic; +} + +def XLAGPU_IndexedVectorType : XLAGPU_Type<"IndexedVector", "indexed_vector", + [ShapedTypeInterface, ValueSemantics]> { + let summary = "Vector type with a specified layout"; + let parameters = (ins + ArrayRefParameter<"int64_t">:$shape, + "mlir::Type":$elementType, + XLAGPU_IndexingMapAttr:$indexing_map_attr + ); + let hasCustomAssemblyFormat = 1; + let extraClassDeclaration = [{ + IndexedVectorType cloneWith(std::optional> shape, + mlir::Type elementType) const { + return IndexedVectorType::get(getContext(), shape.value_or(getShape()), + elementType, getIndexingMapAttr()); + } + + bool hasRank() const { return true; } + }]; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_TYPES \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD new file mode 100644 index 00000000000000..98d8ade7c5e5c3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD @@ -0,0 +1,406 @@ +load("//xla:xla.bzl", "xla_cc_test") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/service/gpu/fusions:__pkg__"], + licenses = ["notice"], +) + +cc_library( + name = "in_place_dynamic_update_slice", + srcs = ["in_place_dynamic_update_slice.cc"], + hdrs = ["in_place_dynamic_update_slice.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:dynamic_update_slice_util", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "in_place_dynamic_update_slice_test", + srcs = ["in_place_dynamic_update_slice_test.cc"], + deps = [ + ":in_place_dynamic_update_slice", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "loop", + srcs = ["loop.cc"], + hdrs = ["loop.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "loop_test", + srcs = ["loop_test.cc"], + deps = [ + "//xla:status_macros", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "scatter", + srcs = ["scatter.cc"], + hdrs = ["scatter.h"], + deps = [ + ":loop", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "scatter_test", + srcs = ["scatter_test.cc"], + deps = [ + ":scatter", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "tiling_util", + srcs = ["tiling_util.cc"], + hdrs = ["tiling_util.h"], + visibility = ["//xla/service/gpu:__subpackages__"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:target_util", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", + "//xla/service/llvm_ir:llvm_util", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduction", + srcs = ["reduction.cc"], + hdrs = ["reduction.h"], + deps = [ + ":tiling_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:kernel_arguments", + "//xla/service/gpu:kernel_reuse_cache", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu:reduction_utils", + "//xla/service/gpu:target_util", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/fusions:reduction_base", + "//xla/service/gpu/fusions:thunk_util", + "//xla/service/gpu/runtime:kernel_thunk", + "//xla/service/gpu/runtime:thunk", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", + "//xla/service/llvm_ir:llvm_util", + "//xla/service/llvm_ir:loop_emitter", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_test", + srcs = ["reduction_test.cc"], + deps = [ + ":reduction", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "concatenate", + srcs = ["concatenate.cc"], + hdrs = ["concatenate.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:loop_emitter", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "concatenate_test", + srcs = ["concatenate_test.cc"], + deps = [ + ":concatenate", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "transpose", + srcs = ["transpose.cc"], + hdrs = ["transpose.h"], + deps = [ + ":tiling_util", + "//xla:permutation_util", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:target_util", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:llvm_util", + "//xla/service/llvm_ir:loop_emitter", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "transpose_test", + srcs = ["transpose_test.cc"], + deps = [ + ":transpose", + "//xla:status_macros", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "input_slices", + srcs = ["input_slices.cc"], + hdrs = ["input_slices.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:elemental_ir_emitter", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:ir_emitter", + "//xla/service/gpu:ir_emitter_context", + "//xla/service/gpu:launch_dimensions", + "//xla/service/gpu:parallel_loop_emitter", + "//xla/service/gpu/fusions:fusion_emitter", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/llvm_ir:fused_ir_emitter", + "//xla/service/llvm_ir:ir_array", + "//xla/service/llvm_ir:kernel_support_library", + "//xla/service/llvm_ir:llvm_loop", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:ir_headers", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "input_slices_test", + srcs = ["input_slices_test.cc"], + deps = [ + ":input_slices", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/model:affine_map_printer", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:IR", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/README.md b/third_party/xla/xla/service/gpu/fusions/legacy/README.md new file mode 100644 index 00000000000000..0fa6bb98f73147 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/legacy/README.md @@ -0,0 +1,8 @@ +# Deprecated emitters + +The emitters in this directory are deprecated. Please do not add any new +features. If you believe you need to add a feature, please reach out and +describe your use case. + +These emitters have more modern MLIR-based equivalents in the directory above +this one. \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/concatenate.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc index 55a45dcd2d2b85..8bb0e04cc7f337 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/legacy/concatenate.h" #include #include diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate.h b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/concatenate.h rename to third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h index e838b2996e57b3..be0465b421e916 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ -#define XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ #include #include @@ -64,4 +64,4 @@ class ConcatenateFusion : public KernelFusionEmitterBase { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_CONCATENATE_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/concatenate_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc index 8192b33c37a1cb..ee63bdab38f5c8 100644 --- a/third_party/xla/xla/service/gpu/fusions/concatenate_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/concatenate.h" +#include "xla/service/gpu/fusions/legacy/concatenate.h" #include @@ -74,7 +74,7 @@ TEST_F(ConcatenateTest, ThreadIndexing) { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = AnalyzeFusion(*root, device_info); + auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc index e362398cea60b1..38a3e5b68d12f6 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" +#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" #include #include @@ -41,7 +41,7 @@ constexpr int kDUSUpdateIndex = 1; } // namespace LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const { - const auto& update_shape = dus_ops_.front()->operand(1)->shape(); + const auto& update_shape = dus_ops_.front().GetOperand(1).shape(); return CalculateLaunchDimensions(update_shape, analysis_.device_info()); } @@ -55,7 +55,7 @@ InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( auto launch_dims = launch_dimensions(); // It is guaranteed that all DUS ops have the same output shape at this point. const auto& update_shape = - dus_ops_.front()->operand(kDUSUpdateIndex)->shape(); + dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, update_shape, mlir_context); } @@ -72,7 +72,7 @@ absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( // This condition should be enforced explicitly in the // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher. for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - output = output.CastToShape(op->shape(), builder); + output = output.CastToShape(op.shape(), builder); } auto* fused_computation = fusion.fused_instructions_computation(); @@ -93,7 +93,7 @@ absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( dus_and_output_array.reserve(dus_ops_.size()); for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - dus_and_output_array.push_back(std::make_pair(op, output)); + dus_and_output_array.push_back(std::make_pair(&op.instruction(), output)); } return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h rename to third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h index 08bcef5f9c2b5c..db12c3cbbf4643 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ -#define XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ #include #include @@ -78,7 +78,7 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { std::optional ComputeThreadIdToInputIndexing( int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* indexing_context) const override; + mlir::MLIRContext* mlir_context) const override; protected: absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, @@ -89,10 +89,10 @@ class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { llvm::IRBuilder<>* builder) const override; const HloFusionAnalysis& analysis_; - std::vector dus_ops_; + std::vector dus_ops_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc index e48cee0cd473b5..c4fd277dc37c94 100644 --- a/third_party/xla/xla/service/gpu/fusions/in_place_dynamic_update_slice_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/in_place_dynamic_update_slice.h" +#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" #include @@ -74,7 +74,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { )")); auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = AnalyzeFusion(*root, device_info_); + auto analysis_fused = HloFusionAnalysis::Create(*root, device_info_); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); @@ -128,7 +128,7 @@ TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) { auto* root = module->entry_computation()->root_instruction(); auto analysis_fused = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info_); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info_); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/input_slices.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc index 75ffe4b13246e5..d336f9226256a2 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/input_slices.h" +#include "xla/service/gpu/fusions/legacy/input_slices.h" #include #include diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices.h b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/input_slices.h rename to third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h index fa8043852ae803..e6532241123aee 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ -#define XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ #include #include @@ -76,4 +76,4 @@ class InputSlicesFusion : public KernelFusionEmitterBase { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_INPUT_SLICES_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/input_slices_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc index 689727aed734ec..bb9f510c59b48d 100644 --- a/third_party/xla/xla/service/gpu/fusions/input_slices_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/input_slices.h" +#include "xla/service/gpu/fusions/legacy/input_slices.h" #include @@ -71,7 +71,7 @@ TEST_F(InputSlicesTest, ThreadIndexing) { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = AnalyzeFusion(*root, device_info); + auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc new file mode 100644 index 00000000000000..e6ce5f113c713b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/legacy/loop.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Type.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout_util.h" +#include "xla/service/gpu/elemental_ir_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/ir_emitter_context.h" +#include "xla/service/gpu/launch_dimensions.h" +#include "xla/service/gpu/model/indexing_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/parallel_loop_emitter.h" +#include "xla/service/llvm_ir/fused_ir_emitter.h" +#include "xla/service/llvm_ir/ir_array.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/macros.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { +namespace { + +const Shape& GetElementShape(const HloFusionAnalysis& analysis) { + const Shape* shape = &analysis.fusion_root(0).shape(); + while (shape->IsTuple()) { + shape = &shape->tuple_shapes(0); + } + return *shape; +} + +} // namespace + +LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} + +std::optional LoopFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, mlir::MLIRContext* ctx) const { + auto launch_dims = launch_dimensions(); + return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, + GetElementShape(analysis_), ctx); +} + +std::optional LoopFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, + mlir::MLIRContext* ctx) const { + std::optional thread_id_to_output_indexing = + ComputeThreadIdToOutputIndexing(root_index, ctx); + if (!thread_id_to_output_indexing.has_value()) { + return std::nullopt; + } + const HloInstruction* fusion_root = + &analysis_.fusion_root(root_index).instruction(); + auto output_to_input_indexing = + ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); + IndexingMapSet output_to_input_indexing_set = + output_to_input_indexing.indexing_maps[hero_operand_index]; + // Since we are computing the indexing for a non-fusion op, there is only one + // indexing map per operand. + CHECK_EQ(output_to_input_indexing_set.size(), 1); + IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( + *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); + thread_id_to_input_indexing_map.Simplify(); + return thread_id_to_input_indexing_map; +} + +absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context, + const HloFusionInstruction& fusion, + const LaunchDimensions& launch_dims, + std::vector inputs, + std::vector outputs, + llvm::IRBuilder<>* builder) const { + GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); + FusedIrEmitter fused_emitter(elemental_emitter); + for (int i = 0; i < fusion.fused_parameters().size(); i++) { + fused_emitter.BindGenerator( + *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) { + return inputs[i].EmitReadArrayElement(index, builder); + }); + } + TF_ASSIGN_OR_RETURN( + auto element_generator, + fused_emitter.GetGenerator(*fusion.fused_expression_root())); + + llvm::Type* index_type = + GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); + + return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder, + config_) + .EmitLoop(fusion.name(), index_type); +} + +LaunchDimensions LoopFusion::launch_dimensions() const { + return CalculateLaunchDimensions(GetElementShape(analysis_), + analysis_.device_info(), config_); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/loop.h b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/loop.h rename to third_party/xla/xla/service/gpu/fusions/legacy/loop.h index 2d23c302ed3100..30e5007bec658f 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LOOP_H_ -#define XLA_SERVICE_GPU_FUSIONS_LOOP_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ #include #include @@ -59,13 +59,7 @@ class LoopFusion : public KernelFusionEmitterBase { LaunchDimensionsConfig config_; }; -LaunchDimensionsConfig ComputeLoopFusionConfig( - const HloFusionAnalysis& analysis); - -LaunchDimensionsConfig ComputeLoopFusionConfig( - const HloFusionAnalysis& analysis, const Shape& shape); - } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_LOOP_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/loop_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc index 69c41ec0c932b9..5832d13701a3dc 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc @@ -78,7 +78,7 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = @@ -88,20 +88,20 @@ TEST_F(LoopTest, ThreadIndexingUnrolled) { EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), MatchIndexingString(R"( (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000, - ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id + (bl_x * 128 + th_x) floordiv 15000, + ((bl_x * 128 + th_x) floordiv 75) mod 200, + ((bl_x * 128 + th_x) mod 75) * 4 + unroll_id ) domain: th_x in [0, 127] th_y in [0, 0] th_z in [0, 0] - bl_x in [0, 1007] + bl_x in [0, 11718] bl_y in [0, 0] bl_z in [0, 0] - chunk_id in [0, 11] + chunk_id in [0, 0] unroll_id in [0, 3] - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] + bl_x * 128 + th_x in [0, 1499999] )")); } @@ -121,7 +121,7 @@ TEST_F(LoopTest, ThreadIndexingNotUnrolled) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = @@ -174,7 +174,7 @@ TEST_F(LoopTest, Broadcast) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); auto thread_id_to_output_indexing = diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/reduction.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc index 77c1e910d7bc55..e009ea18e0b48c 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/reduction.h" +#include "xla/service/gpu/fusions/legacy/reduction.h" #include #include @@ -52,9 +52,9 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/elemental_ir_emitter.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/legacy/tiling_util.h" #include "xla/service/gpu/fusions/reduction_base.h" #include "xla/service/gpu/fusions/thunk_util.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" diff --git a/third_party/xla/xla/service/gpu/fusions/reduction.h b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h similarity index 96% rename from third_party/xla/xla/service/gpu/fusions/reduction.h rename to third_party/xla/xla/service/gpu/fusions/legacy/reduction.h index a15462fe9a2d8a..131b4ec38c7693 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ -#define XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ #include #include @@ -25,8 +25,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/legacy/tiling_util.h" #include "xla/service/gpu/fusions/reduction_base.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" @@ -187,4 +187,4 @@ class ReductionFusion : public KernelFusionEmitterBase { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_REDUCTION_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/reduction_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc index 81649a735e8329..cc0faf4f07d738 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/reduction.h" +#include "xla/service/gpu/fusions/legacy/reduction.h" #include @@ -69,7 +69,7 @@ TEST_F(ReductionTest, ThreadIndexingRowReduction) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); ReductionFusion fusion(analysis); EXPECT_THAT( @@ -134,7 +134,7 @@ TEST_F(ReductionTest, TwoGroups) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); ReductionFusion fusion(analysis); EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, @@ -165,7 +165,7 @@ TEST_F(ReductionTest, OneGroup) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); ReductionFusion fusion(analysis); EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2)); diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/scatter.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc index 8f7f773638a4da..07987886a73120 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/scatter.h" +#include "xla/service/gpu/fusions/legacy/scatter.h" #include #include @@ -36,7 +36,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/loop.h" +#include "xla/service/gpu/fusions/legacy/loop.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" diff --git a/third_party/xla/xla/service/gpu/fusions/scatter.h b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/scatter.h rename to third_party/xla/xla/service/gpu/fusions/legacy/scatter.h index dda11c01e7803e..862d0b3543b4ad 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ -#define XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ #include #include @@ -68,4 +68,4 @@ class ScatterFusion : public KernelFusionEmitterBase { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_SCATTER_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/scatter_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc index 284d308ad5a190..71eea76c2482f2 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/scatter.h" +#include "xla/service/gpu/fusions/legacy/scatter.h" #include @@ -85,7 +85,7 @@ TEST_F(ScatterFusionTest, ScatterFusion) { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = AnalyzeFusion(*root, device_info); + auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); @@ -141,7 +141,7 @@ TEST_F(ScatterFusionTest, ThreadIdIndexing) { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = AnalyzeFusion(*root, device_info); + auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc similarity index 72% rename from third_party/xla/xla/service/gpu/fusions/tiling_util.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc index 9ad085faff3a99..a1a7acb58388a7 100644 --- a/third_party/xla/xla/service/gpu/fusions/tiling_util.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/fusions/legacy/tiling_util.h" #include #include @@ -31,7 +31,11 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/target_util.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/service/llvm_ir/kernel_support_library.h" @@ -46,6 +50,10 @@ namespace xla { namespace gpu { namespace { +using mlir::AffineExpr; +using mlir::AffineMap; +using mlir::MLIRContext; + void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling, int dim, absl::InlinedVector tile_idx, absl::Span tile_dimensions, @@ -199,6 +207,13 @@ absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, return info; } +AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, + int64_t num_symbols) { + return AffineMap::get( + /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs, + exprs[0].getContext()); +} + } // namespace absl::StatusOr EmitTilingKernel( @@ -255,5 +270,82 @@ absl::StatusOr EmitTilingKernel( return {{tile_dimensions, tile_offset, thread_id_info}}; } +AffineMap GetBlockOffsetsForTiling( + absl::Span num_blocks, + absl::Span tile_sizes_per_block, int64_t rank, + MLIRContext* mlir_context) { + auto offsets = + DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks); + for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) { + offset = offset * tile_size; + } + return GetTilingAffineMap(offsets, rank); +} + +AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + return GetBlockOffsetsForTiling(tiling.GetBlockCounts(), + tiling.GetBlockTileSize(), + tiling.GetShape().size(), mlir_context); +} + +AffineMap GetThreadOffsetsForTiling( + absl::Span num_threads, + absl::Span tile_sizes_per_thread, int64_t rank, + MLIRContext* mlir_context) { + auto offsets = + DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads); + for (int dim = 0; dim < rank; ++dim) { + if (tile_sizes_per_thread[dim] > 1) { + offsets[dim] = offsets[dim] + + getAffineSymbolExpr(dim, mlir_context) * num_threads[dim]; + } + } + return GetTilingAffineMap(offsets, rank); +} + +AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(), + tiling.GetThreadTileSize(), + tiling.GetShape().size(), mlir_context); +} + +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + MLIRContext* mlir_context) { + return GetIndexingMapForTiling( + GetBlockOffsetsForTiling(tiling, mlir_context), + GetThreadOffsetsForTiling(tiling, mlir_context), + tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(), + tiling.GetThreadTileSize(), tiling.GetShape()); +} + +IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, + AffineMap thread_offsets, + int64_t threads_per_block, + int64_t num_blocks, + absl::Span thread_tile_sizes, + absl::Span tiled_shape) { + auto* mlir_context = block_offsets.getContext(); + llvm::SmallVector offsets; + offsets.reserve(block_offsets.getNumResults()); + for (auto [block, thread] : + llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { + offsets.push_back(block + thread); + } + std::vector dimension_ranges{ + {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, + }; + auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), + block_offsets.getNumSymbols(), offsets, + mlir_context); + IndexingMap map{affine_map, dimension_ranges, + RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}}; + for (int i = 0; i < tiled_shape.size(); ++i) { + map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1}); + } + return map; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h similarity index 76% rename from third_party/xla/xla/service/gpu/fusions/tiling_util.h rename to third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h index 66014ae48b16d6..de367e36addb61 100644 --- a/third_party/xla/xla/service/gpu/fusions/tiling_util.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_ -#define XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ #include #include @@ -27,6 +27,9 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -177,7 +180,36 @@ absl::StatusOr EmitTilingKernel( llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, const TileGenerator& tile_element_generator); +// Creates an indexing map from thread and block IDs to elements of the tiled +// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 +// are thread indices (currently only 0 is used), dimensions 3 to 5 are block +// indices (currently only 3 is used). +mlir::AffineMap GetBlockOffsetsForTiling( + absl::Span num_blocks, + absl::Span tile_sizes_per_block, int64_t rank, + mlir::MLIRContext* mlir_context); +mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* mlir_context); +mlir::AffineMap GetThreadOffsetsForTiling( + absl::Span num_threads, + absl::Span tile_sizes_per_thread, int64_t rank, + mlir::MLIRContext* mlir_context); +mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, + mlir::MLIRContext* mlir_context); + +// Convenience functions for the two functions above +// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up +// the ranges of dimensions and symbols. +IndexingMap GetIndexingMapForTiling(const Tiling& tiling, + mlir::MLIRContext* mlir_context); +IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, + mlir::AffineMap thread_offsets, + int64_t threads_per_block, + int64_t num_blocks, + absl::Span thread_tile_sizes, + absl::Span tiled_shape); + } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_TILING_UTIL_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc similarity index 97% rename from third_party/xla/xla/service/gpu/fusions/transpose.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc index 611099d8280092..d6cbdecf4bfceb 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/transpose.h" +#include "xla/service/gpu/fusions/legacy/transpose.h" #include #include @@ -39,7 +39,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/fusions/legacy/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" @@ -52,7 +52,6 @@ limitations under the License. #include "xla/service/llvm_ir/llvm_util.h" #include "xla/service/llvm_ir/loop_emitter.h" #include "xla/shape_util.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -66,12 +65,13 @@ Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info, static_assert(WarpSize() % kNumRows == 0); // 3D view over the output shape. - Vector3 transposed_dims = tiled_transpose.dimensions; - Vector3 permutation = tiled_transpose.permutation; + absl::InlinedVector transposed_dims = tiled_transpose.dimensions; + absl::InlinedVector permutation = tiled_transpose.permutation; // Note: the supported permutations are their own inverses. Therefore we // always use the permutation, even when we want the inverse. - CHECK((permutation == Vector3{0, 2, 1}) || (permutation == Vector3{2, 1, 0})); + CHECK((permutation == absl::InlinedVector{0, 2, 1}) || + (permutation == absl::InlinedVector{2, 1, 0})); absl::InlinedVector input_dims{transposed_dims[permutation[0]], transposed_dims[permutation[1]], @@ -189,7 +189,7 @@ absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, } absl::flat_hash_map tiles; - Vector3 permutation; + absl::InlinedVector permutation; for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) { permutation = tr.permutation; auto tile_size = tiling_.GetBlockTileSize(); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/transpose.h rename to third_party/xla/xla/service/gpu/fusions/legacy/transpose.h index 3f369a47c5fd08..3366130c05546b 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose.h +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h @@ -12,25 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ -#define XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ +#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ #include #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "llvm/IR/IRBuilder.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/tiling_util.h" +#include "xla/service/gpu/fusions/legacy/tiling_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/llvm_ir/ir_array.h" -#include "xla/util.h" namespace xla { namespace gpu { @@ -82,10 +82,10 @@ class TransposeFusion : public KernelFusionEmitterBase { private: const HloFusionAnalysis& analysis_; Tiling tiling_; - Vector3 permutation_; + absl::InlinedVector permutation_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_TRANSPOSE_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc similarity index 89% rename from third_party/xla/xla/service/gpu/fusions/transpose_test.cc rename to third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc index f94246916406c9..43a417843858db 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/transpose.h" +#include "xla/service/gpu/fusions/legacy/transpose.h" #include #include @@ -71,7 +71,7 @@ TEST_F(TransposeTest, ThreadIndexing021) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; @@ -118,23 +118,23 @@ TEST_F(TransposeTest, ThreadIndexing021) { )")); } -TEST_F(TransposeTest, ThreadIndexing201) { +TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { auto module = ParseAndReturnVerifiedModule(R"( HloModule module fusion { - %input = f32[100,64,32] parameter(0) - ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + %input = f32[1,6400,32] parameter(0) + ROOT transpose = f32[1,32,6400] transpose(%input), dimensions={0,2,1} } ENTRY entry { - %input = f32[100,64,32] parameter(0) - ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + %input = f32[1,6400,32] parameter(0) + ROOT %fusion = f32[1,32,6400] fusion(%input), kind=kInput, calls=fusion })") .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; @@ -142,8 +142,8 @@ TEST_F(TransposeTest, ThreadIndexing201) { fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, + 0, + d3 * 32 + s1 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -162,9 +162,9 @@ TEST_F(TransposeTest, ThreadIndexing201) { fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( + 0, d0 floordiv 32 + s1 * 4, - d3 floordiv 2, - (d3 mod 2) * 32 + d0 mod 32 + d3 * 32 + d0 mod 32 ) domain: d0 in [0, 127] @@ -185,20 +185,20 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %p0 = f64[24,2,24] parameter(0) + ROOT %t = f64[24,2,24] transpose(%p0), dimensions={2,1,0} } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, + %p0 = f64[24,2,24] parameter(0) + ROOT %fusion = f64[24,2,24] fusion(%p0), kind=kInput, calls=%fused_computation } )") .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; @@ -208,8 +208,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( d0 floordiv 32 + s0 * 4, d3, - (d0 floordiv 4) mod 8, - d0 mod 4 + d0 mod 32 ) domain: d0 in [0, 127] @@ -227,8 +226,7 @@ TEST_F(TransposeTest, ThreadIndexingPartialBlock) { fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - s0, - d0 floordiv 32, + d0 floordiv 32 + s0 * 4, d3, d0 mod 32 ) @@ -264,7 +262,7 @@ TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; @@ -294,7 +292,7 @@ TEST_F(TransposeTest, ThreadIndexingSideOutput) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); mlir::MLIRContext mlir_context; diff --git a/third_party/xla/xla/service/gpu/fusions/loop.cc b/third_party/xla/xla/service/gpu/fusions/loop.cc deleted file mode 100644 index e9b7933b1c7895..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/loop.cc +++ /dev/null @@ -1,293 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/loop.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/numeric/bits.h" -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -const Shape& GetElementShape(const HloFusionAnalysis& analysis) { - const Shape* shape = &analysis.fusion_root(0).shape(); - while (shape->IsTuple()) { - shape = &shape->tuple_shapes(0); - } - return *shape; -} - -// Computes the maximum valid unroll factor for a given instruction. -int ComputeMaxUnrollFactor(int64_t num_elements) { - constexpr int kMaxUnrollFactor = 4; - for (int i = kMaxUnrollFactor; i > 1; i /= 2) { - if (num_elements % i == 0) { - return i; - } - } - return 1; -} - -// Determines if we enable the row optimized codegen. When we have a fusion with -// only pointwise operations, scalar broadcasting and row broadcasting, we can -// trigger a kernel that vectorizes the row loads. This speeds up the kernel, in -// particular on A100. The int is the number of inputs with rank `out_rank`. Its -// value is only defined if row vectorization is enabled. -std::pair RowVectorizationEnabled( - const HloFusionAdaptor& fusion, int64_t out_rank) { - auto roots = fusion.GetRoots(); - const auto is_row_major = [](auto instr) { - // Only tested when the inputs are row-major. So only enable that case. - // Maybe it would work if only the inner dimensions is contiguous. - return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout()); - }; - bool row_vectorized = roots.size() == 1 && !roots[0].shape().IsTuple() && - is_row_major(roots[0]); - if (!row_vectorized) { - return {false, 0}; - } - - // Check that the operations in the fusion are supported. Each - // supported operation (or category) must be manually vetted as XLA - // only unrolls and relies on LLVM to vectorize. But this is brittle. - // Currently tested and supported operations: - // Elementwise, scalar and row broadcasting. - // - // We also detect at the same time if there is a row broadcasting - // operation. - int num_big_inputs = 0; - bool some_row_broadcasting = false; - HloBfsConsumersFirstTraversal( - roots, fusion, - [&](auto node) -> TraversalResult { - if (!row_vectorized) { - return TraversalResult::kInterrupt; - } - - if (node.instruction().IsElementwise()) { - return TraversalResult::kAdvance; - } - - switch (node.opcode()) { - case HloOpcode::kConstant: - return TraversalResult::kSkip; - case HloOpcode::kParameter: - return TraversalResult::kAdvance; - case HloOpcode::kBroadcast: { - auto dims = node.instruction().dimensions(); - if (dims.empty()) { - return TraversalResult::kAdvance; - } - - if (dims.size() == 1 && dims.front() == node.shape().rank() - 1) { - some_row_broadcasting = true; - return TraversalResult::kAdvance; - } - TF_FALLTHROUGH_INTENDED; - } - default: - VLOG(2) << "Row vectorization not enabled due to: " - << node.ToString(); - row_vectorized = false; - return TraversalResult::kInterrupt; - } - }, - [&](auto argument) { - if (argument.shape().rank() == out_rank) { - ++num_big_inputs; - } - if (!is_row_major(argument)) { - row_vectorized = false; - } - }); - // Trigger only when there is a row broadcasting. - return std::make_pair(row_vectorized && some_row_broadcasting, - num_big_inputs); -} - -} // namespace - -LaunchDimensionsConfig ComputeLoopFusionConfig( - const HloFusionAnalysis& analysis) { - return ComputeLoopFusionConfig(analysis, GetElementShape(analysis)); -} - -LaunchDimensionsConfig ComputeLoopFusionConfig( - const HloFusionAnalysis& analysis, const Shape& element_shape) { - int unroll_factor = 1; - // Unrolling is good to read large inputs with small elements - // due to vector loads, but increases the register pressure when one - // thread has to produce multiple output elements. - // Therefore for fusions with small outputs prefer to use one thread - // per output element = no unroll. - // Call 'small' fusions that use less threads than the GPU has. - int64_t num_elements = ShapeUtil::ElementsIn(element_shape); - int64_t n_threads_max = analysis.device_info().threads_per_core_limit() * - analysis.device_info().core_count(); - if (num_elements >= n_threads_max && - !MayPreventVectorization(analysis.fusion())) { - unroll_factor = ComputeMaxUnrollFactor(num_elements); - } - // CHECK that unroll_factor is a power-of-2, as needed by the logic below. - CHECK(absl::has_single_bit(static_cast(unroll_factor))); - // Ensure a single thread writes to a byte containing multiple values by - // setting unroll_factor to an appropriate number. Setting unroll_factor is - // safe even if the new unroll_factor doesn't divide the number of elements, - // as the parallel loop emitter will insert a bounds check in this case to - // ensure the out-of-bounds element is not computed and written. Setting - // unroll_factor is safe even if MayPreventVectorization returns false, as - // the MayPreventVectorization check is an optimization, not a correctness - // requirement. - unroll_factor = std::max( - unroll_factor, - CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits)); - CHECK(absl::has_single_bit(static_cast(unroll_factor))); - VLOG(2) << "Unroll factor: " << unroll_factor; - - bool row_vectorized; - int num_big_inputs; - std::tie(row_vectorized, num_big_inputs) = - RowVectorizationEnabled(analysis.fusion(), element_shape.rank()); - bool few_waves = !HloAnyOf(analysis.fusion(), [&](auto instr) { - if (instr.opcode() == HloOpcode::kParameter || - instr.opcode() == HloOpcode::kConstant || - HloInstruction::IsOpElementwise(instr.opcode())) { - return false; - } - if (auto broadcast = - DynCast(&instr.instruction())) { - if (broadcast->dimensions().empty() || - // More than 3 big inputs cause a speed regression. - (row_vectorized && num_big_inputs <= 3)) { - return false; - } - } - VLOG(2) << "few_waves not enabled due to: " - << instr.instruction().ToString(); - return true; - }); - - LaunchDimensionsConfig launch_config{unroll_factor, few_waves, - row_vectorized}; - // Check that the shapes is supported. - if (launch_config.row_vectorized && - ThreadsPerBlockRowVectorized(element_shape, analysis.device_info(), - launch_config) <= 0) { - VLOG(2) << "Cancelling row_vectorization as the shape isn't supported."; - launch_config.row_vectorized = false; - launch_config.few_waves = false; - } - return launch_config; -} - -LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} - -std::optional LoopFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - auto launch_dims = launch_dimensions(); - return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, - GetElementShape(analysis_), ctx); -} - -std::optional LoopFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - std::optional thread_id_to_output_indexing = - ComputeThreadIdToOutputIndexing(root_index, ctx); - if (!thread_id_to_output_indexing.has_value()) { - return std::nullopt; - } - const HloInstruction* fusion_root = - &analysis_.fusion_root(root_index).instruction(); - auto output_to_input_indexing = - ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); - IndexingMapSet output_to_input_indexing_set = - output_to_input_indexing.indexing_maps[hero_operand_index]; - // Since we are computing the indexing for a non-fusion op, there is only one - // indexing map per operand. - CHECK_EQ(output_to_input_indexing_set.size(), 1); - IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( - *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); - thread_id_to_input_indexing_map.Simplify(); - return thread_id_to_input_indexing_map; -} - -absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const { - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fusion.fused_parameters().size(); i++) { - fused_emitter.BindGenerator( - *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - TF_ASSIGN_OR_RETURN( - auto element_generator, - fused_emitter.GetGenerator(*fusion.fused_expression_root())); - - llvm::Type* index_type = - GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - - return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder, - config_) - .EmitLoop(fusion.name(), index_type); -} - -LaunchDimensions LoopFusion::launch_dimensions() const { - return CalculateLaunchDimensions(GetElementShape(analysis_), - analysis_.device_info(), config_); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc index 9db9173ab05141..4c6bdac0ce01ad 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/launch_dimensions.h" diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h index 029c67bea36a56..ecb1591b553dc7 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir.h @@ -22,9 +22,9 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/loop.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" diff --git a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc index 08dcb4df490e54..37efa18945e58d 100644 --- a/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/loop_mlir_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -29,100 +29,6 @@ namespace { using MlirLoopFusionTest = MlirEmitterTestBase; -TEST_F(MlirLoopFusionTest, ThreadId_IndexingUnrolled) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[100,200,300] parameter(0) - ROOT neg = f32[100,200,300] negate(%input) - } - ENTRY entry { - %input = f32[100,200,300] parameter(0) - ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg - } - )")); - thread_id_printer_.SetSymbolName(0, "chunk_id"); - thread_id_printer_.SetSymbolName(1, "unroll_id"); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - - EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + chunk_id * 129024 + th_x) floordiv 15000, - ((bl_x * 128 + chunk_id * 129024 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + chunk_id * 129024 + th_x) mod 75) * 4 + unroll_id - ) - domain: - th_x in [0, 127] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 1007] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 11] - unroll_id in [0, 3] - bl_x * 128 + chunk_id * 129024 + th_x in [0, 1499999] -)")); -} - -TEST_F(MlirLoopFusionTest, ThreadId_IndexingNotUnrolled) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[20] parameter(0) - ROOT neg = f32[20] negate(%input) - } - ENTRY entry { - %input = f32[20] parameter(0) - ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg - } - )")); - thread_id_printer_.SetSymbolName(0, "chunk_id"); - thread_id_printer_.SetSymbolName(1, "unroll_id"); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - - MlirLoopFusion fusion(analysis); - auto thread_id_to_output_indexing = - fusion.ComputeThreadIdToOutputIndexing(/*root_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) - domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - )")); - auto thread_id_to_input_indexing = fusion.ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(thread_id_printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x) - domain: - th_x in [0, 19] - th_y in [0, 0] - th_z in [0, 0] - bl_x in [0, 0] - bl_y in [0, 0] - bl_z in [0, 0] - chunk_id in [0, 0] - unroll_id in [0, 0] - )")); -} - TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module @@ -140,7 +46,7 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { thread_id_printer_.SetSymbolName(1, "unroll_id"); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirLoopFusion fusion(analysis); auto thread_id_to_output_indexing = @@ -182,42 +88,6 @@ TEST_F(MlirLoopFusionTest, ThreadId_Broadcast) { )")); } -TEST_F(MlirLoopFusionTest, Constant_Broadcast) { - auto kHloString = R"( - HloModule module - - bcast { - zero = bf16[] constant(0) - ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} - } - - ENTRY entry { - ROOT %fusion = bf16[2,16,48]{2,1,0} fusion(), kind=kLoop, calls=bcast - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d1 * 1024 + d0)> - // CHECK: #[[MAP1:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) floordiv 768)> - // CHECK: #[[MAP2:.*]] = affine_map<(d0, d1) -> (((d1 * 1024 + d0) floordiv 48) mod 16)> - // CHECK: #[[MAP3:.*]] = affine_map<(d0, d1) -> ((d1 * 1024 + d0) mod 48)> - // CHECK: func.func @fused_computation(%[[ARG0:.*]]: tensor<2x16x48xbf16> - // CHECK: %[[UPPER_BOUND:.*]] = arith.constant 1535 : index - // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id - // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id - // CHECK: %[[LINEAR:.*]] = xla_gpu.apply_indexing #[[MAP0]] - // CHECL: %[[IN_BOUNDS:.*]] = arith.cmpi sle, %[[LINEAR]], %[[UPPER_BOUND]] : index - // scf.if %[[IN_BOUNDS]] - // CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP1]] - // CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP2]] - // CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP3]] - // CHECK: %[[BCAST:.*]] = xla_gpu.pure_call @bcast_broadcast - // CHECK: %[[INSERTED:.*]] = tensor.insert %[[BCAST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] - // CHECK: func.func private @bcast_broadcast - // CHECK: arith.constant 0.000000e+00 - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0})); -} - TEST_F(MlirLoopFusionTest, NoCodeDuplication) { // This test HLO is copied from // xla/service/fusion_node_indexing_evaluation_test.cc. @@ -253,85 +123,6 @@ TEST_F(MlirLoopFusionTest, NoCodeDuplication) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, TwoUsersConsistentIndexing) { - auto kHloString = R"( - HloModule test_module - - %fused_computation (param: f32[6]) -> f32[2] { - %p0 = f32[2]{0} parameter(0) - %p1 = f32[2]{0} parameter(1) - %add = f32[2] add(%p0, %p1) - %sub = f32[2] subtract(%p0, %p1) - %mul = f32[2] multiply(%add, %sub) - %div = f32[2] divide(%add, %sub) - ROOT %atan2 = f32[2] atan2(%mul, %div) - } - ENTRY entry_computation { - p0 = f32[2] parameter(0) - p1 = f32[2] parameter(1) - ROOT %fusion = f32[2] fusion(p0, p1), kind=kLoop, calls=%fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK-NEXT: gpu.thread_id - // CHECK-NEXT: pure_call @fused_computation_atan2 - // CHECK-NEXT: tensor.insert - // CHECK-NEXT: return - - // CHECK: func.func private @fused_computation_atan2 - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: addf - // CHECK-NEXT: subf - // CHECK-NEXT: mulf - // CHECK-NEXT: divf - // CHECK-NEXT: atan2 - // CHECK-NEXT: return - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, ComplexOps) { - auto kHloString = R"( - HloModule test_module - - %fused_computation { - %p0 = f32[2]{0} parameter(0) - %p1 = f32[2]{0} parameter(1) - %p2 = c64[2]{0} parameter(2) - %complex = c64[2] complex(%p0, %p1) - %add = c64[2] add(%complex, %p2) - %cst = c64[2]{0} constant({(2.0, 0.0), (0.0, 2.0)}) - ROOT %mul = c64[2] multiply(%add, %cst) - } - ENTRY entry_computation { - p0 = f32[2] parameter(0) - p1 = f32[2] parameter(1) - p2 = c64[2] parameter(2) - ROOT %fusion = c64[2] fusion(p0, p1, p2), kind=kLoop, calls=%fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK-NEXT: gpu.thread_id - // CHECK-NEXT: pure_call @fused_computation_mul - // CHECK-NEXT: tensor.insert - // CHECK-NEXT: return - - // CHECK: func.func private @fused_computation_mul - // CHECK-NEXT: arith.constant - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.create - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.add - // CHECK-NEXT: tensor.extract - // CHECK-NEXT: complex.mul - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { auto kHloString = R"( HloModule test_module @@ -359,137 +150,6 @@ TEST_F(MlirLoopFusionTest, IotaCopyBitcastBroadcastReshapeReverseTranspose) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, VariadicReduce) { - auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_lhs.1 = f32[] parameter(1) - scalar_rhs.0 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add = f32[] add(scalar_lhs.0, scalar_rhs.0) - mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add, mul) - } - fused_computation { - param_0 = f32[3,4,5]{2,1,0} parameter(0) - param_1 = f32[3,4,5]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, - f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), - dimensions={0,2}, to_apply=Add - } - ENTRY main { - a = f32[3,4,5]{2,1,0} parameter(0) - b = f32[3,4,5]{2,1,0} parameter(1) - c = f32[] constant(0) - ROOT fusion = (f32[4]{0}, f32[4]{0}) fusion(a, b, c), - kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func @fused_computation( - // CHECK: %[[TID_X:.*]] = gpu.thread_id x - // CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fused_computation_d_1 - // CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[TID_X]]] - // CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[TID_X]]] - // CHECK: return %[[INSERTED_1]], %[[INSERTED_2]] - - // CHECK: func private @fused_computation_d_1 - // CHECK: %[[RET:.*]]:2 = func.call @Add_t - // CHECK: yield %[[RET]]#0, %[[RET]]#1 - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, MinimumMaximum) { - auto kHloString = R"( - HloModule Test - - fused_computation { - param0 = f64[] parameter(0) - param1 = f64[] parameter(1) - - minimum = f64[] minimum(f64[] param0, f64[] param1) - maximum = f64[] maximum(f64[] param0, f64[] param1) - ROOT tuple = (f64[], f64[]) tuple(minimum, maximum) - } - - ENTRY main { - param0 = f64[] parameter(0) - param1 = f64[] parameter(1) - ROOT fusion = (f64[], f64[]) fusion(f64[] param0, f64[] param1), kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: func.func @fused_computation - // CHECK: xla_gpu.pure_call @fused_computation_tuple - // CHECK: func.func private @fused_computation_tuple - // CHECK-DAG: arith.minimumf - // CHECK-DAG: arith.maximumf - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, TupleBitcast) { - auto kHloString = R"( - HloModule Test - - fused_computation { - param0 = f64[8] parameter(0) - param1 = f64[8] parameter(1) - - minimum = f64[8] minimum(param0, param1) - maximum = f64[8] maximum(param0, param1) - bc = f64[2, 4] bitcast(maximum) - ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) - } - - ENTRY main { - param0 = f64[8] parameter(0) - param1 = f64[8] parameter(1) - ROOT fusion = (f64[8], f64[2,4]) fusion(param0, param1), - kind=kLoop, calls=fused_computation - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, NestedTuple) { - auto kHloString = R"( - add { - scalar_lhs.0 = f32[] parameter(0) - scalar_lhs.1 = f32[] parameter(1) - scalar_rhs.0 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add = f32[] add(scalar_lhs.0, scalar_rhs.0) - mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add, mul) - } - fused_computation { - param_0 = f32[3,4,5]{2,1,0} parameter(0) - param_1 = f32[3,4,5]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - param_3 = f32[4] parameter(3) - reduce = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, - f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), - dimensions={0,2}, to_apply=add - log = f32[4] log(param_3) - ROOT tuple = ((f32[4], f32[4]), f32[4]) tuple(reduce, log) - } - ENTRY main { - a = f32[3,4,5]{2,1,0} parameter(0) - b = f32[3,4,5]{2,1,0} parameter(1) - c = f32[] constant(0) - d = f32[4] parameter(2) - ROOT fusion = ((f32[4], f32[4]), f32[4]) fusion(a, b, c, d), - kind=kLoop, calls=fused_computation - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirLoopFusionTest, DynamicSliceWith64BitInput) { // Lowering this kernel with 32 bit indices causes an underflow of `c`, // resulting in slicing the last four elements instead of the first four. @@ -511,63 +171,6 @@ TEST_F(MlirLoopFusionTest, DynamicSliceWith64BitInput) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirLoopFusionTest, DynamicUpdateSlice) { - constexpr auto kHloString = R"( - %fused_computation { - in = c64[2,3] parameter(0) - updates = c64[2,2] parameter(1) - i0 = s32[] parameter(2) - i1 = s32[] parameter(3) - updated = c64[2,3] dynamic-update-slice(in, updates, i0, i1) - ROOT transpose = c64[3,2] transpose(updated), dimensions={1,0} - } - - ENTRY main { - p0 = c64[2,3] parameter(0) - p1 = c64[2,2] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - ROOT %fusion = c64[3,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=%fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: scf.if - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, NotPred) { - constexpr auto kHloString = R"( - %fused_computation { - p0 = s8[1000] parameter(0) - cvt = pred[1000] convert(p0) - ROOT not = pred[1000] not(cvt) - } - - ENTRY main { - p0 = s8[1000] parameter(0) - ROOT %fusion = pred[1000] fusion(p0), kind=kLoop, calls=%fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirLoopFusionTest, MulPred) { - constexpr auto kHloString = R"( - %fused_computation { - p0 = s8[1000] parameter(0) - p1 = s8[1000] parameter(1) - cvt0 = pred[1000] convert(p0) - cvt1 = pred[1000] convert(p1) - ROOT mul = pred[1000] multiply(cvt0, cvt1) - } - - ENTRY main { - p0 = s8[1000] parameter(0) - p1 = s8[1000] parameter(1) - ROOT %fusion = pred[1000] fusion(p0, p1), kind=kLoop, calls=%fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD index 6c20a91a020355..08a159f9b268f3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/mlir/BUILD @@ -1,4 +1,3 @@ -load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("//xla:xla.bzl", "xla_cc_test") package( @@ -76,7 +75,7 @@ cc_library( "//xla/mlir_hlo:type_conversion", "//xla/service:algorithm_util", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", @@ -119,13 +118,14 @@ xla_cc_test( "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/model:indexing_analysis", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:launch_dim", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest", "@llvm-project//llvm:Support", @@ -140,7 +140,6 @@ xla_cc_test( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:Transforms", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -153,7 +152,6 @@ cc_library( deps = [ ":computation_partitioner", ":elemental_hlo_to_mlir", - ":passes", ":type_util", "//xla:shape_util", "//xla:status_macros", @@ -172,7 +170,8 @@ cc_library( "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:target_util", "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/model:indexing_analysis", "//xla/service/gpu/runtime:kernel_thunk", "//xla/service/llvm_ir:llvm_util", @@ -261,95 +260,6 @@ xla_cc_test( ], ) -gentbl_cc_library( - name = "passes_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=GpuFusionTransforms", - ], - "passes.h.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "passes.td", - visibility = ["//visibility:private"], - deps = ["@llvm-project//mlir:PassBaseTdFiles"], -) - -cc_library( - name = "passes", - srcs = [ - "convert_xla_gpu_pure_call_ops.cc", - "erase_dead_functions.cc", - "expand_float_ops.cc", - "flatten_tensors.cc", - "lower_tensors.cc", - "lower_to_llvm.cc", - "lower_xla_gpu_to_scf.cc", - "merge_pointers_to_same_slice.cc", - "optimize_loops.cc", - "propagate_slice_indices.cc", - "simplify_affine.cc", - "simplify_arith.cc", - "unswitch_loops.cc", - "vectorize_loads_stores.cc", - ], - hdrs = ["passes.h"], - deps = [ - ":passes_inc_gen", - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/mlir_hlo", - "//xla/mlir_hlo:map_mhlo_to_scalar_op", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", - "//xla/service/gpu/model:indexing_analysis", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:AffineToStandard", - "@llvm-project//mlir:AffineUtils", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ArithToLLVM", - "@llvm-project//mlir:ArithTransforms", - "@llvm-project//mlir:CallOpInterfaces", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:ComplexToLLVM", - "@llvm-project//mlir:ControlFlowToLLVM", - "@llvm-project//mlir:DataLayoutInterfaces", - "@llvm-project//mlir:DialectUtils", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncToLLVM", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:GPUToNVVMTransforms", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMCommonConversion", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MathToLLVM", - "@llvm-project//mlir:MathTransforms", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:SCFToControlFlow", - "@llvm-project//mlir:SCFUtils", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:VectorDialect", - "@llvm-project//mlir:VectorToLLVM", - "@llvm-project//mlir:VectorTransforms", - ], -) - cc_library( name = "type_util", srcs = ["type_util.cc"], diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc index c1cc0de31de574..53d8678e953074 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/computation_partitioner.cc @@ -300,12 +300,12 @@ PartitionedComputation::PartitionedComputation( absl::StrJoin(roots, "_", [](std::string* out, const auto* root) { absl::StrAppend(out, root->name()); }))); - subgraphs_.push_back( - Subgraph{.name = std::move(name), - .instructions = {instructions.begin(), instructions.end()}, - .roots = std::move(roots), - .index_ranges = std::move(ranges), - .root_indexing = std::move(root_indexing)}); + subgraphs_.push_back(Subgraph{ + /* .name = */ std::move(name), + /* .instructions = */ {instructions.begin(), instructions.end()}, + /* .roots = */ std::move(roots), + /* .index_ranges = */ std::move(ranges), + /* .root_indexing = */ std::move(root_indexing)}); } for (const auto& subgraph : subgraphs_) { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index 59471a3fb337ea..839dd96ab48fea 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/algorithm/container.h" @@ -67,24 +66,18 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #include "xla/primitive_util.h" #include "xla/service/algorithm_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" -#include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/stream_executor/device_description.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" -#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -284,43 +277,45 @@ absl::StatusOr> EmitConcat( PrimitiveTypeToMlirType(instr->shape().element_type(), b); int concat_dim = Cast(instr)->concatenate_dimension(); - int64_t offset = 0; - IfOp outermost_if = nullptr; SmallVector operand_indices = indices; - for (auto [index, operand] : llvm::enumerate(instr->operands())) { - int64_t limit = offset + operand->shape().dimensions(concat_dim); - auto ins = b.create(CmpIPredicate::ult, indices[concat_dim], - b.create(limit)); + SmallVector offsets{0}; + for (auto* operand : instr->operands()) { + offsets.push_back(offsets.back() + operand->shape().dimensions(concat_dim)); + } - auto generate_operand = [&, index = index]() { + std::function>(int64_t, int64_t)> + generate_concat; + generate_concat = [&](int64_t begin, + int64_t end) -> absl::StatusOr> { + // If there's just one operand in the range, emit it. + if (begin == end - 1) { operand_indices[concat_dim] = b.create( - indices[concat_dim], b.create(offset)); + indices[concat_dim], b.create(offsets[begin])); TF_ASSIGN_OR_RETURN(auto operand, - operand_provider(instr, index, operand_indices)); - b.create(operand); - return absl::OkStatus(); - }; - - if (index < instr->operand_count() - 1) { - auto if_op = - b.create(mlir::TypeRange{result_element_type}, ins, true, true); - if (outermost_if == nullptr) { - outermost_if = if_op; - } else { - b.create(if_op.getResults()); - } - - b.setInsertionPointToStart(if_op.getBody(0)); - TF_RETURN_IF_ERROR(generate_operand()); - b.setInsertionPointToStart(if_op.getBody(1)); - } else { - TF_RETURN_IF_ERROR(generate_operand()); + operand_provider(instr, begin, operand_indices)); + return operand; } - offset = limit; - } - b.setInsertionPointAfter(outermost_if); - return outermost_if.getResults(); + int64_t mid = (begin + end) / 2; // No risk of overflow. + auto if_op = b.create( + mlir::TypeRange{result_element_type}, + b.create(CmpIPredicate::ult, indices[concat_dim], + b.create(offsets[mid])), + true, true); + + b.setInsertionPointToStart(if_op.getBody(0)); + TF_ASSIGN_OR_RETURN(auto left_val, generate_concat(begin, mid)); + b.create(left_val); + + b.setInsertionPointToStart(if_op.getBody(1)); + TF_ASSIGN_OR_RETURN(auto right_val, generate_concat(mid, end)); + b.create(right_val); + b.setInsertionPointAfter(if_op); + + return if_op.getResults(); + }; + + return generate_concat(0, instr->operand_count()); } absl::StatusOr> EmitDynamicSlice( @@ -665,9 +660,10 @@ Value ApplyAffineExpr(mlir::AffineExpr expr, ValueRange dims, return b.createOrFold(expr, args); } -SmallVector ApplyIndexing(const IndexingMap& map, ValueRange dims, +SmallVector ApplyIndexing(IndexingMap map, ValueRange dims, ValueRange symbols, ImplicitLocOpBuilder& b) { + map.ClearConstraints(); SmallVector results; for (unsigned int i = 0; i < map.GetAffineMap().getNumResults(); ++i) { SmallVector result; @@ -1176,57 +1172,6 @@ absl::StatusOr> HloToMlir( } // namespace -bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability) { - return !(kUnsupportedOps.contains(instr->opcode()) || - IsUnsupportedGather(instr)); -} - -bool IsHloConversionSupported(const HloComputation* computation, - se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - - return absl::c_all_of( - computation->instructions(), - [=](const HloInstruction* instr) { - return absl::c_all_of(instr->called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) && - IsHloOpSupported(instr, cuda_compute_capability); - }) && - (computation->IsFusionComputation() || - (absl::c_all_of( - computation->parameter_instructions(), [](auto* param) { - return param->shape().IsArray() && param->shape().rank() == 0; - }))); -} - -bool IsHloConversionSupported(const HloFusionAdaptor& fusion, - se::GpuComputeCapability compute_capability) { - if (!std::holds_alternative(compute_capability)) { - // ROCM is not tested. - return false; - } - auto cuda_compute_capability = - std::get(compute_capability); - - return !HloAnyOf(fusion, [=](HloInstructionAdaptor instr) { - return !absl::c_all_of(instr.instruction().called_computations(), - [&](const HloComputation* called) { - return IsHloConversionSupported( - called, compute_capability); - }) || - !IsHloOpSupported(&instr.instruction(), cuda_compute_capability); - }); -} - ValueRange ProvideParameter(const PartitionedComputation& computation, const HloInstruction* instr, int operand_index, ValueRange indices, @@ -1465,6 +1410,8 @@ absl::StatusOr> SubgraphToMlir( .Convert(); } +} // namespace + void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b, const IndexingMap& indexing_map, SmallVectorImpl* lbs, @@ -1479,8 +1426,6 @@ void GetLoopBoundsFromIndexingMap(ImplicitLocOpBuilder& b, } } -} // namespace - absl::Status SubgraphToMlirFunction( const PartitionedComputation& computation, const PartitionedComputation::Subgraph& subgraph, mlir::func::FuncOp& func, @@ -1513,20 +1458,6 @@ absl::Status SubgraphToMlirFunction( namespace { -bool IsSymbolConstrained(const IndexingMap& map, int symbol_id) { - for (const auto& [expr, _] : map.GetConstraints()) { - bool result = false; - expr.walk([&](mlir::AffineExpr leaf) { - auto sym = mlir::dyn_cast(leaf); - if (sym && sym.getPosition() == symbol_id) { - result = true; - } - }); - if (result) return true; - } - return false; -} - ValueRange EmitLoopNestImpl( ImplicitLocOpBuilder& b, ValueRange dim_values, ValueRange iter_args_inits, const IndexingMap& indexing_map, @@ -1623,7 +1554,7 @@ ValueRange EmitLoopNest(ImplicitLocOpBuilder& b, ValueRange dim_values, sym_index >= 0 && cumulative_loop_size < 64; --sym_index) { auto& bound = indexing_map.GetSymbolBound(sym_index); cumulative_loop_size *= bound.GetLoopTripCount(); - if (!IsSymbolConstrained(indexing_map, sym_index)) continue; + if (!indexing_map.IsSymbolConstrained(sym_index)) continue; IndexingMap peeled_map = indexing_map; if (bound.upper == bound.lower) continue; @@ -1632,7 +1563,7 @@ ValueRange EmitLoopNest(ImplicitLocOpBuilder& b, ValueRange dim_values, peeled_map.Simplify(); // If the symbol is still constrained, peeling does not help. - if (IsSymbolConstrained(peeled_map, sym_index)) continue; + if (peeled_map.IsSymbolConstrained(sym_index)) continue; auto first_results = EmitLoopNestImpl(b, dim_values, iter_args_inits, peeled_map, create_body, vectorize); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h index 1f52109e34c883..82811ea56fa97a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h @@ -62,19 +62,6 @@ llvm::SmallVector ProvideParameterRange( const CallTargetProvider& call_target_provider, mlir::func::FuncOp this_fn, mlir::ImplicitLocOpBuilder& builder); -// Checks whether the given HLO instruction can be converted to MLIR. -bool IsHloOpSupported(const HloInstruction* instr, - se::CudaComputeCapability compute_capability); - -// Checks whether the given HLO computation is supported by the MLIR converter: -// - all instructions in it are supported -// - the signature is supported: if the computation is not a fusion computation, -// all arguments have rank 0. -bool IsHloConversionSupported(const HloComputation* computation, - se::GpuComputeCapability compute_capability); -bool IsHloConversionSupported(const HloFusionAdaptor& fusion, - se::GpuComputeCapability compute_capability); - // Converts a function (subgraph) to an MLIR function producing one element of // the result. The function must have the correct interface. absl::Status SubgraphToMlirFunction( @@ -94,7 +81,7 @@ mlir::Value ApplyAffineExpr(mlir::AffineExpr expr, mlir::ValueRange dims, mlir::ImplicitLocOpBuilder& b); // Creates an `apply_indexing` op for the given map. -llvm::SmallVector ApplyIndexing(const IndexingMap& map, +llvm::SmallVector ApplyIndexing(IndexingMap map, mlir::ValueRange dims, mlir::ValueRange symbols, mlir::ImplicitLocOpBuilder& b); @@ -148,6 +135,13 @@ mlir::SmallVector InlineBlock(mlir::OpBuilder& builder, mlir::Block& src_block, mlir::ValueRange mapped_args); +// Populates `lbs`, `ubs` and `steps` with the loop bounds from `indexing_map`. +void GetLoopBoundsFromIndexingMap(mlir::ImplicitLocOpBuilder& b, + const IndexingMap& indexing_map, + llvm::SmallVectorImpl* lbs, + llvm::SmallVectorImpl* ubs, + llvm::SmallVectorImpl* steps); + } // namespace mlir_converter } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index 6a27e548ca932f..eab1568fa2eb38 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/service/gpu/model/indexing_map.h" @@ -47,7 +47,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -235,10 +235,10 @@ TEST_F(ElementalHloToMlirTest, ReduceWindow) { // CHECK: %[[INIT:.*]] = tensor.extract %[[ARG1]][] // CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C7]] // CHECK-SAME: step %[[C1]] iter_args(%[[ACC:.*]] = %[[INIT]]) - // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 4)> - // CHECK-SAME: (%[[Y]] in [0, 2]) - // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0 - 3)> - // CHECK-SAME: (%[[Z]] in [0, 7])[%[[I]] in [0, 6]] + // CHECK: %[[J0:.*]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 4), domain: d0 in [0, 2]>(%[[Y]]) + // CHECK: %[[J1:.*]] = xla_gpu.apply_indexing + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 3), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 6]>(%[[Z]])[%[[I]]] // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]] // CHECK-SAME: [%[[X]], %[[J0]], %[[J1]]] // CHECK: %[[UPD:.*]] = func.call @add_sum(%[[ACC]], @@ -285,8 +285,8 @@ TEST_F(ElementalHloToMlirTest, ReduceWindowWithRescaling) { // If symbol rescaling wasn't working we would have a // `s0 floordiv ` in the map: // CHECK: %[[K:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[X]] in [0, 18])[%[[I]] in [0, 3]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 18], s0 in [0, 3]>(%[[X]])[%[[I]]] // CHECK: tensor.extract %[[ARG0]][%[[K]], %[[Y]], %[[Z]]] )")); @@ -333,6 +333,79 @@ TEST_F(ElementalHloToMlirTest, Concatenate) { )")); } +TEST_F(ElementalHloToMlirTest, ConcatenateMany) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[10,1,30] parameter(0) + p1 = f32[10,2,30] parameter(1) + p2 = f32[10,3,30] parameter(2) + p3 = f32[10,4,30] parameter(3) + p4 = f32[10,5,30] parameter(4) + p5 = f32[10,6,30] parameter(5) + p6 = f32[10,7,30] parameter(6) + ROOT r = f32[10,28,30] concatenate(p0, p1, p2, p3, p4, p5, p6), + dimensions={1} + })", + R"( + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index + // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index + // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index + // CHECK-DAG: %[[C21:.*]] = arith.constant 21 : index + // CHECK: %[[P0TO2:.*]] = arith.cmpi ult, %[[I:.*]], %[[C6]] + // CHECK: %[[CONCAT:.*]] = scf.if %[[P0TO2]] -> (f32) + // CHECK-NEXT: %[[P0:.*]] = arith.cmpi ult, %[[I]], %[[C1]] + // CHECK-NEXT: scf.if %[[P0]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[I]], {{.*}}] : tensor<10x1x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[P1:.*]] = arith.cmpi ult, %[[I]], %[[C3]] + // CHECK-NEXT: scf.if %[[P1]] + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C1]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x2x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C3]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x3x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[P3TO4:.*]] = arith.cmpi ult, %[[I]], %[[C15]] + // CHECK-NEXT: scf.if %[[P3TO4]] + // CHECK-NEXT: %[[P3:.*]] = arith.cmpi ult, %[[I]], %[[C10]] + // CHECK-NEXT: scf.if %[[P3]] + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C6]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x4x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C10]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x5x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[P5:.*]] = arith.cmpi ult, %[[I]], %[[C21]] + // CHECK-NEXT: scf.if %[[P5]] + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C15]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x6x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } else { + // CHECK-NEXT: %[[O:.*]] = arith.subi %[[I]], %[[C21]] + // CHECK-NEXT: tensor.extract {{.*}}[{{.*}}, %[[O]], {{.*}}] : tensor<10x7x30xf32> + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: yield + // CHECK-NEXT: } + // CHECK-NEXT: return %[[CONCAT]] + )")); +} + TEST_F(ElementalHloToMlirTest, ConcatenateUnsigned) { TF_EXPECT_OK(Run(R"( ENTRY main { @@ -433,7 +506,7 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -445,11 +518,9 @@ TEST_F(ElementalHloToMlirTest, Pad) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -477,7 +548,7 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK-DAG: %[[C4:.*]] = arith.constant 4 // CHECK-DAG: %[[C7:.*]] = arith.constant 7 // CHECK: %[[CONSTRAINT_VAL:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2)>(%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) mod 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[CONSTRAINT:.*]] = arith.cmpi eq, %[[CONSTRAINT_VAL]], %[[C0]] // CHECK-DAG: %[[X_L:.*]] = arith.cmpi sge, %[[X]], %[[C1]] // CHECK-DAG: %[[X_H:.*]] = arith.cmpi sle, %[[X]], %[[C7]] @@ -489,11 +560,9 @@ TEST_F(ElementalHloToMlirTest, PadUnsigned) { // CHECK: %[[FROM_INPUT:.*]] = arith.andi %[[X_AND_CONSTRAINT]], %[[Y_BOUNDS]] // CHECK: %[[RET:.*]] = scf.if %[[FROM_INPUT]] // CHECK: %[[IN0:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2)> - // CHECK-SAME: (%[[X]] in [1, 7]) + // CHECK-SAME: <(d0) -> ((d0 - 1) floordiv 2), domain: d0 in [1, 7]>(%[[X]]) // CHECK: %[[IN1:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: <(d0) -> (d0 - 4)> - // CHECK-SAME: (%[[Y]] in [4, 7]) + // CHECK-SAME: <(d0) -> (d0 - 4), domain: d0 in [4, 7]>(%[[Y]]) // CHECK: %[[VAL:.*]] = tensor.extract %[[ARG0]][%[[IN0]], %[[IN1]]] // CHECK: scf.yield %[[VAL]] // CHECK: } else { @@ -810,11 +879,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionSimple) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -856,11 +925,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithWindowStrides) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[W]] in [0, 2])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 2], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 * 2 + s0)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -903,21 +972,21 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithPadding) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK-DAG: %[[TXGE:.+]] = arith.cmpi sge, %[[TESTX]], %[[C1]] : index // CHECK-DAG: %[[TXLE:.+]] = arith.cmpi sle, %[[TESTX]], %[[C8]] : index // CHECK-DAG: %[[TX:.+]] = arith.andi %[[TXGE]], %[[TXLE]] : i1 - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> (d0 + s0)>(%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), domain: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[TYGE:.+]] = arith.cmpi sge, %[[TESTY]], %[[C2]] : index // CHECK-DAG: %[[TYLE:.+]] = arith.cmpi sle, %[[TESTY]], %[[C13]] : index // CHECK-DAG: %[[TY:.+]] = arith.andi %[[TYGE]], %[[TYLE]] : i1 // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 1)> - // CHECK-SAME: (%[[W]] in [0, 7])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 1), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 - 2)> - // CHECK-SAME: (%[[H]] in [0, 11])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 - 2), + // CHECK-SAME: d0 in [0, 11], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -957,17 +1026,17 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithLhsDilation) { // CHECK: %[[R0:.+]] = scf.for %[[X:.+]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[A0:.+]] = %[[INIT]]) -> (f32) { // CHECK-NEXT: %[[R1:.+]] = scf.for %[[Y:.+]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[A1:.+]] = %[[A0]]) -> (f32) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { - // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-DAG: %[[TESTX:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK-DAG: %[[TX:.+]] = arith.cmpi eq, %[[TESTX]], %[[C0]] : index - // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing affine_map<(d0)[s0] -> ((d0 + s0) mod 2)>(%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-DAG: %[[TESTY:.+]] = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) mod 2), domain: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[TY:.+]] = arith.cmpi eq, %[[TESTY]], %[[C0]] : index // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[W]] in [0, 12])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2), + // CHECK-SAME: d0 in [0, 12], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 + s0) floordiv 2)> - // CHECK-SAME: (%[[H]] in [0, 18])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 + s0) floordiv 2), + // CHECK-SAME: d0 in [0, 18], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1009,11 +1078,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithRhsDilation) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[W]] in [0, 3])[%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0 * 2)> - // CHECK-SAME: (%[[H]] in [0, 3])[%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + // CHECK-SAME: d0 in [0, 3], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1055,17 +1124,14 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithFeatureGroupCount) { // CHECK-NEXT: %[[R2:.+]] = scf.for %[[I:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A1]]) -> (f32) { // CHECK: %[[R3:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK: %[[XX2:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0)> - // CHECK-SAME: (%[[O]] in [0, 15]) - // CHECK-SAME: [%[[I]] in [0, 1]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> ((d0 floordiv 8) * 2 + s0), + // CHECK-SAME: d0 in [0, 15], s0 in [0, 1]>(%[[O]])[%[[I]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[B]], %[[XX0]], %[[XX1]], %[[XX2]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<2x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1109,13 +1175,11 @@ TEST_F(ElementalHloToMlirTest, ConvolutionWithBatchGroupCount) { // CHECK-NEXT: %[[R3:.+]] = scf.for %[[G:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[A2]]) -> (f32) { // CHECK: %[[R4:.+]] = scf.if {{.+}} -> (f32) { // CHECK: %[[XX0:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[W]] in [0, 5]) - // CHECK-SAME: [%[[X]] in [0, 2]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 5], s0 in [0, 2]>(%[[W]])[%[[X]]] // CHECK: %[[XX1:.+]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0)[s0] -> (d0 + s0)> - // CHECK-SAME: (%[[H]] in [0, 7]) - // CHECK-SAME: [%[[Y]] in [0, 4]] + // CHECK-SAME: #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0), + // CHECK-SAME: d0 in [0, 7], s0 in [0, 4]>(%[[H]])[%[[Y]]] // CHECK-DAG: %[[VL:.+]] = tensor.extract %[[LHS]][%[[G]], %[[XX0]], %[[XX1]], %[[I]]] : tensor<2x8x12x4xf32> // CHECK-DAG: %[[VR:.+]] = tensor.extract %[[RHS]][%[[I]], %[[X]], %[[Y]], %[[O]]] : tensor<4x3x5x16xf32> // CHECK: %[[MUL:.+]] = arith.mulf %[[VL]], %[[VR]] : f32 @@ -1581,8 +1645,8 @@ TEST_F(ElementalHloToMlirTest, MixedIndexingTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[A:.*]] = tensor.extract %[[P0]][%[[X]], %[[Y]]] // CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing - // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]]) // CHECK: %[[B:.*]] = tensor.extract %[[P1]][%[[IDX]]] // CHECK: return %[[A]], %[[B]] )")); @@ -1605,8 +1669,8 @@ TEST_F(ElementalHloToMlirTest, NestedTuple) { // CHECK-SAME: %[[X:.*]]: index {{{.*}}}, %[[Y:.*]]: index {{{.*}}} // CHECK: %[[P0_V:.*]] = xla_gpu.pure_call @main_p0 // CHECK: %[[IDX:.*]] = - // CHECK-SAME: affine_map<(d0, d1) -> (d0 * 10 + d1)> - // CHECK-SAME: (%[[X]] in [0, 9], %[[Y]] in [0, 9]) + // CHECK-SAME: #xla_gpu.indexing_map<(d0, d1) -> (d0 * 10 + d1), + // CHECK-SAME: d0 in [0, 9], d1 in [0, 9]>(%[[X]], %[[Y]]) // CHECK: %[[P1_V:.*]] = xla_gpu.pure_call @main_p1 // CHECK-SAME: (%[[P0]], %[[P1]], %[[IDX]]) // CHECK: return %[[P0_V]], %[[P1_V]], %[[P1_V]], %[[P1_V]], %[[P0_V]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 251d3ff56b9f85..efb13ae94e090f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -83,11 +83,11 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/dump.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/passes.h" #include "xla/service/gpu/fusions/mlir/type_util.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emitter_context.h" #include "xla/service/gpu/kernel_arguments.h" #include "xla/service/gpu/kernel_reuse_cache.h" @@ -222,7 +222,7 @@ llvm::SmallVector MlirFusionEmitterBase::EmitThreadAndBlockIds( absl::StatusOr MlirFusionEmitterBase::Emit( IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion) const { - VLOG(5) << "Fusion: " << fusion.fused_instructions_computation()->ToString(); + VLOG(4) << "Fusion: " << fusion.fused_instructions_computation()->ToString(); TF_ASSIGN_OR_RETURN( auto args, KernelArguments::Create(ir_emitter_context.buffer_assignment(), &fusion)); @@ -305,13 +305,14 @@ MlirFusionEmitterBase::CreateLLVMModule( mlir::PassManager pm(&mlir_context); pm.addPass(CreateEraseDeadFunctionsPass()); pm.addPass(mlir::createCSEPass()); - pm.addPass(CreateLowerXlaGpuToScfPass()); + pm.addNestedPass(CreateLowerXlaGpuToScfPass()); pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { // CSE after inlining because inlining can introduce duplicates. pm.addPass(mlir::createCSEPass()); })); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); + pm.addNestedPass(CreateLowerXlaGpuLoopsToScfPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); pm.addPass(CreatePropagateSliceIndicesPass()); // We need LICM before unswitching loops, because our loop unswitcher only @@ -322,6 +323,7 @@ MlirFusionEmitterBase::CreateLLVMModule( // opportunities for LICM. This would not be necessary if LICM also moved // instructions over ifs. pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(CreateFlattenTensorsPass()); pm.addNestedPass(CreateVectorizeLoadsAndStoresPass()); pm.addNestedPass(CreateOptimizeLoopsPass()); pm.addNestedPass(CreateConvertPureCallOpsPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc index b76a953cf80076..4921e745b5176f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter_test.cc @@ -76,7 +76,7 @@ class DummyCopyFusionEmitter : public MlirFusionEmitterBase { const mlir_converter::PartitionedComputations& computations, const mlir_converter::CallTargetProvider& call_targets, mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const { + const HloFusionInstruction& fusion) const override { mlir::ImplicitLocOpBuilder b(entry_function.getLoc(), entry_function); b.setInsertionPointToStart(entry_function.addEntryBlock()); auto thread_id = EmitThreadId(b, 0); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD deleted file mode 100644 index 69b4bd09eed93a..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/BUILD +++ /dev/null @@ -1,43 +0,0 @@ -load("//xla:lit.bzl", "lit_test_suite") -load("//xla:xla.bzl", "xla_cc_binary") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - licenses = ["notice"], -) - -xla_cc_binary( - name = "mlir_fusions_opt", - srcs = ["mlir_fusions_opt.cc"], - deps = [ - "//xla/mlir_hlo", - "//xla/service/gpu/fusions/mlir:passes", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", - "@llvm-project//mlir:AffineDialect", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:ComplexDialect", - "@llvm-project//mlir:DLTIDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:FuncExtensions", - "@llvm-project//mlir:GPUDialect", - "@llvm-project//mlir:LLVMDialect", - "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:NVVMDialect", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorDialect", - ], -) - -lit_test_suite( - name = "tests", - srcs = glob(["*.mlir"]), - cfg = "//xla:lit.cfg.py", - tools = [ - ":mlir_fusions_opt", - "@llvm-project//llvm:FileCheck", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir deleted file mode 100644 index 17b0f8d9b45c88..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/canonicalize.mlir +++ /dev/null @@ -1,179 +0,0 @@ -// RUN: mlir_fusions_opt %s --split-input-file -canonicalize | FileCheck %s - -#map0 = affine_map<()[s0, s1] -> (1 + s0 + s1 mod 3 - s1, s0 mod 2)> -func.func @simplify_apply_indexing(%s0: index, %s1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [-10, 10], %s1 in [0, 2]] - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 + 1, s0 mod 2)> - -// CHECK-LABEL: func.func @simplify_apply_indexing -// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]][%[[ARG_0]] in [-10, 10]] - -// ----- - -#map0 = affine_map<(d0, d1, d2)[s0, s1] -> (1 + s0 + s1 mod 4 - s1, s0 mod 2, d0 + d2)> -func.func @simplify_apply_indexing_remove_dims(%d0: index, %d1: index, - %d2: index, %s0: index, %s1: index) -> (index, index, index) { - %0:3 = xla_gpu.apply_indexing #map0 - (%d0 in [0, 1], %d1 in [0, 2], %d2 in [0, 3]) - [%s0 in [-11, 11], %s1 in [0, 3]] - func.return %0#0, %0#1, %0#2 : index, index, index -} -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0] -> (s0 + 1, s0 mod 2, d0 + d1)> - -// CHECK-LABEL: func.func @simplify_apply_indexing_remove_dims -// CHECK-SAME: %[[ARG_0:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG_1:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG_2:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG_3:[a-zA-Z0-9_]+]]: index, -// CHECK-SAME: %[[ARG_4:[a-zA-Z0-9_]+]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG_0]] in [0, 1], %[[ARG_2]] in [0, 3]) -// CHECK-SAME: [%[[ARG_3]] in [-11, 11]] - -// ----- - -#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, 4, d1, 1, s0)> -func.func @fold_indexing_map_results(%d0: index, %d1: index, %s0: index) - -> (index, index, index, index, index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] - func.return %0#0, %0#1, %0#2, %0#3, %0#4 : index, index, index, index, index -} -// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> - -// CHECK-LABEL: func.func @fold_indexing_map_results -// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) - -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK: return %[[NEW_RESULT]], %[[C4]], %[[ARG_1]], %[[C1]], %[[ARG_2]] - -// ----- - -#map0 = affine_map<(d0, d1)[s0] -> (d0 + s0, s0 + 4, d1 mod 2, 1 + d1, s0)> -func.func @remove_unused_results(%d0: index, %d1: index, %s0: index) -> (index) { - %0:5 = xla_gpu.apply_indexing #map0 (%d0 in [-10, 10], %d1 in [0, 2])[%s0 in [-1, 1]] - func.return %0#2 : index -} -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> - -// CHECK-LABEL: func.func @remove_unused_results -// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) - -// CHECK: %[[NEW_RESULT:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG_1]] in [0, 2]) -// CHECK: return %[[NEW_RESULT]] - -// ----- - -#map0 = affine_map<(d0, d1)[s0, s1] -> (d0 + d1 + s0 + s1 mod 3)> -func.func @fold_operands(%d0: index) -> index { - %d1 = arith.constant 1 : index - %s0 = arith.constant 2 : index - %s1 = arith.constant 3 : index - %0 = xla_gpu.apply_indexing #map0 (%d0 in [0, 10], %d1 in [0, 5]) - [%s0 in [-10, 10], %s1 in [0, 4]] - func.return %0 : index -} -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 3)> - -// CHECK-LABEL: func.func @fold_operands -// CHECK-SAME: %[[ARG_0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP]](%[[ARG_0]] in [0, 10]) - -// ----- - -func.func @fold_operands_and_results(%arg0: index, %arg1: index) - -> (index, index) { - %0:2 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (0, d1)> - (%arg0 in [0, 4], %arg1 in [0, 5]) - return %0#0, %0#1 : index, index -} - -// CHECK-LABEL: func.func @fold_operands_and_results -// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -// CHECK-NEXT: %[[C0:.*]] = arith.constant 0 -// CHECK-NEXT: return %[[C0]], %[[ARG_1]] : index, index - -// ----- - -func.func @fold_sequence(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 100 + 42)> - (%0 in [0, 10000]) - func.return %1 : index -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)> -// CHECK-LABEL: func.func @fold_sequence -// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) - -// ----- - -func.func @fold_sequence_sym(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<()[s0] -> (s0 mod 100 + 42)> - [%0 in [0, 10000]] - func.return %1 : index -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 + d1 + 42)> -// CHECK-LABEL: func.func @fold_sequence_sym -// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG0]] in [0, 5], %[[ARG1]] in [0, 4]) - -// ----- - -func.func @fold_sequence_shared_operands(%arg0: index, %arg1: index) -> index { - %0 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg0 in [0, 5], %arg1 in [0, 4]) - %1 = xla_gpu.apply_indexing affine_map<(d0, d1) -> (d0 + d1)> - (%arg1 in [0, 4], %0 in [0, 10000]) - func.return %1 : index -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> -// CHECK-LABEL: func.func @fold_sequence_shared_operands -// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -// CHECK-NEXT: xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[ARG1]] in [0, 4], %[[ARG0]] in [0, 5]) - -// ----- - -func.func @atomic_rmw_empty(%in: tensor<2x3xf32>, %i: index, %j: index) - -> (tensor<2x3xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { - ^bb0(%current : f32): - xla_gpu.yield %current : f32 - } - return %ret : tensor<2x3xf32> -} -// CHECK-LABEL: func.func @atomic_rmw_empty -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32> -// CHECK: return %[[ARG0]] - - -// ----- - -func.func @atomic_rmw_cst(%in: tensor<2x3xf32>, %i: index, %j: index) - -> (tensor<2x3xf32>) { - %cst = arith.constant 0.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { - ^bb0(%current : f32): - xla_gpu.yield %cst : f32 - } - return %ret : tensor<2x3xf32> -} -// CHECK-LABEL: func.func @atomic_rmw_cst -// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3xf32> -// CHECK-NEXT: %[[CST:.*]] = arith.constant -// CHECK-NEXT: atomic_rmw -// CHECK: xla_gpu.yield %[[CST]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir deleted file mode 100644 index 3ea853dc8d0d19..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/indexing_map_attr.mlir +++ /dev/null @@ -1,139 +0,0 @@ -// RUN: mlir_fusions_opt %s -split-input-file | mlir_fusions_opt -split-input-file | FileCheck %s - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: d2 in [10, 12] -// CHECK-NEXT: s0 in [0, 32] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: d0 + s0 in [1, 10] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) - domain: - d0 in [1, 2] - d1 in [5, 8] - d2 in [10, 12] - s0 in [0, 32] - d0 mod 2 in [0, 1] - d0 + s0 in [1, 10] - > - -func.func private @indexing_map_attr(tensor<32xf64, #map>) -// CHECK-LABEL: @indexing_map_attr -// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: s0 in [0, 10] -// CHECK-NEXT: s1 in [0, 5] -// CHECK-NEXT: s2 in [0, 32] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: d0 + s0 in [1, 10] -// CHECK-NEXT: d1 + s1 + s2 in [1, 32] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1, s2] -> (d0 + s0, d1 + s1, d1 + s2) - domain: - d0 in [1, 2] - d1 in [5, 8] - s0 in [0, 10] - s1 in [0, 5] - s2 in [0, 32] - d0 mod 2 in [0, 1] - d0 + s0 in [1, 10] - d1 + s1 + s2 in [1, 32] - > -func.func private @more_range_vars(tensor<32xf64, #map>) -// CHECK-LABEL: @more_range_vars -// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [0, 100] -// CHECK-NEXT: s0 in [-3, -1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0) - domain: - d0 in [0, 100] - s0 in [-3, -1] - > -func.func private @indexing_map_small(tensor<100xf64, #map>) -// CHECK-LABEL: @indexing_map_small -// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0, d1, d2)[s0] -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [1, 2] -// CHECK-NEXT: d1 in [5, 8] -// CHECK-NEXT: d2 in [10, 12] -// CHECK-NEXT: s0 in [0, 32] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0) - domain: - d0 in [1, 2] - d1 in [5, 8] - d2 in [10, 12] - s0 in [0, 32] - > -func.func private @no_constraints(tensor<32xf64, #map>) -// CHECK-LABEL: @no_constraints -// CHECK: tensor<32xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: ()[s0] -> (s0) -// CHECK-NEXT: domain: -// CHECK-NEXT: s0 in [3, 5] -// CHECK-NEXT: s0 mod 2 in [0, 1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<()[s0] -> (s0) - domain: - s0 in [3, 5] - s0 mod 2 in [0, 1] - > -func.func private @no_dimensions(tensor<100xf64, #map>) -// CHECK-LABEL: @no_dimensions -// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: (d0) -> (d0) -// CHECK-NEXT: domain: -// CHECK-NEXT: d0 in [3, 5] -// CHECK-NEXT: d0 mod 2 in [0, 1] -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<(d0) -> (d0) - domain: - d0 in [3, 5] - d0 mod 2 in [0, 1] - > -func.func private @no_symbols(tensor<100xf64, #map>) -// CHECK-LABEL: @no_symbols -// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> - -// ----- - -// CHECK: #[[$INDEX_MAP:.*]] = #xla_gpu.indexing_map< -// CHECK-NEXT: () -> () -// CHECK-NEXT: domain: -// CHECK-NEXT: > -#map = #xla_gpu.indexing_map<() -> () - domain: - > -func.func private @empty(tensor<100xf64, #map>) -// CHECK-LABEL: @empty -// CHECK: tensor<100xf64, #[[$INDEX_MAP]]> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir deleted file mode 100644 index fbef7c049db487..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/invalid.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: mlir_fusions_opt %s -split-input-file -verify-diagnostics - -#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> -func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - // expected-error @+1 {{operand, lower_bounds, upper_bounds count and affine map dimension and symbol count must match}} - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2]) - func.return %0#0, %0#1 : index, index -} diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir deleted file mode 100644 index c7f15073b5e0ed..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/ops.mlir +++ /dev/null @@ -1,96 +0,0 @@ -// R-UN: mlir_fusions_opt %s --split-input-file | FileCheck %s -// Verify the printed output can be parsed. -// RU-N: mlir_fusions_opt %s --split-input-file | mlir_fusions_opt --split-input-file | FileCheck %s -// Verify the generic form can be parsed. -// RUN: mlir_fusions_opt %s --split-input-file --mlir-print-op-generic | mlir_fusions_opt --split-input-file | FileCheck %s - -func.func @shared_and_sync() -> (tensor<2xf32>, tensor<2xf32>) { - %shared1 = xla_gpu.allocate_shared : tensor<2xf32> - %shared2 = xla_gpu.allocate_shared : tensor<2xf32> - %sync:2 = xla_gpu.sync_threads %shared1, %shared2 - : tensor<2xf32>, tensor<2xf32> - return %sync#0, %sync#1 : tensor<2xf32>, tensor<2xf32> -} -// CHECK-LABEL: @shared_and_sync -// CHECK-NEXT: allocate_shared -// CHECK-NEXT: allocate_shared -// CHECK-NEXT: sync_threads -// CHECK-NEXT: return - -// ----- - -func.func @atomic_rmw(%in: tensor<2x3xf32>, %i: index, %j: index) - -> (tensor<2x3xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x3xf32> { - ^bb0(%current : f32): - %c42 = arith.constant 42.0 : f32 - %add = arith.addf %current, %c42 : f32 - xla_gpu.yield %add : f32 - } - return %ret : tensor<2x3xf32> -} -// CHECK-LABEL: @atomic_rmw -// CHECK: xla_gpu.atomic_rmw - -// ----- - -func.func private @add(%a: f32, %b: f32) -> f32 { - %ret = arith.addf %a, %b : f32 - return %ret : f32 -} - -func.func @caller(%a: f32, %b: f32) -> f32 { - %c = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %d = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) - %ret = arith.addf %c, %d : f32 - return %ret : f32 -} -// CHECK-LABEL: @caller -// CHECK: %[[C:.*]] = xla_gpu.pure_call @add -// CHECK: %[[D:.*]] = xla_gpu.pure_call @add -// CHECK: arith.addf %[[C]], %[[D]] - -// CHECK-CSE: @caller -// CHECK-CSE: %[[C:.*]] = xla_gpu.pure_call @add -// CHECK-CSE: arith.addf %[[C]], %[[C]] - -// ----- - -#map0 = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> -func.func @apply_indexing(%d0: index, %d1: index, %s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3])[%s0 in [2, 4]] - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0, d1 + s0)> - -// CHECK-LABEL: @apply_indexing -// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index, %[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3])[%[[s0]] in [2, 4]] - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> -func.func @apply_indexing_no_symbols(%d0: index, %d1: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 (%d0 in [0, 2], %d1 in [1, 3]) - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> - -// CHECK-LABEL: @apply_indexing_no_symbols -// CHECK: (%[[d0:.*]]: index, %[[d1:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]] -// CHECK-SAME: (%[[d0]] in [0, 2], %[[d1]] in [1, 3]) - -// ----- - -#map0 = affine_map<()[s0] -> (s0, s0)> -func.func @apply_indexing_no_dims(%s0: index) -> (index, index) { - %0:2 = xla_gpu.apply_indexing #map0 [%s0 in [2, 4]] - func.return %0#0, %0#1 : index, index -} -// CHECK: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0, s0)> - -// CHECK-LABEL: @apply_indexing_no_dims -// CHECK: (%[[s0:.*]]: index) -// CHECK: xla_gpu.apply_indexing #[[$MAP0]][%[[s0]] in [2, 4]] diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir deleted file mode 100644 index 1141d1581505ea..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/vectorize_loads_stores.mlir +++ /dev/null @@ -1,405 +0,0 @@ -// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file -xla-gpu-vectorize-loads-stores -canonicalize -cse | FileCheck %s - -#map = affine_map<(d0)[s0] -> (d0 * 2 + s0)> -module { - func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-LABEL: @simple_read -// CHECK-SAME: (%[[ARG0:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #map(%[[I]] in [0, 63]) -// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] -// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] -// CHECK-NEXT: vector.extract %[[V]][%[[J]]] -// CHECK-NEXT: addf - -// ----- - -module { - func.func @simple_read_2d(%arg0: tensor<64x2xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @simple_read_2d -// CHECK-SAME: (%[[ARG0:.*]]: tensor -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[C0]]] -// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] -// CHECK-NEXT: vector.extract %[[V]][%[[J]]] - -// ----- - -#map = affine_map<(d0)[s0] -> (d0 * 2 + s0 + 1)> -module { - func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @misaligned_indexing_map -// CHECK-NOT: vector.transfer_read - -// ----- - -#map = affine_map<(d0)[s0] -> (d0 * 3 + s0)> -module { - func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @misaligned_indexing_map_2 -// CHECK-NOT: vector.transfer_read - -// ----- - -module { - func.func @misaligned_shape(%arg0: tensor<64x3xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x3xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @misaligned_shape -// CHECK-NOT: vector.transfer_read - -// ----- - -#map = affine_map<(d0)[s0] -> (d0 + s0 * 2)> -module { - func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @wrong_stride -// CHECK-NOT: vector.transfer_read - -// ----- - -// We could vectorize this as a float vector load of double the size, but we -// don't currently. -module { - func.func @simple_read_complex(%arg0: tensor<64x2xcomplex>, %i: index) -> (complex) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex - %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex { - %extracted = tensor.extract %arg0[%i, %j] : tensor<64x2xcomplex> - %added = complex.add %iter, %extracted : complex - scf.yield %added : complex - } - return %loop : complex - } -} - -// CHECK-LABEL: @simple_read_complex -// CHECK-NOT: vector.transfer_read - -// ----- - -// This is vectorizable, but not currently supported. -module { - func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c64 = arith.constant 64 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %extracted = tensor.extract %arg0[%j, %i] - : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @layout -// CHECK-NOT: vector.transfer_read - -// ----- - -module { - func.func @simple_write(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32> - scf.yield %inserted : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> - } -} - -// CHECK-LABEL: @simple_write -// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[I:.*]]: index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[V:.*]] = scf.for -// CHECK-NEXT: vector.insert -// CHECK-NEXT: scf.yield -// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[I]], %[[C0]]] -// CHECK-NEXT: return %[[WRITTEN]] - -// ----- - -module { - func.func @write_with_use(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %iter[%i, %j] : tensor<16x4xf32> - "dummy.op1"(%inserted) : (tensor<16x4xf32>) -> () - scf.yield %inserted : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> - } -} - -// CHECK-LABEL: @write_with_use -// CHECK-NOT: transfer_write - -// ----- - -module { - func.func @write_not_to_iter_arg(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32> - scf.yield %inserted : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> - } -} - -// CHECK-LABEL: @write_not_to_iter_arg -// CHECK-NOT: transfer_write - -// ----- - -module { - func.func @write_not_yielded(%arg0: tensor<16x4xf32>, %i: index) -> tensor<16x4xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c4 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<16x4xf32> { - %inserted = tensor.insert %cst into %arg0[%i, %j] : tensor<16x4xf32> - scf.yield %arg0 : tensor<16x4xf32> - } - return %loop : tensor<16x4xf32> - } -} - -// CHECK-LABEL: @write_not_yielded -// CHECK-NOT: transfer_write - -// ----- - -#map = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)> -module { - func.func @multiple(%arg0: tensor<32x4096xf32>, %arg1: tensor<4096xbf16>, - %arg2: tensor<32xf32>, %arg3: tensor<32x4096xf32>, - %arg4: index) -> (tensor<32x4096xf32>, f32) { - %cst = arith.constant 1.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c8 = arith.constant 8 : index - %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> - %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<32x4096xf32>, f32) { - %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<32x4096xf32>, f32) { - %2 = xla_gpu.apply_indexing #map(%j in [0, 1], %arg4 in [0, 255])[%i in [0, 7]] - %extracted2 = tensor.extract %arg0[%i, %2] : tensor<32x4096xf32> - %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> - %3 = arith.extf %extracted3 : bf16 to f32 - %4 = arith.addf %extracted2, %3 : f32 - %5 = arith.addf %extracted1, %4 : f32 - %6 = arith.addf %iter3, %5 : f32 - %inserted = tensor.insert %5 into %iter2[%i, %2] : tensor<32x4096xf32> - scf.yield %inserted, %6 : tensor<32x4096xf32>, f32 - } - scf.yield %1#0, %1#1 : tensor<32x4096xf32>, f32 - } - return %0#0, %0#1 : tensor<32x4096xf32>, f32 - } -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0] -> (d0 * 2 + s0 * 512)> -// CHECK-LABEL: @multiple -// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]] in [0, 255])[%[[I]] in [0, 7]] -// CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] -// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[I]], %[[BASE]]] -// CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) -// CHECK-DAG: vector.extract %[[READ1]][%[[J]]] -// CHECK-DAG: vector.extract %[[READ2]][%[[J]]] -// CHECK: extf -// CHECK-NEXT: addf -// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf -// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf -// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]] -// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]] -// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[I]], %[[BASE]]] -// CHECK: scf.yield %[[WRITTEN]], %[[INNER]]#0 - -// ----- - -#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0)> -module { - func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> ((d0 mod 16) * 4)> -// CHECK-LABEL: @remainder_with_modulo -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: scf.for %[[I:.*]] = %[[C0]] -// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] -// CHECK: vector.transfer_read {{.*}}[%[[BASE]]] - -// ----- - -#map = affine_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0)> -module { - func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c63 = arith.constant 63 : index - %cst = arith.constant 0.0 : f32 - %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { - %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { - %idx = xla_gpu.apply_indexing #map(%i in [0, 63])[%j in [0, 1]] - %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> - %added = arith.addf %iter1, %extracted : f32 - scf.yield %added : f32 - } - scf.yield %inner : f32 - } - return %outer : f32 - } -} - -// CHECK-LABEL: @remainder_with_modulo_misaligned -// CHECK-NOT: vector.transfer_read diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc index 2d2b203bd53383..810d794120bfdc 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.cc @@ -38,7 +38,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "xla/tests/filecheck.h" @@ -84,7 +84,7 @@ absl::StatusOr MlirEmitterTestBaseImpl::EmitIR( TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_string)); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); auto fusion_emitter = GetEmitter(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h index 0006dc53683c2a..3b0c78cc760631 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir_emitter_test_base.h @@ -72,7 +72,7 @@ class MlirEmitterTestBase : public MlirEmitterTestBaseImpl { auto& module = modules_.emplace_back(ParseAndReturnVerifiedModule(hlo_string).value()); auto* root = module->entry_computation()->root_instruction(); - analyses_.push_back(AnalyzeFusion(*root, device_info_)); + analyses_.push_back(HloFusionAnalysis::Create(*root, device_info_)); return GetEmitter(analyses_.back()); } diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index cf2aa130082125..e0af3b4ea2aabe 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -36,7 +36,6 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" @@ -183,11 +182,12 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, } absl::flat_hash_set instructions; - + for (const HloInstruction* operand : analysis.fusion().GetParameters()) { + instructions.insert(HloInstructionAdaptor{*operand, &analysis.fusion()}); + } auto visit = [&](absl::Span roots) { HloBfsConsumersFirstTraversal( - roots, analysis.fusion(), - [&](HloInstructionAdaptor consumer) { + roots, analysis.fusion(), [&](HloInstructionAdaptor consumer) { auto& consumer_reachable = reachable_outputs[consumer]; for (auto producer : consumer.GetOperands()) { reachable_outputs[producer].insert(consumer_reachable.begin(), @@ -195,8 +195,7 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, } instructions.insert(consumer); return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor argument) { instructions.insert(argument); }); + }); }; // The legacy emitter grouping is buggy: it does not visit instructions in the diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index d58f809b1c3c41..075678d2e58605 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -47,9 +47,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/fusions/reduction_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" @@ -79,79 +79,6 @@ constexpr int kRowMajorReduced = ReductionDimensions::kRowMajorReducedDimension; constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; constexpr int kRowMinorReduced = ReductionDimensions::kRowMinorReducedDimension; -LaunchDimensions MlirReductionFusion::launch_dimensions() const { - size_t blocks_y = groups_.grouped_roots.size(); - return {se::BlockDim(/*x=*/Product(num_blocks_), - /*y=*/static_cast(blocks_y), /*z=*/1), - se::ThreadDim(/*x=*/Product(num_threads_), - /*y=*/1, /*z=*/1)}; -} - -MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - Shape input_shape = hero_reduction->operand(0)->shape(); - reduction_dimensions_ = - GetReductionKindAndContiguousComponents(*hero_reduction); - VLOG(10) << reduction_dimensions_; - - CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(), - reduction_dimensions_)) - << "Non-race-free reductions should have been decomposed. Did " - "tree_reduction_rewriter run?"; - - groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true); - first_reduce_ = hero_reduction; - - const auto& groups = GetGroups(); - int num_groups = groups.grouped_roots.size(); - side_output_roots_.resize(num_groups); - reduction_heroes_.resize(num_groups); - reduction_roots_.resize(num_groups); - - absl::flat_hash_set seen_heroes; - for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] : - llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(), - groups.is_reduction_root, groups.group_id_per_root)) { - const HloInstruction* root = &root_adaptor.instruction(); - const HloInstruction* hero = &hero_adaptor.instruction(); - if (is_reduction) { - if (seen_heroes.insert(hero).second) { - reduction_heroes_[group_id].push_back(hero); - } - reduction_roots_[group_id].push_back(root); - } else { - side_output_roots_[group_id].push_back(root); - } - } -} - -IndexingMap MlirReductionFusion::GetIndexingMap( - llvm::ArrayRef results, - absl::Span symbol_sizes) const { - auto* ctx = results.front().getContext(); - auto num_groups = static_cast(reduction_heroes_.size()); - return IndexingMap{ - AffineMap::get(6, symbol_sizes.size(), results, ctx), - DimVarsFromTensorSizes( - {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}), - RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}}; -} - -IndexingMap MlirReductionFusion::GetThreadIndexingMap( - llvm::ArrayRef results, - absl::Span const> constraints, - absl::Span symbol_sizes) const { - auto affine_map = AffineMap::get(1, symbol_sizes.size(), results, - results.front().getContext()); - return IndexingMap{affine_map, - DimVarsFromTensorSizes({Product(num_threads_)}), - RangeVarsFromTensorSizes(symbol_sizes), - /*rt_vars=*/{}, constraints}; -} - struct PerThreadOutputs { // The partially reduced scalars for each thread. HloValueMap reduction_scalars; @@ -232,140 +159,6 @@ struct MlirReductionFusion::EmitterState { SmallVector thread_and_block_ids; }; -std::vector -MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, - MLIRContext* mlir_context) const { - std::vector epilogues; - epilogues.reserve(reduction_heroes_.size()); - for (const auto& [heroes, roots] : - llvm::zip(reduction_heroes_, reduction_roots_)) { - epilogues.push_back( - mlir_converter::EpilogueSpecification::FromOutputIndexing( - analysis_, heroes, roots, *this, mlir_context)); - } - // Add empty epilogues for the side outputs. This ensures their roots don't - // get "fused" into the tuple function. - for (const auto& roots : side_output_roots_) { - for (const auto* root : roots) { - epilogues.push_back( - mlir_converter::EpilogueSpecification::FromIdentityIndexing( - root, root, mlir_context)); - } - } - return epilogues; -} - -absl::Status MlirReductionFusion::EmitEntryFunction( - const PartitionedComputations& computations, - const mlir_converter::CallTargetProvider& call_targets, - mlir::func::FuncOp entry_function, - const HloFusionInstruction& fusion) const { - EmitterState state{*this, entry_function, fusion, computations, call_targets}; - auto& b = state.builder; - b.setInsertionPointToStart(entry_function.addEntryBlock()); - state.thread_and_block_ids = EmitThreadAndBlockIds(b); - if (reduction_heroes_.size() == 1) { - b.create(EmitReduction(0, state)); - return absl::OkStatus(); - } - SmallVector cases(reduction_heroes_.size() - 1); - absl::c_iota(cases, 1); // `default` is region 0. - auto switch_op = b.create( - entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size()); - b.create(switch_op.getResults()); - for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) { - b.setInsertionPointToStart(®ion.emplaceBlock()); - b.create(EmitReduction(id, state)); - } - return absl::OkStatus(); -} - -IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing( - mlir::MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); - auto block_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_); - auto major_reduced = getAffineSymbolExpr(0, ctx); - auto minor_reduced = getAffineSymbolExpr(1, ctx); - auto vector_index = getAffineSymbolExpr(2, ctx); - - SmallVector indices{ - major_reduced, - block_id[0] * tile_sizes_per_block_[0] + thread_id[0], - block_id[1] * tile_sizes_per_block_[1] + - (minor_reduced * num_threads_[1]) + thread_id[1], - vector_index, - }; - - auto map = GetIndexingMap(indices, tile_sizes_per_thread_); - for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) { - map.AddConstraint(result, {0, input_dim - 1}); - } - return map; -} - -IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( - MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); - auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx) - : mlir::getAffineDimExpr(3, ctx); - IndexingMap projected_index = - GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), - {0, 0}); - // We don't need a constraint on the loop dimensions, because they are removed - // by GetIndexingMap (since they don't show up in the output index - // computation). - return projected_index; -} - -IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( - mlir::MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); - auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx) - : mlir::getAffineDimExpr(3, ctx); - auto major_reduced = getAffineSymbolExpr(0, ctx); - auto vector_index = getAffineSymbolExpr(1, ctx); - - SmallVector indices{ - major_reduced, block_id * num_threads_[0] + thread_id[0], - thread_id[1] * tile_sizes_per_thread_[1] + vector_index}; - - auto map = GetIndexingMap(indices, tile_sizes_per_thread_); - for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) { - map.AddConstraint(result, {0, input_dim - 1}); - } - return map; -} - -IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing( - MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); - auto block_id = - DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_); - IndexingMap projected_index = - GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]); - projected_index.AddConstraint(thread_id[1], {0, 0}); - return projected_index; -} - -HloValueMap MlirReductionFusion::GetInits(int group_id, - EmitterState& state) const { - HloValueMap result; - const auto& reductions = reduction_heroes_[group_id]; - for (auto* hero : reductions) { - int arity = hero->operand_count() / 2; - result[hero] = ProvideParameterRange(state.computation, hero, arity, arity, - {}, state.call_target, - state.entry_function, state.builder); - } - return result; -} - PerThreadOutputs MlirReductionFusion::EmitterState::EmitPerThreadElements( int group_id, const HloValueMap& inits, const SmallVector& outputs) { auto tile_indexing = @@ -558,45 +351,179 @@ mlir::ValueRange MlirReductionFusion::EmitterState::ReduceViaSharedMemory( }); } -std::optional MlirReductionFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const { - const auto& hero = analysis_.fusion_hero(root_index).instruction(); - if (groups_.is_reduction_root[root_index] && - hero_operand_index >= hero.operand_count() / 2) { - // We don't have indexing for the init values. - return std::nullopt; - } - if (!groups_.is_reduction_root[root_index]) { - return ComposeIndexingMaps( - *ComputeThreadIdToOutputIndexing(root_index, ctx), - *ComputeOutputToInputIndexing( - &analysis_.fusion_root(root_index).instruction(), 0, ctx) - .indexing_maps[hero_operand_index] - .begin()); - } - auto projected_map = ComputeReductionInputIndexing(ctx); - AddGroupIdConstraint(projected_map, root_index, groups_); - auto map = projected_map * - GetBitcastMap(input_shape_, - hero.operand(hero_operand_index)->shape(), ctx); - map.Simplify(); - return map; -} +MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) + : analysis_(analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + Shape input_shape = hero_reduction->operand(0)->shape(); + reduction_dimensions_ = + GetReductionKindAndContiguousComponents(*hero_reduction); + VLOG(10) << reduction_dimensions_; -std::optional MlirReductionFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, MLIRContext* ctx) const { - if (!groups_.is_reduction_root[root_index]) { - auto map = ComposeIndexingMaps( - ComputeReductionInputIndexing(ctx), - GetBitcastMap(input_shape_, analysis_.fusion_root(root_index).shape(), - ctx)); - AddGroupIdConstraint(map, root_index, groups_); - map.Simplify(); - return map; - } + CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(), + reduction_dimensions_)) + << "Non-race-free reductions should have been decomposed. Did " + "tree_reduction_rewriter run?"; - auto projected_indexing = ComputeReductionOutputIndexing(ctx); - auto output_shape = reduction_dimensions_.GetOutputShape(); + groups_ = GroupDisjointReductions(analysis, /*for_mlir=*/true); + first_reduce_ = hero_reduction; + + const auto& groups = GetGroups(); + int num_groups = groups.grouped_roots.size(); + side_output_roots_.resize(num_groups); + reduction_heroes_.resize(num_groups); + reduction_roots_.resize(num_groups); + + absl::flat_hash_set seen_heroes; + for (auto [root_adaptor, hero_adaptor, is_reduction, group_id] : + llvm::zip(analysis.fusion_roots(), analysis.fusion_heroes(), + groups.is_reduction_root, groups.group_id_per_root)) { + const HloInstruction* root = &root_adaptor.instruction(); + const HloInstruction* hero = &hero_adaptor.instruction(); + if (is_reduction) { + if (seen_heroes.insert(hero).second) { + reduction_heroes_[group_id].push_back(hero); + } + reduction_roots_[group_id].push_back(root); + } else { + side_output_roots_[group_id].push_back(root); + } + } +} + +IndexingMap MlirReductionFusion::GetIndexingMap( + llvm::ArrayRef results, + absl::Span symbol_sizes) const { + auto* ctx = results.front().getContext(); + auto num_groups = static_cast(reduction_heroes_.size()); + return IndexingMap{ + AffineMap::get(6, symbol_sizes.size(), results, ctx), + DimVarsFromTensorSizes( + {Product(num_threads_), 1, 1, Product(num_blocks_), num_groups, 1}), + RangeVarsFromTensorSizes(symbol_sizes), + /*rt_vars=*/{}}; +} + +IndexingMap MlirReductionFusion::GetThreadIndexingMap( + llvm::ArrayRef results, + absl::Span const> constraints, + absl::Span symbol_sizes) const { + auto affine_map = AffineMap::get(1, symbol_sizes.size(), results, + results.front().getContext()); + return IndexingMap{affine_map, + DimVarsFromTensorSizes({Product(num_threads_)}), + RangeVarsFromTensorSizes(symbol_sizes), + /*rt_vars=*/{}, constraints}; +} + +LaunchDimensions MlirReductionFusion::launch_dimensions() const { + size_t blocks_y = groups_.grouped_roots.size(); + return {se::BlockDim(/*x=*/Product(num_blocks_), + /*y=*/static_cast(blocks_y), /*z=*/1), + se::ThreadDim(/*x=*/Product(num_threads_), + /*y=*/1, /*z=*/1)}; +} + +std::vector +MlirReductionFusion::GetEpilogues(const HloFusionInstruction& fusion, + MLIRContext* mlir_context) const { + std::vector epilogues; + epilogues.reserve(reduction_heroes_.size()); + for (const auto& [heroes, roots] : + llvm::zip(reduction_heroes_, reduction_roots_)) { + epilogues.push_back( + mlir_converter::EpilogueSpecification::FromOutputIndexing( + analysis_, heroes, roots, *this, mlir_context)); + } + // Add empty epilogues for the side outputs. This ensures their roots don't + // get "fused" into the tuple function. + for (const auto& roots : side_output_roots_) { + for (const auto* root : roots) { + epilogues.push_back( + mlir_converter::EpilogueSpecification::FromIdentityIndexing( + root, root, mlir_context)); + } + } + return epilogues; +} + +absl::Status MlirReductionFusion::EmitEntryFunction( + const PartitionedComputations& computations, + const mlir_converter::CallTargetProvider& call_targets, + mlir::func::FuncOp entry_function, + const HloFusionInstruction& fusion) const { + EmitterState state{*this, entry_function, fusion, computations, call_targets}; + auto& b = state.builder; + b.setInsertionPointToStart(entry_function.addEntryBlock()); + state.thread_and_block_ids = EmitThreadAndBlockIds(b); + if (reduction_heroes_.size() == 1) { + b.create(EmitReduction(0, state)); + return absl::OkStatus(); + } + SmallVector cases(reduction_heroes_.size() - 1); + absl::c_iota(cases, 1); // `default` is region 0. + auto switch_op = b.create( + entry_function.getResultTypes(), EmitBlockId(b, 1), cases, cases.size()); + b.create(switch_op.getResults()); + for (auto [id, region] : llvm::enumerate(switch_op->getRegions())) { + b.setInsertionPointToStart(®ion.emplaceBlock()); + b.create(EmitReduction(id, state)); + } + return absl::OkStatus(); +} + +HloValueMap MlirReductionFusion::GetInits(int group_id, + EmitterState& state) const { + HloValueMap result; + const auto& reductions = reduction_heroes_[group_id]; + for (auto* hero : reductions) { + int arity = hero->operand_count() / 2; + result[hero] = ProvideParameterRange(state.computation, hero, arity, arity, + {}, state.call_target, + state.entry_function, state.builder); + } + return result; +} + +std::optional MlirReductionFusion::ComputeThreadIdToInputIndexing( + int64_t root_index, int64_t hero_operand_index, MLIRContext* ctx) const { + const auto& hero = analysis_.fusion_hero(root_index).instruction(); + if (groups_.is_reduction_root[root_index] && + hero_operand_index >= hero.operand_count() / 2) { + // We don't have indexing for the init values. + return std::nullopt; + } + if (!groups_.is_reduction_root[root_index]) { + return ComposeIndexingMaps( + *ComputeThreadIdToOutputIndexing(root_index, ctx), + *ComputeOutputToInputIndexing( + &analysis_.fusion_root(root_index).instruction(), 0, ctx) + .indexing_maps[hero_operand_index] + .begin()); + } + auto projected_map = ComputeReductionInputIndexing(ctx); + AddGroupIdConstraint(projected_map, root_index, groups_); + auto map = projected_map * + GetBitcastMap(input_shape_, + hero.operand(hero_operand_index)->shape(), ctx); + map.Simplify(); + return map; +} + +std::optional MlirReductionFusion::ComputeThreadIdToOutputIndexing( + int64_t root_index, MLIRContext* ctx) const { + if (!groups_.is_reduction_root[root_index]) { + auto map = ComposeIndexingMaps( + ComputeReductionInputIndexing(ctx), + GetBitcastMap(input_shape_, analysis_.fusion_root(root_index).shape(), + ctx)); + AddGroupIdConstraint(map, root_index, groups_); + map.Simplify(); + return map; + } + + auto projected_indexing = ComputeReductionOutputIndexing(ctx); + auto output_shape = reduction_dimensions_.GetOutputShape(); CHECK_EQ(output_shape.size(), projected_indexing.GetAffineMap().getNumResults()); for (auto [result, dim_size] : llvm::zip( @@ -636,189 +563,13 @@ SmallVector MlirReductionFusion::EvaluateEpilogue( auto output_indices = mlir_converter::ApplyIndexing( epilogue.root_indexing[index], state.thread_and_block_ids, symbol_values, b); - for (auto [result_index, result] : llvm::enumerate(values.at(root))) { - auto& output = outputs[state.OutputIndex(root, result_index)]; - output = b.create(thread_has_output, result, output, - output_indices); - } - } - return outputs; -} - -MlirRowReductionFusion::MlirRowReductionFusion( - const HloFusionAnalysis& analysis) - : MlirReductionFusion(analysis) { - CHECK(reduction_dimensions_.is_row_reduction); - Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); - constexpr int64_t kMinorReducedElementsPerThread = 16; - - int64_t num_threads_kept = 1; - int64_t num_threads_reduced = [&] { - int64_t max_block_size = - MinThreadsXRowReduction(first_reduce_->GetModule()->config()); - return std::min(max_block_size, - RoundUpTo(CeilOfRatio(shape[kRowMinorReduced], - kMinorReducedElementsPerThread), - WarpSize())); - }(); - - // If we're limited by the size of the x dimension, add additional parallelism - // in the y dimension. The code generator doesn't currently support - // parallelizing the z dimension (major reduced dimensions). The general - // recommendation is to use between 128 and 512 threads, so we just go for - // 256. See https://forums.developer.nvidia.com/t/55529 - constexpr int64_t kThreadsPerBlockTarget = 256; - if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) { - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - // Increase the size of the y dimension as long as there's remaining - // parallelism. - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - } - - int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_, - num_threads_reduced); - num_threads_ = {num_threads_kept, num_threads_reduced}; - // TODO(jreiffers): Get rid of `vector_size` in here. - input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size}; - // TODO(jreiffers): Tighten ranges based on constraints when simplifying - // instead of using min here. For example, based on - // - // s1 in [0, 127] - // d0 floordiv 32 + s1 * 32 in [0, 63] - // - // Tighten the bound of s1 to [0, 1]. - int minor_reduced_tile_size = - std::min(kMinorReducedElementsPerThread / vector_size, - CeilOfRatio(input_shape_[2], num_threads_[1])); - - tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size}; - tile_sizes_per_block_ = {num_threads_kept, - minor_reduced_tile_size * num_threads_reduced}; - num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]), - CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])}; -} - -MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis) - : MlirReductionFusion(analysis) { - CHECK(reduction_dimensions_.is_row_reduction); - Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); - input_shape_ = {shape[0], shape[1], shape[2]}; - CHECK_GT(rows_per_warp, 1); - - auto compute_block_size = [&](int vector_size) { - int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - num_threads_ = {num_threads_kept, num_threads_reduced}; - tile_sizes_per_thread_ = {shape[0], vector_size}; - num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; - }; - - // Compute the launch grid without vectorization. We use the results to - // compute the vectorized launch grid. - compute_block_size(1); - - // Normally, we only consider input types for vectorization. However, in - // multi-row reductions, the input:output ratio is much higher, so we consider - // both inputs and outputs. - int smallest_input_or_output_bits = - std::min(analysis.input_output_info().smallest_input_dtype_bits, - analysis.input_output_info().smallest_output_dtype_bits); - - // This vector size is always valid: we know that the reduced dimension is a - // power of 2, since otherwise RowReductionGetRowsPerWarp would have - // returned 1. - // Our codegen can't currently deal with vectorization across rows, so we - // limit the vector size to the size of the row. Note that this emitter - // essentially reverts to the loop emitter in this case, except for side - // outputs. - int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), - 32 / smallest_input_or_output_bits); - - // We target 8 warps per block, which means there could be up to 8 blocks per - // SM, but we have no good way of knowing. In practice, enabling vectorization - // for decently sized reductions at least does not hurt. - if (num_blocks_.front() > analysis.device_info().core_count() && - vector_size > 1) { - compute_block_size(vector_size); - } -} - -int MlirMultiRowReductionFusion::GetRowsPerWarp() const { - return RowReductionGetRowsPerWarp( - input_shape_[ReductionDimensions::kRowMinorReducedDimension]) * - tile_sizes_per_thread_[1]; -} - -int MlirRowReductionFusion::GetWarpsPerRow() const { - return CeilOfRatio(num_threads_[1], WarpSize()); -} - -IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap( - mlir::MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_); - auto lane_id = thread_id[1] % WarpSize(); - return GetThreadIndexingMap({thread_id[0], lane_id}, - {{thread_id[1], {0, GetWarpsPerRow() - 1}}}); -} - -IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap( - mlir::MLIRContext* ctx) const { - auto thread_id = - DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_); - // The reduced dimension is tiled; each warp writes one element to shared - // memory (from lane 0). - auto lane_id = thread_id[1] % WarpSize(); - auto warp_id = thread_id[1].floorDiv(WarpSize()); - return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}}); -} - -llvm::SmallVector MlirRowReductionFusion::EmitReduction( - int group_id, EmitterState& state) const { - const auto& reductions = reduction_heroes_[group_id]; - - HloValueMap inits = GetInits(group_id, state); - auto per_thread = - state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); - per_thread.reduction_scalars = - state.ShuffleReduce(reductions, per_thread.reduction_scalars); - - if (GetWarpsPerRow() == 1) { - // If only a single warp works on an element, we don't need to go through - // shared memory. - return EvaluateEpilogue(per_thread.reduction_scalars, - std::move(per_thread.outputs), state, group_id, - /*symbol_values=*/{}); - } - - return state.ReduceViaSharedMemory(group_id, per_thread, inits); -} - -llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( - int group_id, EmitterState& state) const { - HloValueMap inits = GetInits(group_id, state); - const auto& reductions = reduction_heroes_[group_id]; - auto per_thread = - state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); - auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, - WarpSize() / 2 / GetRowsPerWarp()); - return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, - group_id, /*symbol_values=*/{}); + for (auto [result_index, result] : llvm::enumerate(values.at(root))) { + auto& output = outputs[state.OutputIndex(root, result_index)]; + output = b.create(thread_has_output, result, output, + output_indices); + } + } + return outputs; } MlirColumnReductionFusion::MlirColumnReductionFusion( @@ -930,5 +681,254 @@ std::unique_ptr CreateMlirReductionFusion( return std::make_unique(analysis); } +MlirRowReductionFusion::MlirRowReductionFusion( + const HloFusionAnalysis& analysis) + : MlirReductionFusion(analysis) { + CHECK(reduction_dimensions_.is_row_reduction); + Vector3 shape = reduction_dimensions_.dimensions; + CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); + constexpr int64_t kMinorReducedElementsPerThread = 16; + + int64_t num_threads_kept = 1; + int64_t num_threads_reduced = [&] { + int64_t max_block_size = + MinThreadsXRowReduction(first_reduce_->GetModule()->config()); + return std::min(max_block_size, + RoundUpTo(CeilOfRatio(shape[kRowMinorReduced], + kMinorReducedElementsPerThread), + WarpSize())); + }(); + + // If we're limited by the size of the x dimension, add additional parallelism + // in the y dimension. The code generator doesn't currently support + // parallelizing the z dimension (major reduced dimensions). The general + // recommendation is to use between 128 and 512 threads, so we just go for + // 256. See https://forums.developer.nvidia.com/t/55529 + constexpr int64_t kThreadsPerBlockTarget = 256; + if (num_threads_reduced * 2 <= kThreadsPerBlockTarget) { + int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; + // Increase the size of the y dimension as long as there's remaining + // parallelism. + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + } + + int vector_size = GetVectorSizeForMlir(analysis, reduction_dimensions_, + num_threads_reduced); + num_threads_ = {num_threads_kept, num_threads_reduced}; + // TODO(jreiffers): Get rid of `vector_size` in here. + input_shape_ = {shape[0], shape[1], shape[2] / vector_size, vector_size}; + // TODO(jreiffers): Tighten ranges based on constraints when simplifying + // instead of using min here. For example, based on + // + // s1 in [0, 127] + // d0 floordiv 32 + s1 * 32 in [0, 63] + // + // Tighten the bound of s1 to [0, 1]. + int minor_reduced_tile_size = + std::min(kMinorReducedElementsPerThread / vector_size, + CeilOfRatio(input_shape_[2], num_threads_[1])); + + tile_sizes_per_thread_ = {shape[0], minor_reduced_tile_size, vector_size}; + tile_sizes_per_block_ = {num_threads_kept, + minor_reduced_tile_size * num_threads_reduced}; + num_blocks_ = {CeilOfRatio(input_shape_[1], tile_sizes_per_block_[0]), + CeilOfRatio(input_shape_[2], tile_sizes_per_block_[1])}; +} + +IndexingMap MlirRowReductionFusion::ComputeReductionInputIndexing( + mlir::MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); + auto block_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_); + auto major_reduced = getAffineSymbolExpr(0, ctx); + auto minor_reduced = getAffineSymbolExpr(1, ctx); + auto vector_index = getAffineSymbolExpr(2, ctx); + + SmallVector indices{ + major_reduced, + block_id[0] * tile_sizes_per_block_[0] + thread_id[0], + block_id[1] * tile_sizes_per_block_[1] + + (minor_reduced * num_threads_[1]) + thread_id[1], + vector_index, + }; + + auto map = GetIndexingMap(indices, tile_sizes_per_thread_); + for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) { + map.AddConstraint(result, {0, input_dim - 1}); + } + return map; +} + +IndexingMap MlirRowReductionFusion::ComputeReductionOutputIndexing( + MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); + auto block_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(3, ctx), num_blocks_); + IndexingMap projected_index = + GetIndexingMap(block_id[0] * tile_sizes_per_block_[0] + thread_id[0]); + projected_index.AddConstraint(thread_id[1], {0, 0}); + return projected_index; +} + +int MlirRowReductionFusion::GetWarpsPerRow() const { + return CeilOfRatio(num_threads_[1], WarpSize()); +} + +IndexingMap MlirRowReductionFusion::GetSharedMemoryReductionReadMap( + mlir::MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_); + auto lane_id = thread_id[1] % WarpSize(); + return GetThreadIndexingMap({thread_id[0], lane_id}, + {{thread_id[1], {0, GetWarpsPerRow() - 1}}}); +} + +IndexingMap MlirRowReductionFusion::GetSharedMemoryWriteMap( + mlir::MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(getAffineDimExpr(0, ctx), num_threads_); + // The reduced dimension is tiled; each warp writes one element to shared + // memory (from lane 0). + auto lane_id = thread_id[1] % WarpSize(); + auto warp_id = thread_id[1].floorDiv(WarpSize()); + return GetThreadIndexingMap({thread_id[0], warp_id}, {{lane_id, {0, 0}}}); +} + +llvm::SmallVector MlirRowReductionFusion::EmitReduction( + int group_id, EmitterState& state) const { + const auto& reductions = reduction_heroes_[group_id]; + + HloValueMap inits = GetInits(group_id, state); + auto per_thread = + state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); + per_thread.reduction_scalars = + state.ShuffleReduce(reductions, per_thread.reduction_scalars); + + if (GetWarpsPerRow() == 1) { + // If only a single warp works on an element, we don't need to go through + // shared memory. + return EvaluateEpilogue(per_thread.reduction_scalars, + std::move(per_thread.outputs), state, group_id, + /*symbol_values=*/{}); + } + + return state.ReduceViaSharedMemory(group_id, per_thread, inits); +} + +MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( + const HloFusionAnalysis& analysis) + : MlirReductionFusion(analysis) { + CHECK(reduction_dimensions_.is_row_reduction); + Vector3 shape = reduction_dimensions_.dimensions; + int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); + input_shape_ = {shape[0], shape[1], shape[2]}; + CHECK_GT(rows_per_warp, 1); + + auto compute_block_size = [&](int vector_size) { + int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + num_threads_ = {num_threads_kept, num_threads_reduced}; + tile_sizes_per_thread_ = {shape[0], vector_size}; + num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; + }; + + // Compute the launch grid without vectorization. We use the results to + // compute the vectorized launch grid. + compute_block_size(1); + + // Normally, we only consider input types for vectorization. However, in + // multi-row reductions, the input:output ratio is much higher, so we consider + // both inputs and outputs. + int smallest_input_or_output_bits = + std::min(analysis.input_output_info().smallest_input_dtype_bits, + analysis.input_output_info().smallest_output_dtype_bits); + + // This vector size is always valid: we know that the reduced dimension is a + // power of 2, since otherwise RowReductionGetRowsPerWarp would have + // returned 1. + // Our codegen can't currently deal with vectorization across rows, so we + // limit the vector size to the size of the row. Note that this emitter + // essentially reverts to the loop emitter in this case, except for side + // outputs. + int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), + 32 / smallest_input_or_output_bits); + + // We target 8 warps per block, which means there could be up to 8 blocks per + // SM, but we have no good way of knowing. In practice, enabling vectorization + // for decently sized reductions at least does not hurt. + if (num_blocks_.front() > analysis.device_info().core_count() && + vector_size > 1) { + compute_block_size(vector_size); + } +} + +IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( + mlir::MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); + auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx) + : mlir::getAffineDimExpr(3, ctx); + auto major_reduced = getAffineSymbolExpr(0, ctx); + auto vector_index = getAffineSymbolExpr(1, ctx); + + SmallVector indices{ + major_reduced, block_id * num_threads_[0] + thread_id[0], + thread_id[1] * tile_sizes_per_thread_[1] + vector_index}; + + auto map = GetIndexingMap(indices, tile_sizes_per_thread_); + for (auto [result, input_dim] : llvm::zip(indices, input_shape_)) { + map.AddConstraint(result, {0, input_dim - 1}); + } + return map; +} + +IndexingMap MlirMultiRowReductionFusion::ComputeReductionOutputIndexing( + MLIRContext* ctx) const { + auto thread_id = + DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), num_threads_); + auto block_id = num_blocks_.front() == 1 ? mlir::getAffineConstantExpr(0, ctx) + : mlir::getAffineDimExpr(3, ctx); + IndexingMap projected_index = + GetIndexingMap(block_id * num_threads_[0] + thread_id[0]); + projected_index.AddConstraint(thread_id[1] % (WarpSize() / GetRowsPerWarp()), + {0, 0}); + // We don't need a constraint on the loop dimensions, because they are removed + // by GetIndexingMap (since they don't show up in the output index + // computation). + return projected_index; +} + +int MlirMultiRowReductionFusion::GetRowsPerWarp() const { + return RowReductionGetRowsPerWarp( + input_shape_[ReductionDimensions::kRowMinorReducedDimension]) * + tile_sizes_per_thread_[1]; +} + +llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( + int group_id, EmitterState& state) const { + HloValueMap inits = GetInits(group_id, state); + const auto& reductions = reduction_heroes_[group_id]; + auto per_thread = + state.EmitPerThreadElements(group_id, inits, state.FusionOutputs()); + auto reduced = state.ShuffleReduce(reductions, per_thread.reduction_scalars, + WarpSize() / 2 / GetRowsPerWarp()); + return EvaluateEpilogue(reduced, std::move(per_thread.outputs), state, + group_id, /*symbol_values=*/{}); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc index 6ba7431530309e..479852851322c0 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir_test.cc @@ -23,13 +23,13 @@ limitations under the License. #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" -#include "absl/strings/substitute.h" #include "absl/types/span.h" #include "xla/error_spec.h" #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { @@ -54,69 +54,8 @@ class ReductionTest : public MlirEmitterTestBase { } }; -using MlirRowReductionTest = ReductionTest; -using MlirColumnReductionTest = ReductionTest; using MlirMultiRowReductionTest = ReductionTest; -constexpr std::string_view kVariadicRowReduction = R"( - Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_rhs.0 = f32[] parameter(1) - scalar_lhs.1 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) - add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add.0, add.1) - } - fused_computation { - param_0 = f32[2, 3, 2048] parameter(0) - param_1 = f32[2, 3, 2048] parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[2, 3], f32[2, 3]) - reduce(param_0, param_1, param_2, param_2), dimensions={2}, to_apply=Add - } - ENTRY main { - a = f32[2, 3, 2048] parameter(0) - b = f32[2, 3, 2048] parameter(1) - c = f32[] constant(0) - ROOT fusion = (f32[2, 3], f32[2, 3]) fusion(a, b, c), - kind=kInput, calls=fused_computation - })"; - -constexpr std::string_view kF64RowReduction = R"( - Add { - lhs = f64[] parameter(0) - rhs = f64[] parameter(1) - ROOT add = f64[] add(lhs, rhs) - } - fused_computation { - param_0 = f64[100,128] parameter(0) - param_1 = f64[] parameter(1) - ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f64[100,128] parameter(0) - c = f64[] constant(0) - ROOT fusion = f64[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - -constexpr auto kRowReductionMinorAndMajor = R"( - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[7,100,128] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=Add - } - ENTRY main { - a = f32[7,100,128] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - constexpr auto kMultiRowReductionX8 = R"( Add { lhs = f32[] parameter(0) @@ -179,181 +118,6 @@ constexpr auto kMultiRowReductionX16VectorX2 = R"( ROOT fusion = pred[76800] fusion(p0), kind=kInput, calls=fusion })"; -constexpr std::string_view kRowReductionSideOutput = R"( - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[8,2048] parameter(0) - param_1 = f32[] parameter(1) - exp = f32[8,2048] exponential(param_0) - reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp) - } - ENTRY main { - a = f32[8,2048] parameter(0) - c = f32[] constant(0) - ROOT fusion = (f32[8], f32[8,2048]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - -TEST_F(MlirRowReductionTest, VariadicRowReductionIndexing) { - auto fusion = GetEmitter(kVariadicRowReduction); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {2, 3, 2048})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {2, 3})); -} - -TEST_F(MlirRowReductionTest, VariadicRowReductionCorrectness) { - EXPECT_TRUE(RunAndCompareNoHloPasses(kVariadicRowReduction, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceEpilogue) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[8,2048] parameter(0) - param_1 = f32[] parameter(1) - reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT log = f32[8] log(reduce) - } - ENTRY main { - a = f32[8,2048] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[8] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: pure_call @Add_add - // CHECK: shuffle_reduce - // CHECK: allocate_shared - // CHECK: sync_threads - // CHECK: shuffle_reduce - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceMOFEpilogue) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - Mul { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT mul = f32[] multiply(lhs, rhs) - } - fused_computation { - param_0 = f32[8,1024] parameter(0) - param_1 = f32[] parameter(1) - reduce1 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - reduce2 = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Mul - log = f32[8] log(reduce1) - abs = f32[8] abs(reduce1) - neg = f32[8] negate(reduce2) - ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) - } - ENTRY main { - a = f32[8,1024] parameter(0) - c = f32[] constant(0) - ROOT fusion = (f32[8], f32[8], f32[8]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: pure_call @Add_add - // CHECK-DAG: shuffle_reduce @Add_add - // CHECK-DAG: pure_call @Mul_mul - // CHECK-DAG: shuffle_reduce @Mul_mul - // CHECK: allocate_shared - // CHECK: allocate_shared - // CHECK: sync_threads - // CHECK-DAG: shuffle_reduce @Add_add - // CHECK-DAG: shuffle_reduce @Mul_mul - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReduceMOFGroups) { - constexpr auto kHloString = R"( - %add_f32 { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add = f32[] add(%x, %y) - } - - %fused_computation { - %param0 = f32[1024] parameter(0) - %param1 = f32[1024] parameter(1) - %constant0 = f32[] constant(0) - %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 - %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 - ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) - } - - ENTRY %cluster { - %param0 = f32[1024] parameter(0) - %param1 = f32[1024] parameter(1) - ROOT %fusion = (f32[], f32[]) - fusion(%param0, %param1), kind=kInput, calls=%fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: scf.index_switch %block_id_y - // CHECK: case 1 { - // CHECK: default { - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, F64RowReductionIndexing) { - auto fusion = GetEmitter(kF64RowReduction); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - /*shape=*/{100, 128})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), - /*shape=*/{100})); -} - -TEST_F(MlirRowReductionTest, F64RowReductionIr) { - // This reduction is small enough not to require shared memory. - TF_ASSERT_OK(EmitAndCheckIR(kF64RowReduction, R"( - // CHECK-NOT: allocate_shared - )")); -} - -TEST_F(MlirRowReductionTest, F64RowReductionCorrectness) { - EXPECT_TRUE(RunAndCompareNoHloPasses(kF64RowReduction, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorIndexing) { - auto fusion = GetEmitter(kRowReductionMinorAndMajor); - - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - /*shape=*/{7, 100, 128})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), - /*shape=*/{100})); -} - -TEST_F(MlirRowReductionTest, RowReductionMinorAndMajorCorrectness) { - EXPECT_TRUE( - RunAndCompareNoHloPasses(kRowReductionMinorAndMajor, ErrorSpec{1e-3})); -} - TEST_F(MlirMultiRowReductionTest, MultiRowReductionIndexing) { auto fusion = GetEmitter(kMultiRowReductionX8); @@ -379,207 +143,6 @@ TEST_F(MlirMultiRowReductionTest, MultiRowReductionCorrectness) { EXPECT_TRUE(RunAndCompareNoHloPasses(kMultiRowReductionX8, ErrorSpec{1e-3})); } -TEST_F(MlirRowReductionTest, NonPowerOfTwoRowReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[100,568] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[100,568] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[100] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0)> - // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512)> - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index - // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] - // CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] - // CHECK-NOT: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]] in [0, 1], %thread_id_x in [0, 255])[%[[I]] in [0, 3]] - // CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) - // CHECK: scf.if - // CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]] in [0, 1], %thread_id_x in [0, 255]) - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirMultiRowReductionTest, NonTrivialEpilogueCorrectness) { - constexpr auto kHloString = R"( - HloModule module - add { - p0 = f64[] parameter(0) - p1 = f64[] parameter(1) - ROOT add = f64[] add(p0, p1) - } - fusion { - %p0 = f64[4] parameter(0) - %p1 = f64[4] parameter(1) - %c0 = f64[] constant(-inf) - %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add - %bc0 = f64[4] broadcast(reduce0), dimensions={} - %compare0 = pred[4] compare(p1, bc0), direction=EQ - %c1 = f64[] constant(0) - %bc1 = f64[4] broadcast(c1), dimensions={} - %select.3.1 = f64[4] select(compare0, p0, bc1) - %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add - %convert0 = f64[4] convert(compare0) - %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add - ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2) - } - ENTRY main { - %p0 = f64[4] parameter(0) - %p1 = f64[4] parameter(1) - ROOT %fusion = (f64[], f64[], f64[]) fusion(%p0, %p1), kind=kInput, - calls=fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, SideOutputIndexing) { - auto fusion = GetEmitter(kRowReductionSideOutput); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {8, 2048})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {8})); - TF_EXPECT_OK( - TestBijection(*fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context_), - {8, 2048})); // Side output. -} - -TEST_F(MlirRowReductionTest, SideOutputIr) { - TF_ASSERT_OK(EmitAndCheckIR(kRowReductionSideOutput, R"( - // CHECK: @fused_computation - // CHECK: scf.for - // CHECK: scf.for - // CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp - // CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]] - )")); -} - -TEST_F(MlirRowReductionTest, SideOutputCorrectness) { - EXPECT_TRUE( - RunAndCompareNoHloPasses(kRowReductionSideOutput, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, UnsignedSideOutputCorrectness) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = u32[] parameter(0) - rhs = u32[] parameter(1) - ROOT add = u32[] add(lhs, rhs) - } - fused_computation { - param_0 = u32[8,2048] parameter(0) - param_1 = u32[] parameter(1) - add = u32[8,2048] add(param_0, param_0) - reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=Add - ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add) - } - ENTRY main { - a = u32[8,2048] parameter(0) - c = u32[] constant(0) - ROOT fusion = (u32[8], u32[8,2048]) fusion(a, c), kind=kInput, - calls=fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, BroadcastSideOutputCorrectness) { - constexpr auto kHloString = R"( - %add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - %fusion { - %p0 = f32[6,6] parameter(0) - %c0 = f32[] constant(0) - %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add - %broadcast = f32[6,6] broadcast(%reduce), dimensions={} - ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce) - } - ENTRY main { - %p0 = f32[6,6] parameter(0) - ROOT %fusion = (f32[6,6], f32[]) fusion(%p0), kind=kInput, calls=%fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, VariadicMOFCorrectness) { - constexpr auto kHloString = R"( - %reducer1 { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - %reducer2 { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - p2 = f32[] parameter(2) - p3 = f32[] parameter(3) - add0 = f32[] add(p0, p2) - add1 = f32[] add(p1, p3) - ROOT tuple = (f32[], f32[]) tuple(add0, add1) - } - %fusion { - %p0 = f32[6,6] parameter(0) - %c0 = f32[] constant(0) - %neg = f32[6,6] negate(%p0) - %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1 - %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2 - ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg) - } - ENTRY main { - %p0 = f32[6,6] parameter(0) - ROOT %fusion = (f32[], (f32[], f32[]), f32[6,6]) fusion(%p0), kind=kInput, calls=%fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, OutputLayoutCorrectness) { - constexpr std::string_view kHloString = R"( - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[17,19,127] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add - } - - ENTRY entry { - %input = f32[17,19,127] parameter(0) - ROOT %fusion = f32[17,19]{0,1} fusion(%input), kind=kInput, calls=fusion - })"; - - auto fusion = GetEmitter(kHloString); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context_), - {17, 19, 127})); - TF_EXPECT_OK(TestBijection( - *fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_), {17, 19})); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirMultiRowReductionTest, TwoGroups) { auto module = ParseAndReturnVerifiedModule(R"( add { @@ -604,7 +167,7 @@ TEST_F(MlirMultiRowReductionTest, TwoGroups) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirMultiRowReductionFusion fusion(analysis); EXPECT_THAT(fusion.GetGroups().grouped_roots, @@ -635,231 +198,12 @@ TEST_F(MlirMultiRowReductionTest, OneGroup) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirMultiRowReductionFusion mlir_fusion(analysis); EXPECT_THAT(mlir_fusion.GetGroups().grouped_roots, SizeIs(1)); } -constexpr absl::string_view kColumnVectorizationTemplate = R"( - add { - b = $0[] parameter(1) - a = $0[] parameter(0) - ROOT out = $0[] add(a, b) - } - fusion { - %p0 = $0[192,64,1536] parameter(0) - %p1 = $0[] parameter(1) - ROOT reduce = $0[192,1536] reduce(p0, p1), dimensions={1}, to_apply=add - } - ENTRY entry { - %p0 = $0[192,64,1536] parameter(0) - %p1 = $0[] parameter(1) - ROOT %fusion = $0[192,1536] fusion(p0, p1), kind=kInput, calls=fusion - })"; - -TEST_F(MlirColumnReductionTest, ColumnReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[13,1051,321] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[13,1051,321] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[13,321] fusion(a, c), kind=kInput, calls=fused_computation - })"; - - auto module = ParseAndReturnVerifiedModule(kHloString).value(); - auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); - MlirColumnReductionFusion fusion(analysis); - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( - d3 floordiv 11, - d0 floordiv 32 + s0 * 32, - (d3 mod 11) * 32 + d0 mod 32 - ) - domain: - d0 in [0, 1023] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 32] - s1 in [0, 0] - (d3 mod 11) * 32 + d0 mod 32 in [0, 320] - d0 floordiv 32 + s0 * 32 in [0, 1050] - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0] -> ( - d3 floordiv 11, (d3 mod 11) * 32 + d0 floordiv 32 - ) - domain: - d0 in [0, 992] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 142] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 0] - (d3 mod 11) * 32 + d0 floordiv 32 in [0, 320] - d0 mod 32 in [0, 0] - )")); - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: xla_gpu.pure_call @Add_add - // CHECK: allocate_shared - // CHECK: tensor.insert - // CHECK: sync_threads - // CHECK: predicated_extract - // CHECK: shuffle_reduce - // CHECK: predicated_insert - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, SmallColumnReduction) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[3,128,4] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[3,4] reduce(param_0, param_1), dimensions={1}, to_apply=Add - } - ENTRY main { - a = f32[3,128,4] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[3,4] fusion(a, c), kind=kInput, calls=fused_computation - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, MixedIndexing) { - constexpr auto kHloString = R"( - HloModule module - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - fusion { - %param_0 = f32[64,128] parameter(0) - %constant_0 = f32[] constant(0) - %reduce.1 = f32[128] reduce(f32[64,128] %param_0, f32[] %constant_0), dimensions={0}, to_apply=%add - %neg = f32[64,128] negate(f32[64,128] %param_0) - %bitcast = f32[8,8,128]{2,1,0} bitcast(f32[64,128] %neg) - %reduce.2 = f32[128] reduce(f32[8,8,128]{2,1,0} %bitcast, f32[] %constant_0), dimensions={0,1}, to_apply=%add - ROOT %tuple.12 = (f32[128], f32[128]) tuple(f32[128] %reduce.1, f32[128] %reduce.2) - } - ENTRY entry { - %param_0 = f32[64,128] parameter(0) - ROOT %fusion = (f32[128], f32[128]) fusion(%param_0), kind=kInput, calls=fusion - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, ColumnReductionVectorizationCorrectness) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - Add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) - } - fused_computation { - param_0 = f32[2048,16384] parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce = f32[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add - } - ENTRY main { - a = f32[2048,16384] parameter(0) - c = f32[] constant(0) - ROOT fusion = f32[16384] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: vector<2xf32> - )")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirColumnReductionTest, ColumnReductionVectorization_v4) { - constexpr auto kHloString = R"( - HloModule Test, is_scheduled=true - Add { - lhs = s16[] parameter(0) - rhs = s16[] parameter(1) - ROOT add = s16[] add(lhs, rhs) - } - fused_computation { - param_0 = s16[2048,16384] parameter(0) - param_1 = s16[] parameter(1) - ROOT reduce = s16[16384] reduce(param_0, param_1), dimensions={0}, to_apply=Add - } - ENTRY main { - a = s16[2048,16384] parameter(0) - c = s16[] constant(0) - ROOT fusion = s16[16384] fusion(a, c), kind=kInput, calls=fused_computation - })"; - TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: vector<4xi16> - )")); - // We don't use RunAndCompareNoHloPasses because the interpreter is too slow - // for this input. -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v2) { - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f32"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 2 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_v4) { - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f16"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 4 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_f64) { - // Verifies that we do not use the vectorized indexing for f64. - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "f64"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 1 /* vector size */)); -} - -TEST_F(MlirColumnReductionTest, ThreadIndexingColumn_Complex) { - // Verifies that we do not use the vectorized indexing for complex types. - const auto hlo_string = absl::Substitute(kColumnVectorizationTemplate, "c64"); - auto fusion = GetEmitter(hlo_string); - EXPECT_THAT(GetLoopTripCounts(*fusion->ComputeThreadIdToInputIndexing( - 0, 0, &mlir_context_)), - ElementsAre(2 /* major reduced */, 1 /* vector size */)); -} - TEST_F(MlirMultiRowReductionTest, VectorizedX4Indexing) { auto fusion = GetEmitter(kMultiRowReductionX2VectorX4); @@ -883,61 +227,6 @@ TEST_F(MlirMultiRowReductionTest, VectorizedX4Correctness) { RunAndCompareNoHloPasses(kMultiRowReductionX2VectorX4, ErrorSpec{1e-3})); } -TEST_F(MlirRowReductionTest, LargeToUnit) { - // Regression test for a bug where not all threads in the warp produced a - // valid value for the final warp shuffle. - constexpr auto kHloString = R"( - and { - p0 = pred[] parameter(0) - p1 = pred[] parameter(1) - ROOT and = pred[] and(p0, p1) - } - - %fused_reduce { - c1 = pred[] constant(true) - p0 = pred[10000] broadcast(c1), dimensions={} - ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and - } - )"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(MlirRowReductionTest, MOFTwoVariadic) { - // Regression test for a compilation crash with a MOF with two variadic - // reductions. - constexpr auto kHloString = R"( - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - p2 = f32[] parameter(2) - p3 = f32[] parameter(3) - a = f32[] add(p0, p2) - b = f32[] add(p1, p3) - ROOT out = (f32[], f32[]) tuple(a, b) - } - - fused_reduce { - p0 = f32[3,2] parameter(0) - p1 = f32[3,2] parameter(1) - c0 = f32[] constant(0) - iota0 = f32[3,2] iota(), iota_dimension=1 - iota1 = f32[3,2] iota(), iota_dimension=1 - reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1}, - to_apply=add - reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1}, - to_apply=add - ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1) - } - - ENTRY main { - p0 = f32[3,2] parameter(0) - p1 = f32[3,2] parameter(1) - ROOT fusion = ((f32[3], f32[3]), (f32[3], f32[3])) fusion(p0, p1), - kind=kInput, calls=fused_reduce - })"; - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc index a281c0e92f2f0f..85e1e504e79d73 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.cc @@ -39,9 +39,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/primitive_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" @@ -62,7 +63,6 @@ using mlir::OpBuilder; using mlir::Value; using mlir::ValueRange; using mlir::func::ReturnOp; -using mlir::tensor::InsertOp; using mlir_converter::CallTargetProvider; using mlir_converter::PartitionedComputations; using mlir_converter::ProvideParameter; @@ -174,7 +174,8 @@ mlir::Value EmitScatterComputation( auto reduced_val = mlir_converter::InlineBlock( b, reducer.getBody().front(), {operand_elem, update_elem})[0]; - return b.create(reduced_val, output_tensor, indices); + return b.create(reduced_val, output_tensor, + indices); } auto atomic_rmw = b.create(output_tensor, indices); mlir::OpBuilder body_builder = atomic_rmw.getBodyBuilder(); diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h index de9743a079a849..3efaa0a827fbea 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir.h @@ -25,7 +25,6 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/loop.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/hlo_fusion_analysis.h" diff --git a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc index 869d2335001825..2e9a11a78c2b84 100644 --- a/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/scatter_mlir_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -77,7 +77,7 @@ TEST_F(MlirScatterFusionTest, ThreadIdIndexing) { thread_id_printer_.SetSymbolName(2, "index_id"); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirScatterFusion fusion(analysis); constexpr auto kUpdatesIndexing = R"( @@ -187,8 +187,8 @@ TEST_F(MlirScatterFusionTest, Scatter_UniqueIndices) { } )"; TF_ASSERT_OK(EmitAndCheckIR(kHloString, R"( - // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 2)> - // CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 2)> + // CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 2) + // CHECK: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 2) // CHECK-LABEL: func.func @fused_computation( // CHECK-SAME: %[[OPERAND:[a-zA-Z0-9]*]]: tensor<10x5xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/tests/BUILD new file mode 100644 index 00000000000000..d3e3b665e75d3b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/BUILD @@ -0,0 +1,19 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["**/*.hlo"]), + cfg = "//xla:lit.cfg.py", + default_tags = ["requires-gpu-sm80-only"], + tools = [ + "//xla/service/gpu/fusions/tools:fusion_to_mlir", + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "//xla/service/gpu/fusions/tools:test_correctness", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo new file mode 100644 index 00000000000000..d0dd73d59081cd --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant.hlo @@ -0,0 +1,19 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s --dump-input=always +// RUN: test_correctness %s --bijection_outputs=broadcast + +bcast { + zero = bf16[] constant(0) + ROOT broadcast = bf16[2,16,48]{2,1,0} broadcast(zero), dimensions={} +} + +// CHECK-DAG: #[[MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 6) +// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 48) mod 16) +// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 48) +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<2x16x48xbf16> +// CHECK: %[[THREAD_ID:.*]] = gpu.thread_id +// CHECK: %[[BLOCK_ID:.*]] = gpu.block_id +// CHECK: %[[I0:.*]] = xla_gpu.apply_indexing #[[MAP0]] +// CHECK: %[[I1:.*]] = xla_gpu.apply_indexing #[[MAP1]] +// CHECK: %[[I2:.*]] = xla_gpu.apply_indexing #[[MAP2]] +// CHECK: %[[CST:.*]] = arith.constant 0.000 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ARG0]][%[[I0]], %[[I1]], %[[I2]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo new file mode 100644 index 00000000000000..2f6a5aa41c664f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/complex.hlo @@ -0,0 +1,28 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + %p0 = f32[2]{0} parameter(0) + %p1 = f32[2]{0} parameter(1) + %p2 = c64[2]{0} parameter(2) + %complex = c64[2] complex(%p0, %p1) + %add = c64[2] add(%complex, %p2) + %cst = c64[2]{0} constant({(2.0, 0.0), (0.0, 2.0)}) + ROOT %mul = c64[2] multiply(%add, %cst) +} + +// CHECK: func.func @main +// CHECK-NEXT: gpu.thread_id +// CHECK-NEXT: pure_call @fused_computation_mul +// CHECK-NEXT: tensor.insert +// CHECK-NEXT: return + +// CHECK: func.func private @fused_computation_mul +// CHECK-NEXT: arith.constant +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.create +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.add +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: complex.mul diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo new file mode 100644 index 00000000000000..f9976a51a3994c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/dynamic_update_slice.hlo @@ -0,0 +1,27 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s --dump-input=always +// RUN: test_correctness %s + +%fused_computation { + in = c64[2,3] parameter(0) + updates = c64[2,2] parameter(1) + i0 = s32[] parameter(2) + i1 = s32[] parameter(3) + updated = c64[2,3] dynamic-update-slice(in, updates, i0, i1) + // Add some random epilogue to prevent in-place DUS from triggering. + ROOT negated = c64[2,3] negate(updated) +} + +// CHECK: func.func @main +// CHECK-SAME: %[[IN:.*]]: tensor<2x3xcomplex> {xla.slice_index = 0 +// CHECK-SAME: %[[UPDATES:.*]]: tensor<2x2xcomplex> {xla.slice_index = 1 +// CHECK-SAME: %[[I0:.*]]: tensor {xla.slice_index = 2 +// CHECK-SAME: %[[I1:.*]]: tensor {xla.slice_index = 3 + +// No need to load i0, since its value is irrelevant. +// CHECK-NOT: tensor.extract %[[I0]] +// CHECK: tensor.extract %[[I1]] +// CHECK-NOT: tensor.extract %[[I0]] +// CHECK: scf.if +// CHECK: tensor.extract %[[UPDATES]] +// CHECK: } else { +// CHECK: tensor.extract %[[IN]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo new file mode 100644 index 00000000000000..7d7fdf79fe2b55 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/minimum_maximum.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +fused_computation { + param0 = f64[] parameter(0) + param1 = f64[] parameter(1) + + minimum = f64[] minimum(param0, param1) + maximum = f64[] maximum(param0, param1) + ROOT tuple = (f64[], f64[]) tuple(minimum, maximum) +} + +// CHECK: func.func @main +// CHECK: xla_gpu.pure_call @fused_computation_tuple +// CHECK: func.func private @fused_computation_tuple +// CHECK-DAG: arith.minimumf +// CHECK-DAG: arith.maximumf diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo new file mode 100644 index 00000000000000..017019e436d125 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_mul.hlo @@ -0,0 +1,15 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + p0 = s8[1000] parameter(0) + p1 = s8[1000] parameter(1) + cvt0 = pred[1000] convert(p0) + cvt1 = pred[1000] convert(p1) + ROOT mul = pred[1000] multiply(cvt0, cvt1) +} + +// CHECK: %[[A:.*]] = arith.cmpi ne, +// CHECK: %[[B:.*]] = arith.cmpi ne, +// CHECK: %[[R:.*]] = arith.andi %[[A]], %[[B]] +// CHECK: arith.extui %[[R]] : i1 to i8 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo new file mode 100644 index 00000000000000..0597b3590cbbd4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/pred_not.hlo @@ -0,0 +1,13 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%fused_computation { + p0 = s8[1000] parameter(0) + cvt = pred[1000] convert(p0) + ROOT not = pred[1000] not(cvt) +} + +// CHECK: %[[C0:.*]] = arith.constant 0 : i8 +// CHECK: %[[NONZERO:.*]] = arith.cmpi eq, {{.*}}, %[[C0]] +// CHECK: %[[CVT:.*]] = arith.extui %[[NONZERO]] : i1 to i8 +// CHECK: return %[[CVT]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo new file mode 100644 index 00000000000000..f77a3c38cd8ded --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_heterogeneous.hlo @@ -0,0 +1,21 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-to-inline | FileCheck %s +// RUN: test_correctness %s + +fused_computation { + param0 = f64[8] parameter(0) + param1 = f64[8] parameter(1) + + minimum = f64[8] minimum(param0, param1) + maximum = f64[8] maximum(param0, param1) + bc = f64[2, 4] bitcast(maximum) + ROOT tuple = (f64[8], f64[2,4]) tuple(minimum, bc) +} + +// CHECK: #[[MAJOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 4), +// CHECK: #[[MINOR:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 mod 4), + +// CHECK: %[[TID:.*]] = gpu.thread_id +// CHECK-DAG: %[[MAJOR_IDX:.*]] = xla_gpu.apply_indexing #[[MAJOR]] +// CHECK-DAG: %[[MINOR_IDX:.*]] = xla_gpu.apply_indexing #[[MINOR]] +// CHECK-DAG: tensor.insert {{.*}}[%[[MAJOR_IDX]], %[[MINOR_IDX]]] +// CHECK-DAG: tensor.insert {{.*}}[%[[TID]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo new file mode 100644 index 00000000000000..ac5f26682ec356 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/tuple_nested.hlo @@ -0,0 +1,35 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_lhs.1 = f32[] parameter(1) + scalar_rhs.0 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add = f32[] add(scalar_lhs.0, scalar_rhs.0) + mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add, mul) +} + +fused_computation { + param_0 = f32[3,4,5]{2,1,0} parameter(0) + param_1 = f32[3,4,5]{2,1,0} parameter(1) + param_2 = f32[] parameter(2) + param_3 = f32[4] parameter(3) + reduce = (f32[4], f32[4]) reduce(f32[3,4,5]{2,1,0} param_0, + f32[3,4,5]{2,1,0} %param_1, f32[] param_2, f32[] param_2), + dimensions={0,2}, to_apply=add + log = f32[4] log(param_3) + ROOT tuple = ((f32[4], f32[4]), f32[4]) tuple(reduce, log) +} + +// CHECK: @main +// CHECK: %[[R0:.*]], %[[R1:.*]], %[[R2:.*]] = xla_gpu.pure_call @fused_computation_tuple +// CHECK-DAG: tensor.insert %[[R0]] +// CHECK-DAG: tensor.insert %[[R1]] +// CHECK-DAG: tensor.insert %[[R2]] + +// CHECK: @fused_computation_tuple +// CHECK: %[[REDUCTION:.*]]:2 = scf.for +// CHECK: %[[LOG:.*]] = math.log +// CHECK: return %[[REDUCTION]]#0, %[[REDUCTION]]#1, %[[LOG]] \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo new file mode 100644 index 00000000000000..b16b005897b3f7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/two_users.hlo @@ -0,0 +1,30 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +// We have two users of add and sub, but they use consistent indexing, so they +// can be generated as a single function (fused_computation_atan2). +%fused_computation { + %p0 = f32[2] parameter(0) + %p1 = f32[2] parameter(1) + %add = f32[2] add(%p0, %p1) + %sub = f32[2] subtract(%p0, %p1) + %mul = f32[2] multiply(%add, %sub) + %div = f32[2] divide(%add, %sub) + ROOT %atan2 = f32[2] atan2(%mul, %div) +} + +// CHECK: func.func @main +// CHECK-NEXT: gpu.thread_id +// CHECK-NEXT: pure_call @fused_computation_atan2 +// CHECK-NEXT: tensor.insert +// CHECK-NEXT: return + +// CHECK: func.func private @fused_computation_atan2 +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: addf +// CHECK-NEXT: subf +// CHECK-NEXT: mulf +// CHECK-NEXT: divf +// CHECK-NEXT: atan2 +// CHECK-NEXT: return \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo new file mode 100644 index 00000000000000..8ac83ced80af31 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/variadic_reduce.hlo @@ -0,0 +1,31 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_lhs.1 = f32[] parameter(1) + scalar_rhs.0 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add = f32[] add(scalar_lhs.0, scalar_rhs.0) + mul = f32[] multiply(scalar_lhs.1, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add, mul) +} + +fused_computation { + param_0 = f32[3,4,5] parameter(0) + param_1 = f32[3,4,5] parameter(1) + c = f32[] constant(0) + ROOT d.1 = (f32[4], f32[4]) reduce(param_0, param_1, c, c), dimensions={0,2}, + to_apply=add +} + +// CHECK: func @main( +// CHECK: %[[TID_X:.*]] = gpu.thread_id x +// CHECK: %[[SCALARS_0:.*]], %[[SCALARS_1:.*]] = xla_gpu.pure_call @fused_computation_d_1 +// CHECK: %[[INSERTED_1:.*]] = tensor.insert %[[SCALARS_0]] into %{{.*}}[%[[TID_X]]] +// CHECK: %[[INSERTED_2:.*]] = tensor.insert %[[SCALARS_1]] into %{{.*}}[%[[TID_X]]] +// CHECK: return %[[INSERTED_1]], %[[INSERTED_2]] + +// CHECK: func private @fused_computation_d_1 +// CHECK: %[[RET:.*]]:2 = func.call @add_t +// CHECK: yield %[[RET]]#0, %[[RET]]#1 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo new file mode 100644 index 00000000000000..75510894bcadd2 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x1_too_small.hlo @@ -0,0 +1,10 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg + +neg { + %input = f32[20] parameter(0) + ROOT neg = f32[20] negate(%input) +} + +// CHECK-NOT: vector. +// CHECK: tensor.extract diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo new file mode 100644 index 00000000000000..549231c7aa4447 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/vectorize_x4.hlo @@ -0,0 +1,13 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline -xla-gpu-test-vectorize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=neg:0 --bijection_outputs=neg + +neg { + %input = f32[20,40,300] parameter(0) + ROOT neg = f32[20,40,300] negate(%input) +} + +// CHECK-NOT: tensor. +// CHECK: vector.transfer_read {{.*}} vector<4xf32> +// CHECK-NOT: tensor. +// CHECK: vector.transfer_write {{.*}} vector<4xf32> +// CHECK-NOT: tensor. \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo new file mode 100644 index 00000000000000..1646aded57fdf4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/mof_heteorgeneous_input_shapes.hlo @@ -0,0 +1,19 @@ +// RUN: test_correctness %s --bijection_inputs=reduce.1:0 --bijection_inputs=reduce.2:0 --bijection_outputs=reduce.1 --bijection_outputs=reduce.2 + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fusion { + %param_0 = f32[64,128] parameter(0) + %constant_0 = f32[] constant(0) + %reduce.1 = f32[128] reduce(param_0, constant_0), dimensions={0}, + to_apply=%add + %neg = f32[64,128] negate(param_0) + %bitcast = f32[8,8,128] bitcast(neg) + %reduce.2 = f32[128] reduce(bitcast, constant_0), dimensions={0,1}, + to_apply=%add + ROOT %tuple = (f32[128], f32[128]) tuple(reduce.1, reduce.2) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo new file mode 100644 index 00000000000000..e7ae070f2938e7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/odd_sizes.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[13,1051,321] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[13,321] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// CHECK: xla_gpu.pure_call @add_add +// CHECK: allocate_shared +// CHECK: tensor.insert +// CHECK: sync_threads +// CHECK: predicated_extract +// CHECK: shuffle_reduce +// CHECK: predicated_insert diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo new file mode 100644 index 00000000000000..958b391179001f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/small.hlo @@ -0,0 +1,13 @@ +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[3,128,4] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[3,4] reduce(param_0, c0), dimensions={1}, to_apply=add +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo new file mode 100644 index 00000000000000..a2a22363108b10 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_c64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = c64[] parameter(0) + rhs = c64[] parameter(1) + ROOT add = c64[] add(lhs, rhs) +} + +fused_computation { + param_0 = c64[128,64] parameter(0) + c0 = c64[] constant((0, 0)) + ROOT reduce = c64[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK-NOT: vector< \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo new file mode 100644 index 00000000000000..660664bba95f37 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x1_f64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) +} + +fused_computation { + param_0 = f64[128,64] parameter(0) + c0 = f64[] constant(0) + ROOT reduce = f64[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK-NOT: vector< \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo new file mode 100644 index 00000000000000..a142ad4a164100 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x2_f32.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[2048,64] parameter(0) + c0 = f32[] constant(0) + ROOT reduce = f32[64] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK: vector<2xf32> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo new file mode 100644 index 00000000000000..81da088974132f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_column/vectorize_x4_s16.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = s16[] parameter(0) + rhs = s16[] parameter(1) + ROOT add = s16[] add(lhs, rhs) +} + +fused_computation { + param_0 = s16[256,128] parameter(0) + c0 = s16[] constant(0) + ROOT reduce = s16[128] reduce(param_0, c0), dimensions={0}, + to_apply=add +} + +// CHECK: vector<4xi16> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo new file mode 100644 index 00000000000000..f8a9e86ff48f65 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/epilogue.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT log = f32[8] log(reduce) +} + +// CHECK: shuffle_reduce +// CHECK: allocate_shared +// CHECK: sync_threads +// CHECK: shuffle_reduce +// CHECK-NEXT: %[[OUT:.*]] = math.log +// CHECK: predicated_insert %[[OUT]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo new file mode 100644 index 00000000000000..bc841743d9d3f8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/inefficient_codegen.hlo @@ -0,0 +1,43 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-to-inline | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2 + +add { + p0 = f64[] parameter(0) + p1 = f64[] parameter(1) + ROOT add = f64[] add(p0, p1) +} + +// This fusion is valid, but we can't efficiently codegen it. +fusion { + %p0 = f64[4] parameter(0) + %p1 = f64[4] parameter(1) + %c0 = f64[] constant(-inf) + %reduce0 = f64[] reduce(p1, c0), dimensions={0}, to_apply=add + %bc0 = f64[4] broadcast(reduce0), dimensions={} + %compare0 = pred[4] compare(p1, bc0), direction=EQ + %c1 = f64[] constant(0) + %bc1 = f64[4] broadcast(c1), dimensions={} + %select.3.1 = f64[4] select(compare0, p0, bc1) + %reduce1 = f64[] reduce(select.3.1, c1), dimensions={0}, to_apply=add + %convert0 = f64[4] convert(compare0) + %reduce2 = f64[] reduce(convert0, c1), dimensions={0}, to_apply=add + ROOT %tuple.1 = (f64[], f64[], f64[]) tuple(%reduce1, reduce0, reduce2) +} + +// We read all of %p1 once from each thread, and then read one element again. +// CHECK: func.func @main +// CHECK-SAME: , %[[P1:.*]]: tensor<4xf64> {xla.slice_index = 1 : index} +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CST0:.*]] = arith.constant 0xFFF0000000000000 +// CHECK-DAG: %[[TID_X:.*]] = gpu.thread_id x + +// reduce0 in the context of reduce2 and reduce1's prologue: +// CHECK: scf.for %[[I:.*]] = %[[C0]] +// CHECK-NEXT: tensor.extract %[[P1]][%[[I]]] +// CHECK-NEXT: addf +// CHECK-NEXT: yield + +// reduce0 again, in the context of its status as a fusion hero: +// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[P1]][%[[TID_X]]] +// CHECK: %[[ADDED:.*]] = arith.addf %[[CST0]], %[[EXTRACTED]] +// CHECK: shuffle_reduce @add_add(%[[ADDED]]) to 2 diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo new file mode 100644 index 00000000000000..ee155c86e2bb54 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/large_to_unit.hlo @@ -0,0 +1,15 @@ +// Regression test for a bug where not all threads in the warp produced a valid +// value for the final warp shuffle. +// RUN: test_correctness %s + +and { + p0 = pred[] parameter(0) + p1 = pred[] parameter(1) + ROOT and = pred[] and(p0, p1) +} + +fused_reduce { + c1 = pred[] constant(true) + p0 = pred[10000] broadcast(c1), dimensions={} + ROOT reduce = pred[] reduce(p0, c1), dimensions={0}, to_apply=and +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo new file mode 100644 index 00000000000000..102e32b861e648 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/layout.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +fusion { + %input = f32[17,19,127] parameter(0) + %c0 = f32[] constant(0) + // The output is physically transposed. + ROOT reduce = f32[17,19]{0,1} reduce(%input, %c0), dimensions={2}, to_apply=add +} + +// CHECK: xla_gpu.predicated_insert {{.*}} : tensor<17x19xf32, dense<[0, 1]> : tensor<2xi64>> diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo new file mode 100644 index 00000000000000..c9481f35bf7fe3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/minor_and_major.hlo @@ -0,0 +1,20 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[7,100,128] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={0,2}, to_apply=add +} + +// Our codegen doesn't support parallelizing the major reduction dimension. In +// principle, this could be done via shared memory. +// CHECK-NOT: allocate_shared +// CHECK: shuffle_reduce +// CHECK-NOT: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo new file mode 100644 index 00000000000000..315d604b563ebe --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_epilogue.hlo @@ -0,0 +1,40 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -xla-erase-dead-functions -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce1:0 --bijection_inputs=reduce2:0 --bijection_outputs=reduce1 --bijection_outputs=reduce2 + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,1024] parameter(0) + c0 = f32[] constant(0) + c1 = f32[] constant(1) + reduce1 = f32[8] reduce(param_0, c0), dimensions={1}, to_apply=add + reduce2 = f32[8] reduce(param_0, c1), dimensions={1}, to_apply=mul + log = f32[8] log(reduce1) + abs = f32[8] abs(reduce1) + neg = f32[8] negate(reduce2) + ROOT tuple = (f32[8], f32[8], f32[8]) tuple(log, neg, abs) +} + +// CHECK-DAG: shuffle_reduce @add_add +// CHECK-DAG: shuffle_reduce @mul_mul +// CHECK: allocate_shared +// CHECK: allocate_shared +// CHECK: sync_threads +// CHECK-DAG: %[[ADDED:.*]] = xla_gpu.shuffle_reduce @add_add +// CHECK-DAG: %[[MULTIPLIED:.*]] = xla_gpu.shuffle_reduce @mul_mul +// CHECK-DAG: %[[LOG:.*]] = math.log %[[ADDED]] +// CHECK-DAG: %[[ABS:.*]] = math.absf %[[ADDED]] +// CHECK-DAG: %[[NEG:.*]] = arith.negf %[[MULTIPLIED]] +// CHECK-DAG: xla_gpu.predicated_insert %[[LOG]] +// CHECK-DAG: xla_gpu.predicated_insert %[[ABS]] +// CHECK-DAG: xla_gpu.predicated_insert %[[NEG]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo new file mode 100644 index 00000000000000..48a20334c7ea03 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_scalar_variadic.hlo @@ -0,0 +1,26 @@ +// RUN: test_correctness %s + +%reducer1 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +%reducer2 { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + p3 = f32[] parameter(3) + add0 = f32[] add(p0, p2) + add1 = f32[] add(p1, p3) + ROOT tuple = (f32[], f32[]) tuple(add0, add1) +} + +%fusion { + %p0 = f32[6,6] parameter(0) + %c0 = f32[] constant(0) + %neg = f32[6,6] negate(%p0) + %reduce1 = f32[] reduce(%neg, %c0), dimensions={0,1}, to_apply=%reducer1 + %reduce2 = (f32[], f32[]) reduce(%p0, %p0, %c0, %c0), dimensions={0,1}, to_apply=%reducer2 + ROOT %tuple = (f32[], (f32[], f32[]), f32[6,6]) tuple(%reduce1, %reduce2, %neg) +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo new file mode 100644 index 00000000000000..6d47fc6b842b9f --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/mof_variadic_variadic.hlo @@ -0,0 +1,26 @@ +// Regression test for a compilation crash with a MOF with two variadic +// reductions. +// RUN: test_correctness %s + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + p2 = f32[] parameter(2) + p3 = f32[] parameter(3) + a = f32[] add(p0, p2) + b = f32[] add(p1, p3) + ROOT out = (f32[], f32[]) tuple(a, b) +} + +fused_reduce { + p0 = f32[3,2] parameter(0) + p1 = f32[3,2] parameter(1) + c0 = f32[] constant(0) + iota0 = f32[3,2] iota(), iota_dimension=1 + iota1 = f32[3,2] iota(), iota_dimension=1 + reduce0 = (f32[3], f32[3]) reduce(p0, iota0, c0, c0), dimensions={1}, + to_apply=add + reduce1 = (f32[3], f32[3]) reduce(p1, iota1, c0, c0), dimensions={1}, + to_apply=add + ROOT tuple = ((f32[3], f32[3]), (f32[3], f32[3])) tuple(reduce0, %reduce1) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo new file mode 100644 index 00000000000000..30202d0f2613b8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/non_power_of_two.hlo @@ -0,0 +1,31 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-simplify-arith -inline -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} +fused_computation { + param_0 = f32[100,568] parameter(0) + param_1 = f32[] parameter(1) + ROOT reduce = f32[100] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// CHECK-DAG: #[[MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d1 mod 64) * 2 + s0 * 128 + d0), domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 3]> +// CHECK-DAG: #[[MAP2:.*]] = #xla_gpu.indexing_map<(d0, d1) -> ((d1 mod 64) * 2 + d0 + 512), domain: d0 in [0, 1], d1 in [0, 255]> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index + +// The full loop without bounds checks: +// CHECK: %[[FULL_TILES:.*]] = scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK-NOT: scf.if +// CHECK: xla_gpu.apply_indexing #[[MAP1]](%[[J]], %thread_id_x)[%[[I]]] + +// The tail loop: +// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%{{.*}} = %[[FULL_TILES]]) +// CHECK: scf.if +// CHECK: xla_gpu.apply_indexing #[[MAP2]](%[[J]], %thread_id_x) diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo new file mode 100644 index 00000000000000..a7e64151affdda --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/reduction_groups.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s + +%add_f32 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(%x, %y) +} + +%fused_computation { + %param0 = f32[1024] parameter(0) + %param1 = f32[1024] parameter(1) + %constant0 = f32[] constant(0) + %reduce1 = f32[] reduce(%param0, %constant0), dimensions={0}, to_apply=%add_f32 + %reduce2 = f32[] reduce(%param1, %constant0), dimensions={0}, to_apply=%add_f32 + ROOT %tuple = (f32[], f32[]) tuple(%reduce1, %reduce2) +} + +// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y +// CHECK: scf.index_switch %[[BLOCK_ID_Y]] +// CHECK: case 1 { +// CHECK: default { diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo new file mode 100644 index 00000000000000..e950e3cbdf8d83 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output.hlo @@ -0,0 +1,24 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce --bijection_outputs=exp + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +fused_computation { + param_0 = f32[8,2048] parameter(0) + param_1 = f32[] parameter(1) + exp = f32[8,2048] exponential(param_0) + reduce = f32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT t = (f32[8], f32[8,2048]) tuple(reduce, exp) +} + +// CHECK: @fused_computation +// CHECK: scf.for +// CHECK: scf.for +// CHECK: %[[SIDE_OUTPUT:.*]] = xla_gpu.pure_call @fused_computation_exp +// CHECK-NEXT: tensor.insert %[[SIDE_OUTPUT]] +// CHECK: scf.yield +// CHECK: scf.yield diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo new file mode 100644 index 00000000000000..0db1901a532501 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_broadcast.hlo @@ -0,0 +1,15 @@ +// RUN: test_correctness %s + +%add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +%fusion { + %p0 = f32[6,6] parameter(0) + %c0 = f32[] constant(0) + %reduce = f32[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add + %broadcast = f32[6,6] broadcast(%reduce), dimensions={} + ROOT %tuple = (f32[6,6], f32[]) tuple(%broadcast, %reduce) +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo new file mode 100644 index 00000000000000..5371b80532bad7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/side_output_unsigned.hlo @@ -0,0 +1,15 @@ +// RUN: test_correctness %s + +add { + lhs = u32[] parameter(0) + rhs = u32[] parameter(1) + ROOT add = u32[] add(lhs, rhs) +} + +fused_computation { + param_0 = u32[8,2048] parameter(0) + param_1 = u32[] parameter(1) + add = u32[8,2048] add(param_0, param_0) + reduce = u32[8] reduce(param_0, param_1), dimensions={1}, to_apply=add + ROOT t = (u32[8], u32[8,2048]) tuple(reduce, add) +} diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo new file mode 100644 index 00000000000000..56e326608a0826 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/small_f64.hlo @@ -0,0 +1,17 @@ +// RUN: fusion_to_mlir %s | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0 --bijection_outputs=reduce + +add { + lhs = f64[] parameter(0) + rhs = f64[] parameter(1) + ROOT add = f64[] add(lhs, rhs) +} + +fused_computation { + param_0 = f64[100,128] parameter(0) + param_1 = f64[] parameter(1) + ROOT reduce = f64[100] reduce(param_0, param_1), dimensions={1}, to_apply=add +} + +// This reduction is small enough to not require any shared memory. +// CHECK-NOT: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo new file mode 100644 index 00000000000000..b28bff49d8245d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_row/variadic_f32.hlo @@ -0,0 +1,23 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -cse -xla-gpu-simplify-arith -canonicalize | FileCheck %s +// RUN: test_correctness %s --bijection_inputs=reduce:0,1 --bijection_outputs=reduce + +add { + scalar_lhs.0 = f32[] parameter(0) + scalar_rhs.0 = f32[] parameter(1) + scalar_lhs.1 = f32[] parameter(2) + scalar_rhs.1 = f32[] parameter(3) + add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) + add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) + ROOT t = (f32[], f32[]) tuple(add.0, add.1) +} + +fused_computation { + param_0 = f32[2, 3, 2048] parameter(0) + param_1 = f32[2, 3, 2048] parameter(1) + c0 = f32[] constant(0) + ROOT reduce = (f32[2, 3], f32[2, 3]) + reduce(param_0, param_1, c0, c0), dimensions={2}, to_apply=add +} + +// CHECK: allocate_shared +// CHECK: allocate_shared diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD new file mode 100644 index 00000000000000..2886ad1f7578bf --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD @@ -0,0 +1,113 @@ +load("//xla:xla.bzl", "xla_cc_binary") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +xla_cc_binary( + name = "mlir_fusions_opt", + srcs = ["mlir_fusions_opt.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + "//xla/mlir_hlo", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/transforms:passes", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorDialect", + ], +) + +cc_library( + name = "test_lib", + testonly = 1, + srcs = ["test_lib.cc"], + hdrs = ["test_lib.h"], + deps = [ + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/fusions", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", + "//xla/stream_executor:device_description", + "//xla/tools:hlo_module_loader", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DLTIDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:VectorDialect", + ], +) + +xla_cc_binary( + name = "fusion_to_mlir", + testonly = 1, + srcs = ["fusion_to_mlir.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + ":test_lib", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_binary( + name = "test_correctness", + testonly = 1, + srcs = ["test_correctness.cc"], + visibility = ["//xla/service/gpu/fusions:__subpackages__"], + deps = [ + ":test_lib", + "//xla:debug_options_flags", + "//xla:error_spec", + "//xla:shape_util", + "//xla/service:gpu_plugin", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:indexing_analysis", + "//xla/service/gpu/model:indexing_test_utils", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc new file mode 100644 index 00000000000000..9fe41b6cb97a5b --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/fusion_to_mlir.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "llvm/Support/raw_ostream.h" +#include "xla/service/gpu/fusions/tools/test_lib.h" +#include "tsl/platform/init_main.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +absl::Status Run(const std::string& filename) { + TF_ASSIGN_OR_RETURN(auto module, LoadTestModule(filename)); + TF_ASSIGN_OR_RETURN(auto emitter_data, GetMlirFusionEmitter(*module)); + + auto context = GetMlirContextForTest(); + TF_ASSIGN_OR_RETURN(auto mlir_module, + emitter_data->emitter->CreateMLIRModule( + context, *emitter_data->fusion, "main", + /*buffer_assignment=*/nullptr)); + llvm::outs() << *mlir_module; + return absl::OkStatus(); +} + +} // namespace gpu +} // namespace xla + +int main(int argc, char** argv) { + tsl::port::InitMain(argv[0], &argc, &argv); + CHECK_EQ(argc, 2) << "Must specify an input file"; + CHECK_OK(xla::gpu::Run(argv[1])); + return 0; +} diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc similarity index 51% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc rename to third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc index 8e9fb47eef69c0..780ede0d6d061c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/mlir_fusions_opt.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLFunctionalExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -26,14 +29,15 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/transforms/passes.h" -int main(int argc, char **argv) { +int main(int argc, char** argv) { mlir::DialectRegistry registry; registry.insert + errorHandler) { + if (!options.empty()) return mlir::failure(); + + pm.addNestedPass( + xla::gpu::CreateSimplifyArithPass()); + pm.addPass(xla::gpu::CreateEraseDeadFunctionsPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { + pm.addPass(mlir::createCSEPass()); + })); + return mlir::success(); + }, + [](llvm::function_ref) {}); + mlir::registerPassPipeline( + "xla-gpu-test-vectorize", + "Test pipeline for vectorization. Should run after " + "xla-gpu-test-to-inline.", + [=](mlir::OpPassManager& pm, llvm::StringRef options, + llvm::function_ref + errorHandler) { + if (!options.empty()) return mlir::failure(); + pm.addNestedPass( + xla::gpu::CreateLowerXlaGpuLoopsToScfPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addNestedPass( + xla::gpu::CreateUnswitchLoopsPass()); + pm.addPass(mlir::createLoopInvariantCodeMotionPass()); + pm.addPass(xla::gpu::CreateFlattenTensorsPass()); + pm.addNestedPass( + xla::gpu::CreateVectorizeLoadsAndStoresPass()); + return mlir::success(); + }, + [](llvm::function_ref) {}); return mlir::failed( MlirOptMain(argc, argv, "XLA MLIR Fusion Pass Driver\n", registry)); diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc new file mode 100644 index 00000000000000..72529cd6545c4d --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_correctness.cc @@ -0,0 +1,192 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" +#include "xla/debug_options_flags.h" +#include "xla/error_spec.h" +#include "xla/service/gpu/fusions/tools/test_lib.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/service/gpu/model/indexing_map.h" +#include "xla/service/gpu/model/indexing_test_utils.h" +#include "xla/shape.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +struct Flags { + std::string input_file = ""; + float abs_error_bound = 1e-4; + float rel_error_bound = 1e-4; + std::vector>> bijection_inputs; + std::vector bijection_outputs; +}; + +Flags& flags = *new Flags; + +namespace xla { +namespace gpu { +namespace { + +using CorrectnessTest = HloTestBase; + +const Shape& GetFirstArrayShape(const Shape& shape) { + if (shape.IsArray()) { + return shape; + } + CHECK(shape.IsTuple()); + return GetFirstArrayShape(shape.tuple_shapes(0)); +} + +absl::Status TestBijection(const IndexingMap& map, + absl::Span shape) { + std::vector intervals; + for (int64_t size : shape) { + intervals.push_back({0, size - 1}); + } + auto status = VerifyBijection(map, intervals); + if (status.ok()) return status; + return absl::FailedPreconditionError( + absl::StrCat(status.message(), " in map ", map.ToString())); +} + +TEST_F(CorrectnessTest, RunAndCompare) { + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + EXPECT_TRUE(RunAndCompareNoHloPasses( + std::move(module), + ErrorSpec{flags.abs_error_bound, flags.rel_error_bound})); +} + +absl::StatusOr GetHeroIndex(absl::string_view name, + const HloFusionAnalysis& analysis) { + for (auto [index, hero] : llvm::enumerate(analysis.fusion_heroes())) { + if (hero.name() == name) { + return index; + } + } + return absl::NotFoundError(absl::StrCat("Hero ", name, " not found")); +} + +std::pair> ParseHeroAndIds( + absl::string_view hero_and_ids) { + std::pair hero_and_ids_pair = + absl::StrSplit(hero_and_ids, ':'); + std::vector ids; + for (absl::string_view id : absl::StrSplit(hero_and_ids_pair.second, ',')) { + ids.push_back(std::stoi(std::string(absl::StripAsciiWhitespace(id)))); + } + return {std::string(absl::StripAsciiWhitespace(hero_and_ids_pair.first)), + ids}; +} + +TEST_F(CorrectnessTest, InputIndexingIsBijection) { + auto context = GetMlirContextForTest(); + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module)); + for (const auto& [hero_name, ids] : flags.bijection_inputs) { + TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index, + GetHeroIndex(hero_name, *emitter_data->analysis)); + for (int64_t id : ids) { + auto indexing = emitter_data->emitter->ComputeThreadIdToInputIndexing( + hero_index, id, &context); + ASSERT_TRUE(indexing.has_value()); + TF_ASSERT_OK(TestBijection(*indexing, + emitter_data->analysis->fusion_hero(hero_index) + .GetOperand(id) + .shape() + .dimensions())) + << "Expected operand " << id << " of " << hero_name << " (root index " + << hero_index << ") to be read exactly once."; + } + } +} + +TEST_F(CorrectnessTest, OutputIndexingIsBijection) { + auto context = GetMlirContextForTest(); + TF_ASSERT_OK_AND_ASSIGN(auto module, LoadTestModule(flags.input_file)); + TF_ASSERT_OK_AND_ASSIGN(auto emitter_data, GetMlirFusionEmitter(*module)); + for (const auto& hero_name : flags.bijection_outputs) { + TF_ASSERT_OK_AND_ASSIGN(int64_t hero_index, + GetHeroIndex(hero_name, *emitter_data->analysis)); + auto indexing = emitter_data->emitter->ComputeThreadIdToOutputIndexing( + hero_index, &context); + ASSERT_TRUE(indexing.has_value()); + TF_ASSERT_OK(TestBijection( + *indexing, GetFirstArrayShape( + emitter_data->analysis->fusion_root(hero_index).shape()) + .dimensions())) + << "Expected output of " << hero_name << " (root index " << hero_index + << ") to be written exactly once."; + } +} + +} // namespace +} // namespace gpu +} // namespace xla + +int main(int argc, char* argv[]) { + std::vector flag_list = { + tsl::Flag("abs_error_bound", &flags.abs_error_bound, + "Absolute error bound."), + tsl::Flag("rel_error_bound", &flags.rel_error_bound, + "Relative error bound."), + tsl::Flag( + "bijection_inputs", + [](std::string name_and_ids) { + if (name_and_ids.empty()) return false; + flags.bijection_inputs.push_back( + xla::gpu::ParseHeroAndIds(name_and_ids)); + return true; + }, + "", + "The name of a hero followed by operand ids that should be read " + "exactly once, i.e. there's a bijection between a subset of threads " + "and the input shape. Example: 'reduction0: 0, 1'."), + tsl::Flag( + "bijection_outputs", + [](std::string name) { + if (name.empty()) return false; + flags.bijection_outputs.push_back(name); + return true; + }, + "", + "The name of a hero whose outputs should be written exactly once, " + "i.e. there's a bijection between a subset of threads and the output " + "shape.")}; + + xla::AppendDebugOptionsFlags(&flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + bool parseResult = tsl::Flags::Parse(&argc, argv, flag_list); + if (!parseResult || argc != 2) { + LOG(ERROR) << "\n" << usage; + return 1; + } + + flags.input_file = argv[1]; + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc new file mode 100644 index 00000000000000..11b82ddd517072 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc @@ -0,0 +1,118 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "xla/service/gpu/fusions/tools/test_lib.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/service/gpu/fusions/fusions.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/status_macros.h" +#include "xla/tools/hlo_module_loader.h" + +namespace xla { +namespace gpu { + +absl::StatusOr> LoadTestModule( + absl::string_view filename) { + auto module = *xla::LoadModuleFromFile(std::string(filename)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_mlir_emitter_level(4); + + int num_fusions = absl::c_count_if( + module->entry_computation()->instructions(), + [](const HloInstruction* instruction) { + return instruction->opcode() == xla::HloOpcode::kFusion; + }); + TF_RET_CHECK(num_fusions <= 1) << "HLO must contain at most one fusion"; + + if (num_fusions == 0) { + // Generate a fusion from the entry computation. + HloComputation::Builder builder("generated_main"); + std::vector params; + for (const auto* param : + module->entry_computation()->parameter_instructions()) { + params.push_back(*builder.AddParameter(param->Clone(/*suffix=*/""))); + } + builder.AddInstruction(HloInstruction::CreateFusion( + module->entry_computation()->root_instruction()->shape(), + HloInstruction::FusionKind::kLoop /* irrelevant */, params, + module->entry_computation())); + + auto* new_entry = module->AddComputationAndUnifyNamesAndIds( + builder.Build(), /*is_entry=*/false); + module->ReplaceEntryComputation(new_entry); + } + + return module; +} + +absl::StatusOr> GetMlirFusionEmitter( + const HloModule& module) { + auto data = std::make_unique(); + data->fusion = DynCast( + module.entry_computation()->root_instruction()); + TF_RET_CHECK(data->fusion != nullptr) << "Root instruction must be a fusion"; + data->device.emplace(TestGpuDeviceInfo::RTXA6000DeviceInfo()); + data->analysis.emplace( + HloFusionAnalysis::Create(*data->fusion, data->device.value())); + PreBufferAssignmentFusionInfo info(data->analysis.value()); + auto emitter = GetFusionEmitter(info); + + auto mlir_emitter = dynamic_cast(emitter.get()); + TF_RET_CHECK(mlir_emitter != nullptr) + << "Expected emitter to be an MlirFusionEmitter"; + + emitter.release(); + data->emitter.reset(mlir_emitter); + return data; +} + +mlir::MLIRContext GetMlirContextForTest() { + mlir::DialectRegistry registry; + registry.insert(); + return mlir::MLIRContext(registry); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h new file mode 100644 index 00000000000000..5dfa3009f71c40 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ +#define XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" +#include "xla/stream_executor/device_description.h" + +namespace xla { + +namespace gpu { + +// Loads a test module from the given filename, ensuring it has a single fusion. +// If the file contains more than one fusion, the function fails. If the file +// contains no fusions, the function generates a fusion from the entry +// computation. +absl::StatusOr> LoadTestModule( + absl::string_view filename); + +// Returns the MLIR fusion emitter for the given module, which should have been +// loaded using LoadTestModule. +struct EmitterData { + HloFusionInstruction* fusion; + std::optional device; + std::optional analysis; + std::unique_ptr emitter; +}; +absl::StatusOr> GetMlirFusionEmitter( + const HloModule& module); + +// Returns an MLIR context with all the dialects needed for testing. +mlir::MLIRContext GetMlirContextForTest(); + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TOOLS_TEST_LIB_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD new file mode 100644 index 00000000000000..24fb1963afccaa --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -0,0 +1,106 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [":friends"], + licenses = ["notice"], +) + +package_group( + name = "friends", + includes = [ + "//xla:friends", + ], +) + +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=GpuFusionTransforms", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "passes", + srcs = [ + "convert_xla_gpu_pure_call_ops.cc", + "erase_dead_functions.cc", + "expand_float_ops.cc", + "flatten_tensors.cc", + "lower_tensors.cc", + "lower_to_llvm.cc", + "lower_xla_gpu_to_scf.cc", + "merge_pointers_to_same_slice.cc", + "optimize_loops.cc", + "peel_loops.cc", + "propagate_slice_indices.cc", + "simplify_affine.cc", + "simplify_arith.cc", + "unswitch_loops.cc", + "vectorize_loads_stores.cc", + ], + hdrs = ["passes.h"], + deps = [ + ":passes_inc_gen", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/mlir_hlo", + "//xla/mlir_hlo:map_mhlo_to_scalar_op", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/ir:xla_gpu", + "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + "//xla/service/gpu/model:indexing_analysis", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineToStandard", + "@llvm-project//mlir:AffineUtils", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:ArithToLLVM", + "@llvm-project//mlir:ArithTransforms", + "@llvm-project//mlir:CallOpInterfaces", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:ComplexToLLVM", + "@llvm-project//mlir:ControlFlowToLLVM", + "@llvm-project//mlir:DataLayoutInterfaces", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncToLLVM", + "@llvm-project//mlir:GPUDialect", + "@llvm-project//mlir:GPUToNVVMTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMCommonConversion", + "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:MathDialect", + "@llvm-project//mlir:MathToLLVM", + "@llvm-project//mlir:MathTransforms", + "@llvm-project//mlir:NVVMDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:SCFToControlFlow", + "@llvm-project//mlir:SCFUtils", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:VectorDialect", + "@llvm-project//mlir:VectorToLLVM", + "@llvm-project//mlir:VectorTransforms", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc similarity index 94% rename from third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc index bb1270e98495c9..0c9053a5570654 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/convert_xla_gpu_pure_call_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/convert_xla_gpu_pure_call_ops.cc @@ -17,14 +17,14 @@ limitations under the License. #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_CONVERTPURECALLOPSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" struct RewriteCall : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc index 012201a76fe9c9..3918a191fee3cb 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/erase_dead_functions.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/erase_dead_functions.cc @@ -21,13 +21,13 @@ limitations under the License. #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_ERASEDEADFUNCTIONSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc index 001df2cc4bff91..66ff74413ef25c 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/expand_float_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/expand_float_ops.cc @@ -34,13 +34,14 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/xla_data.pb.h" namespace xla { @@ -52,7 +53,7 @@ using ma::SelectOp; using mlir::Value; #define GEN_PASS_DEF_EXPANDFLOATOPSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc similarity index 62% rename from third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc index 99a7ecb7c57113..c854507003c44f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/flatten_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/flatten_tensors.cc @@ -21,6 +21,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -43,16 +46,17 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/layout_util.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/model/indexing_analysis.h" #include "xla/shape_util.h" +#include "xla/xla_data.pb.h" namespace xla { namespace gpu { namespace { #define GEN_PASS_DEF_FLATTENTENSORSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" using mlir::Location; using mlir::LogicalResult; @@ -71,6 +75,7 @@ using mlir::func::FuncOp; using mlir::func::ReturnOp; using mlir::scf::ForOp; using mlir::scf::IfOp; +using mlir::scf::IndexSwitchOp; using mlir::tensor::ExtractOp; using mlir::tensor::InsertOp; @@ -79,12 +84,25 @@ RankedTensorType GetFlattenedType(RankedTensorType tensor_type) { tensor_type.getElementType()); } +bool IsScalarOrFlat(Type type) { + auto tensor_type = mlir::dyn_cast(type); + if (!tensor_type) return true; + return tensor_type.getRank() < 2; +} + bool HasOnlyFlatTensorsOrScalars(TypeRange types) { - return llvm::all_of(types, [](Type ty) { - auto tensor_type = mlir::dyn_cast(ty); - if (!tensor_type) return true; - return tensor_type.getRank() < 2; - }); + return llvm::all_of(types, IsScalarOrFlat); +} + +Value Flatten(Value value, PatternRewriter& rewriter) { + auto tensor_type = mlir::dyn_cast(value.getType()); + if (!tensor_type || tensor_type.getRank() < 2) { + return value; + } + auto flat_type = GetFlattenedType(tensor_type); + return rewriter + .create(value.getLoc(), flat_type, value) + .getResult(0); } struct RewriteFunctionSignatures : OpRewritePattern { @@ -109,20 +127,9 @@ struct RewriteFunctionSignatures : OpRewritePattern { rewriter.setInsertionPoint(terminator); for (Value result : terminator->getOperands()) { - auto tensor_type = mlir::dyn_cast(result.getType()); - if (!tensor_type) { - new_result_types.push_back(result.getType()); - new_results.push_back(result); - continue; - } - auto new_result_type = GetFlattenedType(tensor_type); - new_result_types.push_back(new_result_type); - - Value result_1d = - rewriter - .create(loc, new_result_type, result) - .getResult(0); - new_results.push_back(result_1d); + Value flattened = Flatten(result, rewriter); + new_results.push_back(flattened); + new_result_types.push_back(flattened.getType()); } rewriter.replaceOpWithNewOp(terminator, new_results); @@ -130,16 +137,14 @@ struct RewriteFunctionSignatures : OpRewritePattern { SmallVector new_operand_types(input_types); rewriter.setInsertionPointToStart(entry_block); for (auto&& [index, operand_type] : llvm::enumerate(new_operand_types)) { - if (auto tensor_type = mlir::dyn_cast(operand_type)) { - if (tensor_type.getRank() > 1) { - mlir::BlockArgument func_argument = op.getArgument(index); - auto cast_to_orig_type = rewriter.create( - loc, operand_type, func_argument); - func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), - cast_to_orig_type); - operand_type = GetFlattenedType(tensor_type); - } - } + if (IsScalarOrFlat(operand_type)) continue; + mlir::BlockArgument func_argument = op.getArgument(index); + auto cast_to_orig_type = rewriter.create( + loc, operand_type, func_argument); + func_argument.replaceAllUsesExcept(cast_to_orig_type.getResult(0), + cast_to_orig_type); + operand_type = + GetFlattenedType(mlir::cast(operand_type)); } // Replace the function arguments with the new types. for (auto [arg, arg_type] : @@ -152,6 +157,51 @@ struct RewriteFunctionSignatures : OpRewritePattern { } }; +struct RewritePureCall : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PureCallOp op, + PatternRewriter& rewriter) const override { + if (HasOnlyFlatTensorsOrScalars(op.getOperandTypes()) && + HasOnlyFlatTensorsOrScalars(op.getResultTypes())) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + SmallVector flat_operands; + flat_operands.reserve(op.getNumOperands()); + for (Value operand : op.getOperands()) { + flat_operands.push_back(Flatten(operand, rewriter)); + } + SmallVector flat_result_types; + flat_result_types.reserve(op.getNumResults()); + llvm::SmallBitVector results_to_update(op.getNumResults(), false); + for (auto [index, result_type] : llvm::enumerate(op.getResultTypes())) { + if (IsScalarOrFlat(result_type)) { + flat_result_types.push_back(result_type); + continue; + } + results_to_update.set(index); + flat_result_types.push_back( + GetFlattenedType(mlir::cast(result_type))); + } + Location loc = op.getLoc(); + auto new_call_op = rewriter.create( + loc, flat_result_types, op.getCalleeAttr(), flat_operands); + SmallVector new_results; + new_results.reserve(op.getNumResults()); + for (auto [index, new_result] : llvm::enumerate(new_call_op.getResults())) { + if (results_to_update.test(index)) { + new_results.push_back(new_result); + continue; + } + auto cast_to_orig_type = rewriter.create( + loc, op.getResult(index).getType(), new_result); + new_results.push_back(cast_to_orig_type.getResult(0)); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + // Returns the linearized index, if the rank is greater than 1. Otherwise, // returns nullptr. Value LinearizeIndex(TypedValue tensor, @@ -174,6 +224,43 @@ Value LinearizeIndex(TypedValue tensor, return result.front(); } +struct RewriteAllocateShared : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocateSharedOp op, + PatternRewriter& rewriter) const override { + auto tensor_type = op.getResult().getType(); + if (IsScalarOrFlat(tensor_type)) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + auto flat_type = GetFlattenedType(tensor_type); + Location loc = op.getLoc(); + Value new_op = rewriter.create(op.getLoc(), flat_type); + auto cast_to_orig_type = + rewriter.create(loc, tensor_type, new_op); + rewriter.replaceOp(op, cast_to_orig_type.getResult(0)); + return mlir::success(); + } +}; + +struct RewriteTensorConstant : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::arith::ConstantOp op, + PatternRewriter& rewriter) const override { + if (IsScalarOrFlat(op.getType())) { + return rewriter.notifyMatchFailure(op, "the tensor is already flat"); + } + auto tensor_type = mlir::cast(op.getType()); + auto dense_attr = mlir::dyn_cast(op.getValue()); + Value new_constant = rewriter.create( + op.getLoc(), dense_attr.reshape(GetFlattenedType(tensor_type))); + rewriter.replaceOpWithNewOp(op, tensor_type, + new_constant); + return mlir::success(); + } +}; + struct RewriteTensorExtract : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -262,7 +349,7 @@ std::optional GetDelinearizedTensor(Value value) { return cast->getOperand(0); } -struct RewriteForOp : public OpRewritePattern { +struct RewriteFor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ForOp op, @@ -337,7 +424,7 @@ struct RewriteForOp : public OpRewritePattern { } }; -struct RewriteIfOp : public OpRewritePattern { +struct RewriteIf : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IfOp op, @@ -405,6 +492,113 @@ struct RewriteIfOp : public OpRewritePattern { } }; +struct RewriteIndexSwitch : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter& rewriter) const override { + auto result_types = op.getResultTypes(); + if (HasOnlyFlatTensorsOrScalars(result_types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + auto default_yield = + mlir::cast(op.getDefaultBlock().getTerminator()); + SmallVector new_result_types; + new_result_types.reserve(default_yield.getNumOperands()); + bool found_cast = false; + for (auto& result : default_yield->getOpOperands()) { + auto delinearized_tensor = GetDelinearizedTensor(result.get()); + if (!delinearized_tensor.has_value()) { + new_result_types.push_back(result.get().getType()); + continue; + } + new_result_types.push_back(delinearized_tensor->getType()); + result.set(*delinearized_tensor); + found_cast = true; + } + if (!found_cast) { + return rewriter.notifyMatchFailure(op, "no cast found"); + } + Location loc = op.getLoc(); + // Update the "case" regions. + for (auto& case_region : op.getCaseRegions()) { + auto yield = mlir::cast( + case_region.getBlocks().front().getTerminator()); + mlir::OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(yield); + for (auto&& [result, type] : + llvm::zip(yield->getOpOperands(), new_result_types)) { + if (result.get().getType() == type) continue; + result.set( + rewriter.create(loc, type, result.get()) + .getResult(0)); + } + } + // Create new IndexSwitchOp and move the old op's regions to the new one. + auto new_index_switch = rewriter.create( + loc, new_result_types, op.getArg(), op.getCases(), op.getNumCases()); + for (auto&& [old_region, new_region] : + llvm::zip(op.getRegions(), new_index_switch.getRegions())) { + rewriter.inlineRegionBefore(*old_region, *new_region, new_region->end()); + } + // Update the results. + rewriter.setInsertionPointAfter(new_index_switch); + SmallVector new_results(new_index_switch.getResults()); + for (auto&& [index, result] : llvm::enumerate(new_results)) { + Type old_type = op->getResult(index).getType(); + if (result.getType() == old_type) continue; + result = + rewriter.create(loc, old_type, result) + .getResult(0); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + +struct RewriteSyncThreads : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SyncThreadsOp op, + PatternRewriter& rewriter) const override { + auto types = op.getResultTypes(); + if (HasOnlyFlatTensorsOrScalars(types)) { + return rewriter.notifyMatchFailure(op, "nothing to flatten"); + } + + auto loc = op.getLoc(); + + SmallVector new_operands; + new_operands.reserve(op.getNumOperands()); + llvm::SmallBitVector results_to_update(op.getNumResults(), false); + for (auto& operand : op->getOpOperands()) { + auto tensor_type = mlir::cast(operand.get().getType()); + if (tensor_type.getRank() < 2) continue; + results_to_update.set(operand.getOperandNumber()); + new_operands.push_back( + rewriter + .create( + loc, GetFlattenedType(tensor_type), operand.get()) + .getResult(0)); + } + auto new_op = rewriter.create(loc, TypeRange(new_operands), + new_operands); + SmallVector new_results; + new_results.reserve(op.getNumResults()); + for (auto [index, result] : llvm::enumerate(new_op.getResults())) { + if (!results_to_update.test(index)) { + new_results.push_back(result); + continue; + } + auto cast_to_orig_type = rewriter.create( + loc, result.getType(), result); + new_results.push_back(cast_to_orig_type.getResult(0)); + } + rewriter.replaceOp(op, new_results); + return mlir::success(); + } +}; + class FlattenTensorsPass : public impl::FlattenTensorsPassBase { public: @@ -414,10 +608,15 @@ class FlattenTensorsPass mlir::RewritePatternSet patterns(mlir_context); // clang-format off patterns.add< + RewriteAllocateShared, RewriteAtomicRMW, - RewriteForOp, + RewriteFor, RewriteFunctionSignatures, - RewriteIfOp, + RewriteIf, + RewriteIndexSwitch, + RewritePureCall, + RewriteSyncThreads, + RewriteTensorConstant, RewriteTensorExtract, RewriteTensorInsert >(mlir_context); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc similarity index 93% rename from third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc index 929ee0ee5744d8..63e5c75f56c03a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_tensors.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_tensors.cc @@ -24,7 +24,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/LogicalResult.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -57,10 +56,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/layout_util.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/shape_util.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" @@ -70,7 +66,7 @@ namespace { #define GEN_PASS_DECL_LOWERTENSORSPASS #define GEN_PASS_DEF_LOWERTENSORSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" using mlir::failure; using mlir::Location; @@ -172,27 +168,14 @@ struct RewriteFunctionSignatures : mlir::OpRewritePattern { } }; -Value GetLinearIndex(TypedValue tensor, - ValueRange indices, mlir::PatternRewriter& rewriter) { - auto byte_shape = ShapeUtil::MakeShape(U8, tensor.getType().getShape()); - if (auto encoding = tensor.getType().getEncoding()) { - *byte_shape.mutable_layout() = LayoutUtil::MakeLayout(llvm::to_vector( - mlir::cast(encoding).getValues())); - } - auto linear_shape = - ShapeUtil::MakeShape(U8, {ShapeUtil::ElementsIn(byte_shape)}); - auto linearize_map = - GetBitcastMap(byte_shape, linear_shape, tensor.getContext()); - mlir::SmallVector result; - rewriter.createOrFold(result, tensor.getLoc(), indices, - ValueRange{}, linearize_map); - CHECK_EQ(result.size(), 1); - auto index = result.front(); - auto index_ty = rewriter.getIntegerType( - mlir::DataLayout::closest(rewriter.getInsertionBlock()->getParentOp()) +Value GetLinearIndex(ValueRange indices, mlir::ImplicitLocOpBuilder& b) { + CHECK_LE(indices.size(), 1) << "Only 0D and 1D tensors are supported"; + auto index = indices.empty() ? b.create(0) + : indices.front(); + auto index_ty = b.getIntegerType( + mlir::DataLayout::closest(b.getInsertionBlock()->getParentOp()) .getTypeSizeInBits(index.getType())); - return rewriter.create(tensor.getLoc(), index_ty, - index); + return b.create(index_ty, index); } std::tuple GetI4IndexAndNibble(Value linear_index, @@ -206,28 +189,25 @@ std::tuple GetI4IndexAndNibble(Value linear_index, } mlir::LLVM::GEPOp CreateGep(TypedValue tensor, - Value linear_index, mlir::PatternRewriter& rewriter, + Value linear_index, mlir::ImplicitLocOpBuilder& b, Type element_type = nullptr) { if (!element_type) { element_type = tensor.getType().getElementType(); } - auto ptr = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); - auto tensor_ptr = rewriter - .create( - tensor.getLoc(), ptr, tensor) - .getResult(0); - mlir::LLVMTypeConverter converter(rewriter.getContext()); + auto ptr = mlir::LLVM::LLVMPointerType::get(b.getContext()); + auto tensor_ptr = + b.create(ptr, tensor).getResult(0); + mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_element_type = converter.convertType(element_type); - auto gep = rewriter.create( - tensor.getLoc(), ptr, llvm_element_type, tensor_ptr, linear_index); + auto gep = b.create(ptr, llvm_element_type, tensor_ptr, + linear_index); gep.setInbounds(true); return gep; } mlir::LLVM::GEPOp CreateGep(TypedValue tensor, - ValueRange indices, - mlir::PatternRewriter& rewriter) { - return CreateGep(tensor, GetLinearIndex(tensor, indices, rewriter), rewriter); + ValueRange indices, mlir::ImplicitLocOpBuilder& b) { + return CreateGep(tensor, GetLinearIndex(indices, b), b); } struct RewriteTensorExtract : mlir::OpRewritePattern { @@ -237,8 +217,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { mlir::tensor::ExtractOp op, mlir::PatternRewriter& rewriter) const override { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto linear_index = - GetLinearIndex(op.getTensor(), op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); Type element_type = op.getTensor().getType().getElementType(); Value is_low_nibble = nullptr; if (element_type == rewriter.getI4Type()) { @@ -247,7 +226,7 @@ struct RewriteTensorExtract : mlir::OpRewritePattern { GetI4IndexAndNibble(linear_index, b); } - auto gep = CreateGep(op.getTensor(), linear_index, rewriter, element_type); + auto gep = CreateGep(op.getTensor(), linear_index, b, element_type); auto load = rewriter .create(gep.getLoc(), gep.getElemType(), gep) @@ -296,7 +275,7 @@ struct RewriteTransferRead op.getSource()); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto linear_index = GetLinearIndex(source, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); mlir::VectorType vector_type = op.getVectorType(); if (vector_type.getElementType().isInteger(1)) { @@ -309,7 +288,7 @@ struct RewriteTransferRead b.create(1, linear_index.getType())); gep_element_type = b.getI8Type(); } - auto gep = CreateGep(source, linear_index, rewriter, gep_element_type); + auto gep = CreateGep(source, linear_index, b, gep_element_type); mlir::LLVMTypeConverter converter(b.getContext()); auto llvm_vector_type = converter.convertType(vector_type); @@ -345,7 +324,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); - auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); auto element_type = tensor_dest.getType().getElementType(); Value is_low_nibble = nullptr; @@ -355,7 +334,7 @@ struct RewriteTensorInsert : mlir::OpRewritePattern { GetI4IndexAndNibble(linear_index, b); } - auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type); + auto gep = CreateGep(tensor_dest, linear_index, b, element_type); auto scalar_value = op.getScalar(); if (is_low_nibble) { @@ -402,7 +381,7 @@ struct RewriteTransferWrite mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto tensor_dest = mlir::cast>(dest); - auto linear_index = GetLinearIndex(tensor_dest, op.getIndices(), rewriter); + auto linear_index = GetLinearIndex(op.getIndices(), b); auto element_type = tensor_dest.getType().getElementType(); mlir::Value vector_value = op.getVector(); @@ -420,7 +399,7 @@ struct RewriteTransferWrite // elements. vector_value = PermutePairsInVector(vector_value, b); } - auto gep = CreateGep(tensor_dest, linear_index, rewriter, element_type); + auto gep = CreateGep(tensor_dest, linear_index, b, element_type); mlir::LLVMTypeConverter converter(getContext()); auto llvm_type = converter.convertType(vector_value.getType()); @@ -724,7 +703,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { Location loc = op.getLoc(); llvm::StringRef sync_scope = is_amd_ ? "agent" : ""; - Value addr = CreateGep(op.getInput(), op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(loc, rewriter); + Value addr = CreateGep(op.getInput(), op.getIndices(), b); switch (atomic_bin_op) { case ml::AtomicBinOp::xchg: { @@ -932,7 +912,8 @@ class RewriteAtomicRMW : public mlir::OpRewritePattern { mlir::IntegerType::get(op.getContext(), small_type ? 32 : result_size); // Calculate load address for the input. - Value addr = CreateGep(input, op.getIndices(), rewriter); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value addr = CreateGep(input, op.getIndices(), b); Value shift, mask; if (small_type) { // Update input pointer by discarding the last two bits - i.e. align to diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc index 6e05eadbb3a796..28762d03c42a1a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_to_llvm.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_to_llvm.cc @@ -44,7 +44,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_LOWERTOLLVMPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc similarity index 66% rename from third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index 9028480b1d52b9..cbd64b870e83d1 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -38,19 +38,28 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" namespace xla { namespace gpu { +namespace { #define GEN_PASS_DEF_LOWERXLAGPUTOSCFPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" - -namespace { +#define GEN_PASS_DEF_LOWERXLAGPULOOPSTOSCFPASS +#include "xla/service/gpu/fusions/transforms/passes.h.inc" +using mlir::ImplicitLocOpBuilder; +using mlir::Location; +using mlir::OpBuilder; +using mlir::SmallVector; using mlir::success; +using mlir::Value; +using mlir::ValueRange; +using mlir::scf::IfOp; struct RewritePredicatedInsert : mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -104,18 +113,18 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { return op->emitOpError("max_distance must be a power of 2 < WarpSize()"); } - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - mlir::ValueRange values = op.getOperands(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + ValueRange values = op.getOperands(); for (int distance = max_distance; distance > 0; distance /= 2) { namespace ml = mlir::LLVM; - auto shuffle_32 = [&](mlir::Value v) { + auto shuffle_32 = [&](Value v) { return b .create(v, distance, WarpSize(), mlir::gpu::ShuffleMode::DOWN) .getShuffleResult(); }; - auto shuffle_int_or_float = [&](mlir::Value value) { + auto shuffle_int_or_float = [&](Value value) { auto ty = value.getType(); int bit_width = ty.getIntOrFloatBitWidth(); if (bit_width == 32) { @@ -130,7 +139,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { // Don't generate vectors if the size is 1. auto vector_type = ml::getVectorType(b.getI32Type(), n_shuffles); value = b.create(vector_type, value); - mlir::Value result_vec = b.create(vector_type); + Value result_vec = b.create(vector_type); for (int i = 0; i < n_shuffles; ++i) { auto idx = b.create(i, 32); result_vec = b.create( @@ -146,7 +155,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { return value; }; - auto shuffle = [&](mlir::Value value) -> mlir::Value { + auto shuffle = [&](Value value) -> Value { if (mlir::isa(value.getType())) { return b.create( value.getType(), @@ -168,7 +177,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { return shuffle_int_or_float(value); }; - llvm::SmallVector args = values; + SmallVector args = values; for (auto value : values) { args.push_back(shuffle(value)); } @@ -181,13 +190,73 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { } }; +struct RewriteXlaGpuLoop : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + LoopOp op, mlir::PatternRewriter& rewriter) const override { + Location loc = op.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + IndexingMap indexing_map = op.getIndexingMap(); + SmallVector lbs, ubs, steps; + mlir_converter::GetLoopBoundsFromIndexingMap(b, indexing_map, &lbs, &ubs, + &steps); + mlir::scf::LoopNest loop_nest = mlir::scf::buildLoopNest( + b, loc, lbs, ubs, steps, op.getInits(), + [&](OpBuilder& nested_builder, Location loc, ValueRange symbol_values, + ValueRange iter_args) -> mlir::scf::ValueVector { + mlir::ImplicitLocOpBuilder nested_b(loc, nested_builder); + auto is_in_bounds = mlir_converter::CheckConstraints( + indexing_map, op.getDims(), symbol_values, nested_b); + auto if_op = nested_b.create( + is_in_bounds, + [&](OpBuilder& then_builder, Location then_loc) -> void { + SmallVector bb_args(symbol_values); + bb_args.append(iter_args.begin(), iter_args.end()); + + mlir::Block* then_block = then_builder.getInsertionBlock(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(then_block); + rewriter.mergeBlocks(op.getBody(), then_block, bb_args); + + auto old_terminator = then_block->getTerminator(); + then_builder.create( + then_loc, old_terminator->getOperands()); + old_terminator->erase(); + }, + [&](OpBuilder& else_b, Location else_loc) { + else_b.create(loc, iter_args); + }); + return if_op.getResults(); + }); + rewriter.replaceOp(op, loop_nest.results); + return mlir::success(); + } +}; + class LowerXlaGpuToScfPass : public impl::LowerXlaGpuToScfPassBase { public: void runOnOperation() override { - mlir::RewritePatternSet patterns(&getContext()); + auto* ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); patterns.add(&getContext()); + RewriteShuffleReduce>(ctx); + if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +class LowerXlaGpuLoopsToScfPass + : public impl::LowerXlaGpuLoopsToScfPassBase { + public: + void runOnOperation() override { + auto* ctx = &getContext(); + mlir::RewritePatternSet patterns(ctx); + patterns.add(ctx); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); @@ -201,5 +270,9 @@ std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() { return std::make_unique(); } +std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() { + return std::make_unique(); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc index c1899d27c68874..50193e3a2a29f4 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/merge_pointers_to_same_slice.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/merge_pointers_to_same_slice.cc @@ -30,7 +30,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_MERGEPOINTERSTOSAMESLICEPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc index 6d5456f0150323..e483bfebedb979 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/optimize_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/optimize_loops.cc @@ -41,13 +41,14 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_map.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_OPTIMIZELOOPSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { @@ -66,8 +67,17 @@ bool DoIndicesDependOnInductionVar(mlir::ValueRange indices, bool CanReplaceInductionVar(mlir::ValueRange indices) { return absl::c_all_of(indices, [&](mlir::Value v) { - if (mlir::isa(v)) { - return true; + if (auto bbarg = mlir::dyn_cast(v)) { + auto for_op = mlir::dyn_cast_or_null( + v.getParentRegion()->getParentOp()); + // This is a bbarg that is defined outside of the loop, so it doesn't + // affect pipelining. + if (!for_op) { + return true; + } + // We can only replace the induction variable, not other loop-carried + // values. + return v == for_op.getInductionVar(); } auto* op = v.getDefiningOp(); return op && @@ -191,9 +201,10 @@ struct PipelineLoad : mlir::OpRewritePattern { auto plus_one_map = mlir::AffineMap::get( 1, 0, mlir::getAffineDimExpr(0, this->getContext()) + 1); b.setInsertionPoint(next_value); + IndexingMap indexing_map(plus_one_map, {DimVar{0, ub.getSExtValue() - 1}}, + /*range_vars=*/{}, /*rt_vars=*/{}); auto induction_plus_one = - b.create(new_for.getInductionVar(), plus_one_map, 0, - ub.getSExtValue() - 1) + b.create(new_for.getInductionVar(), indexing_map) ->getResult(0); // Create the new apply_indexing ops outside the if, to improve CSE. diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h similarity index 84% rename from third_party/xla/xla/service/gpu/fusions/mlir/passes.h rename to third_party/xla/xla/service/gpu/fusions/transforms/passes.h index bb0f1d44380018..e70af753e87a3f 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ #include #include @@ -27,7 +27,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DECL -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" // Returns the range of a given value, if it can be statically determined. std::optional GetRange(mlir::Value value); @@ -44,8 +44,10 @@ std::unique_ptr CreateLowerTensorsPass( bool is_amd_gpu = false, const std::string& gpu_arch = "6.0"); std::unique_ptr CreateLowerToLLVMPass(); std::unique_ptr CreateLowerXlaGpuToScfPass(); +std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); +std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); @@ -53,9 +55,9 @@ std::unique_ptr CreateUnswitchLoopsPass(); std::unique_ptr CreateVectorizeLoadsAndStoresPass(); #define GEN_PASS_REGISTRATION -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/mlir/passes.td rename to third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 6785670581d68e..af27b3692766ee 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ -#define XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_ +#define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_ include "mlir/Pass/PassBase.td" @@ -162,7 +162,7 @@ def ExpandFloatOpsPass : Pass<"xla-gpu-expand-float-ops", "mlir::ModuleOp"> { } def LowerXlaGpuToScfPass : - Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::ModuleOp"> { + Pass<"xla-gpu-lower-xla-gpu-to-scf", "mlir::func::FuncOp"> { let summary = "Lowers xla_gpu to SCF."; let dependentDialects = [ @@ -173,6 +173,24 @@ def LowerXlaGpuToScfPass : let constructor = "CreateLowerXlaGpuToScfPass()"; } +def LowerXlaGpuLoopsToScfPass : Pass< + "xla-gpu-lower-xla-gpu-loops-to-scf", "mlir::func::FuncOp"> { + let summary = "Lowers xla_gpu.loop to SCF."; + + let description = [{ + This pass is separate from lower-xla-gpu-to-scf because + lower-xla-gpu-to-scf, inliner, peeling and lower-xla-gpu-loops-to-scf + have to run in that order. + }]; + + let dependentDialects = [ + "mlir::scf::SCFDialect", + "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", + ]; + + let constructor = "CreateLowerXlaGpuLoopsToScfPass()"; +} + def EraseDeadFunctionsPass : Pass<"xla-erase-dead-functions", "mlir::ModuleOp"> { let summary = "Deletes unused functions"; @@ -222,6 +240,17 @@ def VectorizeLoadsAndStoresPass : let constructor = "CreateVectorizeLoadsAndStoresPass()"; } +def PeelLoopsPass : Pass<"xla-gpu-peel-loops", "mlir::func::FuncOp"> { + let summary = "Peels xla_gpu.loop."; + let description = [{ + Attempts to split each loop dimension [0, NUM_ITERATIONS) + as [0, NUM_ITERATIONS - 1) and [NUM_ITERATIONS - 1, NUM_ITERATIONS) + if it removes a constraint. + }]; + let dependentDialects = ["xla::gpu::XlaGpuDialect"]; + let constructor = "CreatePeelLoopsPass()"; +} + def OptimizeLoopsPass : Pass<"xla-gpu-optimize-loops", "mlir::func::FuncOp"> { let summary = "Unrolls and pipelines loops."; @@ -287,4 +316,4 @@ def UnswitchLoopsPass : let constructor = "CreateUnswitchLoopsPass()"; } -#endif // XLA_SERVICE_GPU_FUSIONS_MLIR_PASSES_TD_ +#endif // XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc new file mode 100644 index 00000000000000..7c0845fff2011c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/peel_loops.cc @@ -0,0 +1,149 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/model/indexing_map.h" + +namespace xla { +namespace gpu { +namespace { + +#define GEN_PASS_DEF_PEELLOOPSPASS +#include "xla/service/gpu/fusions/transforms/passes.h.inc" + +using mlir::Location; +using mlir::OpBuilder; +using mlir::OpRewritePattern; +using mlir::PatternRewriter; +using mlir::SmallVector; +using mlir::Value; +using mlir::ValueRange; + +struct PeelLoop : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + LoopOp loop_op, PatternRewriter& rewriter) const override { + int64_t cumulative_loop_size = 1; + + // Compute the list of indexing maps. The last element is the "peeled" or + // "main" loop. Everything else is a "tail" loop. + auto indexing_map = loop_op.getIndexingMap(); + // TODO(b/358274367): Remove the simplify call once we have `is_simplified` + // field and a canonicalization pattern to simplify indexing map in + // xla_gpu.loop. + indexing_map.Simplify(); + SmallVector indexing_maps{indexing_map}; + for (int sym_index = indexing_map.GetSymbolCount() - 1; + sym_index >= 0 && cumulative_loop_size < 64; --sym_index) { + IndexingMap indexing_map = indexing_maps.back(); + auto& bound = indexing_map.GetSymbolBound(sym_index); + cumulative_loop_size *= bound.GetLoopTripCount(); + if (!indexing_map.IsSymbolConstrained(sym_index) || + bound.upper == bound.lower) { + continue; + } + // Create peeled indexing map. + IndexingMap peeled_map = indexing_map; + --peeled_map.GetMutableSymbolBound(sym_index).upper; + peeled_map.Simplify(); + + // If the symbol is still constrained, peeling does not help. + if (peeled_map.IsSymbolConstrained(sym_index)) continue; + + // Create remainder indexing map. + IndexingMap tail_map = indexing_map; + tail_map.GetMutableSymbolBound(sym_index).lower = bound.upper; + tail_map.Simplify(); + + VLOG(5) << "Peeled indexing map\n" + << indexing_map.ToString() << "into\n" + << peeled_map.ToString() << "and\n" + << tail_map.ToString() << "\n"; + indexing_maps.pop_back(); + indexing_maps.push_back(tail_map); + indexing_maps.push_back(peeled_map); + } + + if (indexing_maps.size() == 1) { + return rewriter.notifyMatchFailure(loop_op, + "No range variables to peel."); + } + + // Create chained loops from the list of indexing maps. + Location loc = loop_op.getLoc(); + SmallVector inits = loop_op.getInits(); + for (const auto& indexing_map : llvm::reverse(indexing_maps)) { + if (indexing_map.IsKnownEmpty()) continue; + auto tail_loop = rewriter.create( + loc, indexing_map, loop_op.getDims(), inits, + [&](OpBuilder& nested_b, Location nested_loc, ValueRange ivs, + ValueRange iter_args) { + OpBuilder::InsertionGuard guard(nested_b); + mlir::IRMapping mapping; + mapping.map(loop_op.getInductionVars(), ivs); + mapping.map(loop_op.getRegionIterArgs(), iter_args); + for (auto& op : loop_op.getBody()->getOperations()) { + nested_b.clone(op, mapping); + } + }); + inits = tail_loop.getResults(); + } + rewriter.replaceOp(loop_op, inits); + return mlir::success(); + } +}; + +struct PeelLoopsPass : public impl::PeelLoopsPassBase { + void runOnOperation() override { + auto func = getOperation(); + mlir::MLIRContext* mlir_context = &getContext(); + mlir::RewritePatternSet patterns(mlir_context); + patterns.add(mlir_context); + if (mlir::failed( + mlir::applyPatternsAndFoldGreedily(func, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +std::unique_ptr CreatePeelLoopsPass() { + return std::make_unique(); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc similarity index 95% rename from third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc index 218b432f795190..31a637900c8a7a 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/propagate_slice_indices.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/propagate_slice_indices.cc @@ -19,13 +19,13 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/transforms/passes.h" namespace xla { namespace gpu { #define GEN_PASS_DEF_PROPAGATESLICEINDICESPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc index 7b234998860d37..acbd9d3735ea46 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_affine.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_affine.cc @@ -41,8 +41,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -69,7 +69,7 @@ using mlir::affine::AffineApplyOp; namespace arith = mlir::arith; #define GEN_PASS_DEF_SIMPLIFYAFFINEPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" int Distance(ImplicitLocOpBuilder& builder, Value a) { auto* block = builder.getInsertionBlock(); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc index 77b1d7c40ea290..f3d67e24ee3248 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/simplify_arith.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/simplify_arith.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -29,8 +30,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/model/indexing_map.h" namespace xla { @@ -38,7 +39,7 @@ namespace gpu { namespace { #define GEN_PASS_DEF_SIMPLIFYARITHPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" using mlir::LogicalResult; using mlir::OpRewritePattern; diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD new file mode 100644 index 00000000000000..381d5a3220b1df --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/BUILD @@ -0,0 +1,16 @@ +load("//xla:lit.bzl", "lit_test_suite") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + licenses = ["notice"], +) + +lit_test_suite( + name = "tests", + srcs = glob(["*.mlir"]), + cfg = "//xla:lit.cfg.py", + tools = [ + "//xla/service/gpu/fusions/tools:mlir_fusions_opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/convert_xla_gpu_pure_calls.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/convert_xla_gpu_pure_calls.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/expand_float_ops.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/expand_float_ops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir similarity index 57% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir index ee2c2ae9e9553d..18e3e30bc309c2 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/flatten_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/flatten_tensors.mlir @@ -8,13 +8,12 @@ func.func @tensor_extract( : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> func.return %v : f32 } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1 * 2 + d0)> +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d1 * 2 + d0), domain: d0 in [0, 1], d1 in [0, 2]> // CHECK-LABEL: func.func @tensor_extract( // CHECK-SAME: %[[SRC:.*]]: tensor<6xf32>, // CHECK-SAME: %[[I:.*]]: index, %[[J:.*]]: index) -> f32 { -// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] -// CHECK-SAME: in [0, 1], %[[J]] in [0, 2]) +// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]], %[[J]]) // CHECK: tensor.extract %[[SRC]][%[[INDEX]]] : tensor<6xf32> // ----- @@ -37,6 +36,26 @@ func.func @tensor_insert( // ----- +func.func @update(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { + %c1 = arith.constant 1 : index + %c42_f32 = arith.constant 42.0 : f32 + %out = tensor.insert %c42_f32 into %arg0[%c1, %c1] : tensor<10x24xf32> + func.return %out : tensor<10x24xf32> +} + +func.func @pure_call(%arg0: tensor<10x24xf32>) -> tensor<10x24xf32> { + %updated_tensor = xla_gpu.pure_call @update(%arg0) + : (tensor<10x24xf32>) -> (tensor<10x24xf32>) + func.return %updated_tensor : tensor<10x24xf32> +} +// CHECK-LABEL: func.func @pure_call( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<240xf32>) -> tensor<240xf32> { +// CHECK-NEXT: xla_gpu.pure_call @update(%[[TENSOR]]) +// CHECK-SAME: : (tensor<240xf32>) -> tensor<240xf32> +// CHECK-NEXT: return + +// ----- + func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) -> (tensor<2x4xf32>) { %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { @@ -47,13 +66,12 @@ func.func @atomic_rmw(%in: tensor<2x4xf32>, %i: index, %j: index) } return %ret : tensor<2x4xf32> } -// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 * 4 + d1)> - +// CHECK: #[[$MAP:.+]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 4 + d1), domain: d0 in [0, 1], d1 in [0, 3]> // CHECK-LABEL: func.func @atomic_rmw( // CHECK-SAME: %[[TENSOR:.*]]: tensor<8xf32>, %[[I:.*]]: index, // CHECK-SAME: %[[J:.*]]: index) -> tensor<8xf32> { // CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[I]] in [0, 1], %[[J]] in [0, 3]) +// CHECK-SAME: (%[[I]], %[[J]]) // CHECK: xla_gpu.atomic_rmw %[[TENSOR]][%[[INDEX]]] : tensor<8xf32> // ----- @@ -73,9 +91,8 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) } {some_attr} return %for#0, %for#1, %c0_f32 : tensor<32x1024xf32>, tensor<64x8x4xf32>, f32 } - -// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 + 1024)> -// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 32 + 5)> +// CHECK: #[[$MAP0:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1024) +// CHECK: #[[$MAP1:.+]] = #xla_gpu.indexing_map<(d0) -> (d0 * 32 + 5) // CHECK-LABEL: func.func @for_loop( // CHECK-SAME: %[[T0:.*]]: tensor<32768xf32>, // CHECK-SAME: %[[T1:.*]]: tensor<2048xf32>) -> (tensor<32768xf32>, tensor<2048xf32>, f32) { @@ -87,17 +104,17 @@ func.func @for_loop(%t0: tensor<32x1024xf32>, %t1: tensor<64x8x4xf32>) // CHECK: %[[FOR:.*]]:2 = scf.for %[[I:.*]] = %[[C0]] to %[[C64]] // CHECK-SAME: step %[[C32]] // CHECK-SAME: iter_args(%[[T0_:.*]] = %[[T0]], %[[T1_:.*]] = %[[T1]]) -// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]] in [0, 1023]) +// CHECK: %[[IND0:.*]] = xla_gpu.apply_indexing #[[$MAP0]](%[[I]]) // CHECK: %[[UPD0:.*]] = tensor.insert %[[F32]] into %[[T0_]][%[[IND0]]] -// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]] in [0, 63]) +// CHECK: %[[IND1:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]]) // CHECK: %[[UPD1:.*]] = tensor.insert %[[F32]] into %[[T1_]][%[[IND1]]] // CHECK: scf.yield %[[UPD0]], %[[UPD1]] : tensor<32768xf32>, tensor<2048xf32> // ----- -#map = affine_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36)> -#map1 = affine_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4)> -#map2 = affine_map<(d0, d1) -> ((d1 * 128 + d0) mod 9)> +#map = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) floordiv 36), domain: d0 in [0, 127], d1 in [0, 393749]> +#map1 = #xla_gpu.indexing_map<(d0, d1) -> (((d1 * 128 + d0) floordiv 9) mod 4), domain: d0 in [0, 127], d1 in [0, 393749]> +#map2 = #xla_gpu.indexing_map<(d0, d1) -> ((d1 * 128 + d0) mod 9), domain: d0 in [0, 127], d1 in [0, 393749]> func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %arg2: tensor<1400x1x4x9xf32>, %arg3: tensor<4000x4x9xf32>) -> tensor<4000x4x9xf32> { @@ -105,13 +122,13 @@ func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, %c3999 = arith.constant 3999 : index %th_x = gpu.thread_id x {xla.range = [0 : index, 127 : index]} %bl_x = gpu.block_id x {xla.range = [0 : index, 393749 : index]} - %0 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 393749]) + %0 = xla_gpu.apply_indexing #map(%th_x, %bl_x) %extracted = tensor.extract %arg1[%0, %c0] : tensor<1400x1xi32> %1 = arith.index_cast %extracted : i32 to index %2 = arith.cmpi ule, %1, %c3999 : index %3 = scf.if %2 -> (tensor<4000x4x9xf32>) { - %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 393749]) - %5 = xla_gpu.apply_indexing #map2(%th_x in [0, 127], %bl_x in [0, 393749]) + %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x) + %5 = xla_gpu.apply_indexing #map2(%th_x, %bl_x) %elem = tensor.extract %arg2[%0, %c0, %4, %5] : tensor<1400x1x4x9xf32> %atomic_rmw = xla_gpu.atomic_rmw %arg3[%1, %4, %5] : tensor<4000x4x9xf32> { ^bb0(%arg4: f32): @@ -132,9 +149,67 @@ func.func @if_op(%arg0: tensor<4000x4x9xf32>, %arg1: tensor<1400x1xi32>, // ----- +func.func @allocate_shared() -> tensor<10x15xf32> { + %shmem = xla_gpu.allocate_shared : tensor<10x15xf32> + func.return %shmem : tensor<10x15xf32> +} +// CHECK-LABEL: func.func @allocate_shared() -> tensor<150xf32> +// CHECK: xla_gpu.allocate_shared : tensor<150xf32> +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + +func.func @sync() -> (tensor<8x4xf32>, tensor<8x4xf32>) { + %shared1 = xla_gpu.allocate_shared : tensor<8x4xf32> + %shared2 = xla_gpu.allocate_shared : tensor<8x4xf32> + %sync:2 = xla_gpu.sync_threads %shared1, %shared2 + : tensor<8x4xf32>, tensor<8x4xf32> + return %sync#0, %sync#1 : tensor<8x4xf32>, tensor<8x4xf32> +} +// CHECK-LABEL: func.func @sync() -> (tensor<32xf32>, tensor<32xf32>) { +// CHECK: %[[SHARED1:.*]] = xla_gpu.allocate_shared : tensor<32xf32> +// CHECK: %[[SHARED2:.*]] = xla_gpu.allocate_shared : tensor<32xf32> +// CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHARED1]], %[[SHARED2]] +// CHECK-SAME: : tensor<32xf32>, tensor<32xf32> +// CHECK-NEXT: return + +// ----- + +func.func @index_switch(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>, + %arg2: tensor<2x3xf32>, %arg3: tensor<2x3xf32> + ) -> (tensor<2x3xf32>, tensor<2x3xf32>) { + %block_id_y = gpu.block_id y {xla.range = [0 : index, 1 : index]} + %0:2 = scf.index_switch %block_id_y -> tensor<2x3xf32>, tensor<2x3xf32> + case 1 { + scf.yield %arg0, %arg3 : tensor<2x3xf32>, tensor<2x3xf32> + } + default { + scf.yield %arg1, %arg2 : tensor<2x3xf32>, tensor<2x3xf32> + } + return %0#0, %0#1: tensor<2x3xf32>, tensor<2x3xf32> +} +// CHECK-LABEL: func.func @index_switch +// CHECK-SAME: -> (tensor<6xf32>, tensor<6xf32>) +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + +func.func @constant() -> tensor<2x3xf32> { + %cst = arith.constant dense<[ + [-3.000000e+00, 2.000000e+00, 1.000000e+00], + [0.000000e+00, -3.000000e+00, 1.000000e+00] + ]> : tensor<2x3xf32> + return %cst : tensor<2x3xf32> +} +// CHECK-LABEL: func.func @constant +// CHECK-SAME: -> tensor<6xf32> +// CHECK-NOT: builtin.unrealized_conversion_cast + +// ----- + func.func @dangling_cast(%arg0: tensor<6xf32>, %arg1: index) -> i32 { %v = tensor.extract %arg0[%arg1] : tensor<6xf32> %cast = builtin.unrealized_conversion_cast %v : f32 to i32 func.return %cast : i32 } -// CHECK: FlattenTensorsPass failed to converge +// CHECK: FlattenTensorsPass failed to converge \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir new file mode 100644 index 00000000000000..f15b37b040b848 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/inlining.mlir @@ -0,0 +1,295 @@ +// RUN: mlir_fusions_opt %s -split-input-file -xla-erase-dead-functions -inline | FileCheck %s + +module { + func.func private @mul(%a: f32, %b: f32) -> f32 { + %ret = arith.mulf %a, %b : f32 + return %ret : f32 + } + + func.func private @add(%a: f32, %b: f32) -> f32 { + %add = arith.addf %a, %b : f32 + %ret = xla_gpu.pure_call @mul(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %ret = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: @caller +// CHECK-NOT: xla_gpu.pure_call @add +// CHECK: arith.addf +// CHECK-NOT: xla_gpu.pure_call @mul +// CHECK: arith.mulf + +// ----- + +module { + func.func @fused_computation(%arg0: tensor<2xf32> {xla.slice_index = 0 : index}, %arg1: tensor<2xf32> {xla.slice_index = 1 : index}, %arg2: tensor<2xf32> {xla.slice_index = 2 : index}) -> tensor<2xf32> attributes {xla.entry} { + %0 = gpu.thread_id x {xla.range = [0 : index, 1 : index]} + %1 = xla_gpu.pure_call @fused_computation_atan2(%arg0, %arg1, %0) : (tensor<2xf32>, tensor<2xf32>, index) -> f32 + %inserted = tensor.insert %1 into %arg2[%0] : tensor<2xf32> + return %inserted : tensor<2xf32> + } + func.func private @fused_computation_atan2(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>, %arg2: index {xla.range = [0 : index, 1 : index]}) -> f32 attributes {llvm.linkage = #llvm.linkage} { + %extracted = tensor.extract %arg0[%arg2] : tensor<2xf32> + %extracted_0 = tensor.extract %arg1[%arg2] : tensor<2xf32> + %0 = arith.addf %extracted, %extracted_0 : f32 + %1 = arith.subf %extracted, %extracted_0 : f32 + %2 = arith.mulf %0, %1 : f32 + %3 = arith.divf %0, %1 : f32 + %4 = math.atan2 %2, %3 : f32 + return %4 : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: @fused_computation +// CHECK-NOT: xla_gpu.pure_call @add +// CHECK: gpu.thread_id +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: tensor.extract +// CHECK-NEXT: arith.addf +// CHECK-NEXT: arith.subf +// CHECK-NEXT: arith.mulf +// CHECK-NEXT: arith.divf +// CHECK-NEXT: math.atan2 +// CHECK-NEXT: tensor.insert + +// ----- + +module { + // Do not inline this function as it has two callers. Even if the callers are + // in different functions at the start, after inlining the two callers are in + // the same function. + func.func private @large(%a: f32, %b: f32) -> f32 { + %mul = arith.mulf %a, %b : f32 + %add = arith.addf %a, %mul : f32 + %div = arith.divf %add, %b : f32 + %sub = arith.subf %div, %a : f32 + %atan2 = math.atan2 %b, %sub : f32 + %neg = arith.negf %atan2 : f32 + %zero = arith.constant 0.0 : f32 + %comp = arith.cmpf olt, %neg, %zero : f32 + %ret = arith.select %comp, %zero, %neg : f32 + return %ret : f32 + } + + func.func private @add(%a: f32, %b: f32) -> f32 { + %add = arith.addf %a, %b : f32 + %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla_gpu.pure_call @large(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: @caller +// CHECK: arith.addf +// CHECK: xla_gpu.pure_call @large +// CHECK: xla_gpu.pure_call @large + +// ----- + +module { + func.func private @add(%a: f32, %b: f32) -> f32 { + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %add = xla_gpu.pure_call @add(%a, %b) : (f32, f32) -> (f32) + %ret = xla_gpu.pure_call @add(%add, %add) : (f32, f32) -> (f32) + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: @caller +// CHECK-NOT: xla_gpu.pure_call +// CHECK: arith.addf +// CHECK: arith.addf + +// ----- + +module { + func.func private @fib0(%start : f32) -> f32 { + %zero = arith.constant 0.0 : f32 + return %zero : f32 + } + func.func private @fib1(%start : f32) -> f32 { + return %start : f32 + } + func.func private @fib2(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib0(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib3(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib1(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib4(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib2(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + // When inlining the other functions into @fib5, this function exceeds the + // threshold for inlining. + func.func private @fib5(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib3(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + // As we do not inline @fib5 into @fib6, this function stays below the + // threshold for inlining. + func.func private @fib6(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib4(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + func.func private @fib7(%start : f32) -> f32 { + %a = xla_gpu.pure_call @fib5(%start) : (f32) -> (f32) + %b = xla_gpu.pure_call @fib6(%start) : (f32) -> (f32) + %ret = arith.addf %a, %b : f32 + return %ret : f32 + } + + func.func @caller(%a: f32) -> f32 { + %ret = xla_gpu.pure_call @fib7(%a) : (f32) -> (f32) + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: @caller +// CHECK: arith.constant 0.000000e+00 +// CHECK: xla_gpu.pure_call @fib5 +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: arith.addf +// CHECK: xla_gpu.pure_call @fib5 +// CHECK: arith.addf +// CHECK: arith.addf + +// ----- + +module { + func.func private @complex(%a: f32, %b: f32) -> complex { + %ret = complex.create %a, %b : complex + return %ret : complex + } + + func.func @caller(%a: f32, %b: f32) -> complex { + %ret = xla_gpu.pure_call @complex(%a, %b) : (f32, f32) -> (complex) + return %ret : complex + } +} + +// CHECK-LABEL: module { +// CHECK: @caller +// CHECK-NEXT: complex.create + +// ----- + +module { + func.func private @callee2(%a: f32) -> f32 { + %ret = arith.addf %a, %a : f32 + return %ret : f32 + } + + func.func private @callee1(%a: f32) -> f32 { + %c1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) + %b0 = arith.addf %a, %a : f32 + %b1 = arith.addf %b0, %a : f32 + %b2 = arith.addf %b1, %a : f32 + %b3 = arith.addf %b2, %a : f32 + %b4 = arith.addf %b3, %a : f32 + %b5 = arith.addf %b4, %a : f32 + %b6 = arith.addf %b5, %a : f32 + %b7 = arith.addf %b6, %a : f32 + %c2 = xla_gpu.pure_call @callee2(%b7) : (f32) -> (f32) + %ret = arith.addf %c1, %c2 : f32 + return %ret : f32 + } + + func.func private @dead(%a: f32) -> f32 { + %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %ret = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK-NOT: func.func +// CHECK: func.func @caller +// CHECK-NOT: xla_gpu.pure_call +// CHECK-NOT: func.func + +// ----- + +module { + func.func private @callee1(%a: f32) -> f32 { + %b0 = arith.addf %a, %a : f32 + %b1 = arith.addf %b0, %a : f32 + %b2 = arith.addf %b1, %a : f32 + %b3 = arith.addf %b2, %a : f32 + %b4 = arith.addf %b3, %a : f32 + %b5 = arith.addf %b4, %a : f32 + %b6 = arith.addf %b5, %a : f32 + %b7 = arith.addf %b6, %a : f32 + %b8 = arith.addf %b7, %a : f32 + %b9 = arith.addf %b8, %a : f32 + %b10 = arith.addf %b9, %a : f32 + %b11 = arith.addf %b10, %a : f32 + return %b11 : f32 + } + + func.func private @callee2(%a: f32) -> f32 { + %call = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %b0 = arith.addf %a, %a : f32 + %b1 = arith.addf %b0, %a : f32 + %b2 = arith.addf %b1, %a : f32 + %b3 = arith.addf %b2, %a : f32 + %b4 = arith.addf %b3, %a : f32 + %b5 = arith.addf %b4, %a : f32 + %b6 = arith.addf %b5, %a : f32 + %b7 = arith.addf %b6, %a : f32 + %b8 = arith.addf %b7, %a : f32 + %b9 = arith.addf %b8, %a : f32 + %ret = arith.addf %call, %b9 : f32 + return %ret : f32 + } + + func.func @caller(%a: f32, %b: f32) -> f32 { + %call1 = xla_gpu.pure_call @callee2(%a) : (f32) -> (f32) + %call2 = xla_gpu.pure_call @callee1(%a) : (f32) -> (f32) + %ret = arith.addf %call1, %call2 : f32 + return %ret : f32 + } +} + +// CHECK-LABEL: module { +// CHECK: func.func private @callee1 +// CHECK-NOT: callee2 +// CHECK: func.func @caller +// CHECK-COUNT-2: pure_call @callee1 diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir similarity index 54% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir index 2125e6f4d70c8f..822c3a85c9a2a0 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_tensors.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_tensors.mlir @@ -80,55 +80,28 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry> // ----- -module { - func.func @layout( - %arg0: tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>>, - %arg1: index, %arg2: index) -> f32 { - %v = tensor.extract %arg0[%arg1, %arg2] - : tensor<2x3xf32, dense<[0, 1]> : tensor<2xi64>> - func.return %v : f32 - } -} - -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d1 * 2 + d0)> -// CHECK-LABEL: @layout( -// CHECK-SAME: %[[ARG0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[X:.*]]: index, %[[Y:.*]]: index -// CHECK: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP]] -// CHECK-SAME: (%[[X]] in [0, 1], %[[Y]] in [0, 2]) -// CHECK: %[[IDX_CAST:.*]] = arith.index_castui %[[IDX]] : index to i64 -// CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %[[ARG0]][%[[IDX_CAST]]] -// CHECK: llvm.load %[[PTR]] +func.func @store_control_flow( %arg0: tensor<2xf32>, %arg1: index) + -> tensor<2xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %cst2 = arith.constant 1.0 : f32 -// ----- + %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { + %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> + scf.yield %new_out : tensor<2xf32> + } -module { - func.func @store_control_flow( - %arg0: tensor<2xf32>, - %arg1: index - ) -> tensor<2xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %cst = arith.constant 0.0 : f32 - %cst2 = arith.constant 1.0 : f32 - - %for = scf.for %i = %c0 to %c2 step %c1 iter_args(%arg2 = %arg0) -> tensor<2xf32> { - %new_out = tensor.insert %cst into %arg2[%i] : tensor<2xf32> - scf.yield %new_out : tensor<2xf32> - } - - %inbounds = arith.cmpi sle, %arg1, %c1 : index - %result = scf.if %inbounds -> tensor<2xf32> { - %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> - scf.yield %if : tensor<2xf32> - } else { - scf.yield %for : tensor<2xf32> - } - func.return %result : tensor<2xf32> + %inbounds = arith.cmpi sle, %arg1, %c1 : index + %result = scf.if %inbounds -> tensor<2xf32> { + %if = tensor.insert %cst2 into %for[%arg1] : tensor<2xf32> + scf.yield %if : tensor<2xf32> + } else { + scf.yield %for : tensor<2xf32> } + func.return %result : tensor<2xf32> } - // CHECK: @store_control_flow(%[[ARG0:.*]]: !llvm.ptr, %[[X:.*]]: index) { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index @@ -145,33 +118,25 @@ module { // ----- -module { - func.func @large_tensor( - %arg0: tensor<1024x1024x1024x6xf32>, - %arg1: index) -> f32 { - %v = tensor.extract %arg0[%arg1, %arg1, %arg1, %arg1] : tensor<1024x1024x1024x6xf32> - func.return %v : f32 - } +func.func @large_tensor(%arg0: tensor<8000000000xf32>, %arg1: index) -> f32 { + %v = tensor.extract %arg0[%arg1] : tensor<8000000000xf32> + func.return %v : f32 } - -// CHECK: @large_tensor +// CHECK-LABEL: @large_tensor // CHECK: arith.index_castui {{.*}} : index to i64 // ----- -module { - func.func @extract_from_constant(%arg0: tensor<2x1xf32>, - %arg1: index, %arg2: index) -> f32 { - %cst = arith.constant dense<[[1.000000e+00], [2.000000e+00]]> : tensor<2x1xf32> - %extracted = tensor.extract %arg0[%arg1, %arg2] : tensor<2x1xf32> - %extracted_0 = tensor.extract %cst[%arg1, %arg2] : tensor<2x1xf32> - %0 = arith.addf %extracted, %extracted_0 : f32 - return %0 : f32 - } +func.func @extract_from_constant(%arg0: tensor<2xf32>, %arg1: index) -> f32 { + %cst = arith.constant dense<[1.000000e+00, 2.000000e+00]> : tensor<2xf32> + %extracted = tensor.extract %arg0[%arg1] : tensor<2xf32> + %extracted_0 = tensor.extract %cst[%arg1] : tensor<2xf32> + %0 = arith.addf %extracted, %extracted_0 : f32 + return %0 : f32 } // CHECK: llvm.mlir.global private constant @global_cst_0(dense< // CHECK-SAME: [1.000000e+00, 2.000000e+00]> : tensor<2xf32>) {addr_space = 0 : i32} : !llvm.array<2 x f32> -// CHECK: @extract_from_constant +// CHECK-LABEL: @extract_from_constant // CHECK: %[[ADDR_OF:.*]] = llvm.mlir.addressof @global_cst_0 : !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ADDR_OF]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f32 @@ -180,31 +145,25 @@ module { // ----- -module { - func.func @vector_constant() -> vector<2xindex> { - %c1 = arith.constant dense<[1, 2]> : vector<2xindex> - func.return %c1 : vector<2xindex> - } +func.func @vector_constant() -> vector<2xindex> { + %c1 = arith.constant dense<[1, 2]> : vector<2xindex> + func.return %c1 : vector<2xindex> } - // vector constants should not be rewritten. // CHECK: @vector_constant // CHECK-NEXT: arith.constant // ----- -module { - func.func @complex_tensor_insert( - %arg0: tensor<10xcomplex>) -> tensor<10xcomplex> { - %c1 = arith.constant 1 : index - %real = arith.constant 3.0 : f32 - %imag = arith.constant 2.0 : f32 - %complex = complex.create %real, %imag : complex - %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> - func.return %out : tensor<10xcomplex> - } +func.func @complex_tensor_insert( + %arg0: tensor<10xcomplex>) -> tensor<10xcomplex> { + %c1 = arith.constant 1 : index + %real = arith.constant 3.0 : f32 + %imag = arith.constant 2.0 : f32 + %complex = complex.create %real, %imag : complex + %out = tensor.insert %complex into %arg0[%c1] : tensor<10xcomplex> + func.return %out : tensor<10xcomplex> } - // CHECK: @complex_tensor_insert(%[[ARG0:.*]]: !llvm.ptr // CHECK: %[[C:.*]] = complex.create // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> @@ -213,15 +172,12 @@ module { // ----- -module { - func.func @complex_tensor_extract( - %arg0: tensor<10xcomplex>) -> complex { - %c1 = arith.constant 1 : index - %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> - func.return %v2 : complex - } +func.func @complex_tensor_extract( + %arg0: tensor<10xcomplex>) -> complex { + %c1 = arith.constant 1 : index + %v2 = tensor.extract %arg0[%c1] : tensor<10xcomplex> + func.return %v2 : complex } - // CHECK: @complex_tensor_extract(%[[ARG0:.*]]: !llvm.ptr // CHECK: %[[GEP:.*]] = llvm.getelementptr inbounds %[[ARG0]][1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, f32)> // CHECK: %[[LOAD:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> !llvm.struct<(f32, f32)> @@ -229,46 +185,33 @@ module { // ----- -module { - // This example is a bit silly, in real life there wouldn't be a loop (the - // loop body would be executed by different threads). We're just doing it this - // way so control flow with shared memory is tested as well. - func.func @transpose_shared(%in: tensor<32x32xf32>, - %out: tensor<32x32xf32>) -> tensor<32x32xf32> { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - - %shared = xla_gpu.allocate_shared : tensor<32x32xf32> - %loaded_tile = scf.for %i = %c0 to %c32 step %c1 - iter_args(%tile = %shared) -> tensor<32x32xf32> { - %inner_loaded_tile = scf.for %j = %c0 to %c32 step %c1 - iter_args(%inner_tile = %tile) -> tensor<32x32xf32> { - %v = tensor.extract %in[%i, %j] : tensor<32x32xf32> - %inserted = tensor.insert %v into %inner_tile[%i, %j] - : tensor<32x32xf32> - scf.yield %inserted : tensor<32x32xf32> - } - scf.yield %inner_loaded_tile : tensor<32x32xf32> - } - - %synced = xla_gpu.sync_threads %shared : tensor<32x32xf32> - %written_tile = scf.for %i = %c0 to %c32 step %c1 - iter_args(%written = %out) -> tensor<32x32xf32> { - %inner_written_tile = scf.for %j = %c0 to %c32 step %c1 - iter_args(%inner_written = %written) -> tensor<32x32xf32> { - %v = tensor.extract %shared[%j, %i] : tensor<32x32xf32> - %inserted = tensor.insert %v into %inner_written[%i, %j] - : tensor<32x32xf32> - scf.yield %inserted : tensor<32x32xf32> - } - scf.yield %inner_written_tile : tensor<32x32xf32> - } - - return %written_tile : tensor<32x32xf32> +// This example is a bit silly, in real life there wouldn't be a loop (the +// loop body would be executed by different threads). We're just doing it this +// way so control flow with shared memory is tested as well. +func.func @transpose_shared(%in: tensor<1024xf32>, + %out: tensor<1024xf32>) -> tensor<1024xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c1024 = arith.constant 1024 : index + + %shared = xla_gpu.allocate_shared : tensor<1024xf32> + %loaded_tile = scf.for %i = %c0 to %c1024 step %c1 + iter_args(%tile = %shared) -> tensor<1024xf32> { + %v = tensor.extract %in[%i] : tensor<1024xf32> + %inserted = tensor.insert %v into %tile[%i] : tensor<1024xf32> + scf.yield %inserted : tensor<1024xf32> } -} + %synced = xla_gpu.sync_threads %shared : tensor<1024xf32> + %written_tile = scf.for %i = %c0 to %c1024 step %c1 + iter_args(%written = %out) -> tensor<1024xf32> { + %v = tensor.extract %shared[%i] : tensor<1024xf32> + %inserted = tensor.insert %v into %written[%i] : tensor<1024xf32> + scf.yield %inserted : tensor<1024xf32> + } + + return %written_tile : tensor<1024xf32> +} // CHECK: llvm.mlir.global private @[[SHARED:shared_.*]]() // CHECK-SAME: {addr_space = 3 : i32} : !llvm.array<1024 x f32> // CHECK: @transpose_shared @@ -276,30 +219,24 @@ module { // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[ADDR]] // CHECK-SAME: : !llvm.ptr<3> to !llvm.ptr // CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] -// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]] +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.store {{.*}} %[[ELEM_ADDR]] // CHECK: gpu.barrier // CHECK: scf.for -// CHECK: scf.for -// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] -// CHECK: llvm.load %[[ELEM_ADDR]] +// CHECK: %[[ELEM_ADDR:.*]] = llvm.getelementptr inbounds %[[CAST]] +// CHECK: llvm.load %[[ELEM_ADDR]] // ----- -module { - func.func @atomic_rmw_f32(%in: tensor<2x4xf32>, %i: index, %j: index) - -> (tensor<2x4xf32>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %c42 = arith.constant 1.0 : f32 - %add = arith.minimumf %current, %c42 : f32 - xla_gpu.yield %add : f32 - } - return %ret : tensor<2x4xf32> +func.func @atomic_rmw_f32(%in: tensor<8xf32>, %i: index) -> (tensor<8xf32>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %c42 = arith.constant 1.0 : f32 + %add = arith.minimumf %current, %c42 : f32 + xla_gpu.yield %add : f32 } + return %ret : tensor<8xf32> } - // CHECK: @atomic_rmw_f32 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[INIT:.*]] = llvm.load %[[ADDR]] @@ -309,19 +246,16 @@ module { // ----- -module { - func.func @atomic_rmw_f16(%in: tensor<2x4xf16>, %i: index, %j: index) - -> (tensor<2x4xf16>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - %c1 = arith.constant 1.0 : f16 - %add = arith.addf %current, %c1 : f16 - xla_gpu.yield %add : f16 - } - return %ret : tensor<2x4xf16> +func.func @atomic_rmw_f16(%in: tensor<8xf16>, %i: index) + -> (tensor<8xf16>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + %c1 = arith.constant 1.0 : f16 + %add = arith.addf %current, %c1 : f16 + xla_gpu.yield %add : f16 } + return %ret : tensor<8xf16> } - // CHECK: @atomic_rmw_f16 // CHECK: %[[ADDR:.*]] = llvm.getelementptr // CHECK-NEXT: %[[ADDR_INT:.*]] = llvm.ptrtoint %[[ADDR]] @@ -342,16 +276,14 @@ module { // ----- -module { - func.func @atomic_rmw_overwrite(%in: tensor<2x4xf16>, %i: index, %j: index) - -> (tensor<2x4xf16>) { - %c1 = arith.constant 1.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - xla_gpu.yield %c1 : f16 - } - return %ret : tensor<2x4xf16> +func.func @atomic_rmw_overwrite(%in: tensor<8xf16>, %i: index) + -> (tensor<8xf16>) { + %c1 = arith.constant 1.0 : f16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + xla_gpu.yield %c1 : f16 } + return %ret : tensor<8xf16> } // CHECK: @atomic_rmw_overwrite // CHECK: %[[ADDR:.*]] = llvm.getelementptr @@ -370,26 +302,21 @@ module { // ----- -module { - func.func @shared_complex() -> tensor<10xcomplex> { - %shared = xla_gpu.allocate_shared : tensor<10xcomplex> - return %shared : tensor<10xcomplex> - } +func.func @shared_complex() -> tensor<10xcomplex> { + %shared = xla_gpu.allocate_shared : tensor<10xcomplex> + return %shared : tensor<10xcomplex> } - // CHECK: llvm.mlir.global private @{{.*}}() {addr_space = 3 : i32} : !llvm.array<10 x struct<(f32, f32)>> // CHECK: @shared_complex // ----- -module { - func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) -> tensor<10xi4> { - %v = tensor.extract %arg[%i] : tensor<10xi4> - %r = tensor.insert %v into %arg[%j] : tensor<10xi4> - return %r : tensor<10xi4> - } +func.func @i4_load_store(%arg: tensor<10xi4>, %i: index, %j: index) + -> tensor<10xi4> { + %v = tensor.extract %arg[%i] : tensor<10xi4> + %r = tensor.insert %v into %arg[%j] : tensor<10xi4> + return %r : tensor<10xi4> } - // CHECK: @i4_load_store // CHECK: llvm.getelementptr // CHECK-SAME: -> !llvm.ptr, i8 @@ -401,16 +328,14 @@ module { // ----- -module { - func.func @direct_atomic_rmw_overwrite(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_overwrite(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_overwrite // CHECK: %[[C2:.*]] = arith.constant 2 @@ -419,17 +344,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_addi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.addi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_addi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.addi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_addi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -438,17 +361,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maxsi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.maxsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_maxsi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.maxsi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_maxsi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -457,17 +378,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maxui(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.maxui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_maxui(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.maxui %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_maxui // CHECK: %[[C2:.*]] = arith.constant 2 @@ -476,17 +395,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_minsi(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.minsi %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_minsi(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.minsi %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_minsi // CHECK: %[[C2:.*]] = arith.constant 2 @@ -495,17 +412,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_minui(%in: tensor<2x4xi32>, - %i: index, %j: index) -> (tensor<2x4xi32>) { - %c2 = arith.constant 2 : i32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xi32> { - ^bb0(%current : i32): - %min = arith.minui %current, %c2 : i32 - xla_gpu.yield %c2 : i32 - } - return %ret : tensor<2x4xi32> +func.func @direct_atomic_rmw_minui(%in: tensor<8xi32>, + %i: index) -> (tensor<8xi32>) { + %c2 = arith.constant 2 : i32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xi32> { + ^bb0(%current : i32): + %min = arith.minui %current, %c2 : i32 + xla_gpu.yield %c2 : i32 } + return %ret : tensor<8xi32> } // CHECK: @direct_atomic_rmw_minui // CHECK: %[[C2:.*]] = arith.constant 2 @@ -514,17 +429,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f32(%in: tensor<2x4xf32>, - %i: index, %j: index) -> (tensor<2x4xf32>) { - %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %min = arith.addf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 - } - return %ret : tensor<2x4xf32> +func.func @direct_atomic_rmw_fadd_f32(%in: tensor<8xf32>, + %i: index) -> (tensor<8xf32>) { + %c2 = arith.constant 2.0 : f32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %min = arith.addf %current, %c2 : f32 + xla_gpu.yield %c2 : f32 } + return %ret : tensor<8xf32> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f32 // CHECK: %[[C2:.*]] = arith.constant 2 @@ -555,17 +468,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f16(%in: tensor<2x4xf16>, - %i: index, %j: index) -> (tensor<2x4xf16>) { - %c2 = arith.constant 2.0 : f16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf16> { - ^bb0(%current : f16): - %min = arith.addf %current, %c2 : f16 - xla_gpu.yield %c2 : f16 - } - return %ret : tensor<2x4xf16> +func.func @direct_atomic_rmw_fadd_f16(%in: tensor<8xf16>, + %i: index) -> (tensor<8xf16>) { + %c2 = arith.constant 2.0 : f16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf16> { + ^bb0(%current : f16): + %min = arith.addf %current, %c2 : f16 + xla_gpu.yield %c2 : f16 } + return %ret : tensor<8xf16> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f16 // CHECK-NOT: llvm.atomicrmw fadd @@ -591,17 +502,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<2x4xbf16>, - %i: index, %j: index) -> (tensor<2x4xbf16>) { - %c2 = arith.constant 2.0 : bf16 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xbf16> { - ^bb0(%current : bf16): - %min = arith.addf %current, %c2 : bf16 - xla_gpu.yield %c2 : bf16 - } - return %ret : tensor<2x4xbf16> +func.func @direct_atomic_rmw_fadd_bf16(%in: tensor<8xbf16>, + %i: index) -> (tensor<8xbf16>) { + %c2 = arith.constant 2.0 : bf16 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xbf16> { + ^bb0(%current : bf16): + %min = arith.addf %current, %c2 : bf16 + xla_gpu.yield %c2 : bf16 } + return %ret : tensor<8xbf16> } // CHECK-LABEL: @direct_atomic_rmw_fadd_bf16 // CHECK-NOT: llvm.atomicrmw fadd @@ -613,17 +522,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_fadd_f64(%in: tensor<2x4xf64>, - %i: index, %j: index) -> (tensor<2x4xf64>) { - %c2 = arith.constant 2.0 : f64 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf64> { - ^bb0(%current : f64): - %min = arith.addf %current, %c2 : f64 - xla_gpu.yield %c2 : f64 - } - return %ret : tensor<2x4xf64> +func.func @direct_atomic_rmw_fadd_f64(%in: tensor<8xf64>, + %i: index) -> (tensor<8xf64>) { + %c2 = arith.constant 2.0 : f64 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf64> { + ^bb0(%current : f64): + %min = arith.addf %current, %c2 : f64 + xla_gpu.yield %c2 : f64 } + return %ret : tensor<8xf64> } // CHECK-LABEL: @direct_atomic_rmw_fadd_f64 // CHECK: %[[C2:.*]] = arith.constant 2 @@ -648,17 +555,15 @@ module { // ----- -module { - func.func @direct_atomic_rmw_maximumf(%in: tensor<2x4xf32>, - %i: index, %j: index) -> (tensor<2x4xf32>) { - %c2 = arith.constant 2.0 : f32 - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xf32> { - ^bb0(%current : f32): - %min = arith.maximumf %current, %c2 : f32 - xla_gpu.yield %c2 : f32 - } - return %ret : tensor<2x4xf32> +func.func @direct_atomic_rmw_maximumf(%in: tensor<8xf32>, + %i: index) -> (tensor<8xf32>) { + %c2 = arith.constant 2.0 : f32 + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xf32> { + ^bb0(%current : f32): + %min = arith.maximumf %current, %c2 : f32 + xla_gpu.yield %c2 : f32 } + return %ret : tensor<8xf32> } // CHECK-LABEL: @direct_atomic_rmw_maximumf @@ -687,18 +592,15 @@ module { // ----- -module { - func.func @atomic_rmw_c32(%in: tensor<2x4xcomplex>, %i: index, %j: index) - -> (tensor<2x4xcomplex>) { - %ret = xla_gpu.atomic_rmw %in[%i, %j] : tensor<2x4xcomplex> { - ^bb0(%current : complex): - %a = complex.add %current, %current : complex - xla_gpu.yield %a : complex - } - return %ret : tensor<2x4xcomplex> +func.func @atomic_rmw_c32(%in: tensor<8xcomplex>, %i: index) + -> (tensor<8xcomplex>) { + %ret = xla_gpu.atomic_rmw %in[%i] : tensor<8xcomplex> { + ^bb0(%current : complex): + %a = complex.add %current, %current : complex + xla_gpu.yield %a : complex } + return %ret : tensor<8xcomplex> } - // CHECK-LABEL: @atomic_rmw_c32 // CHECK: scf.while (%[[ITER_ARG:.*]] = %{{.*}}) : (i64) -> i64 @@ -709,21 +611,18 @@ module { // ----- -module { - func.func @unused_index_switch_results(%i: index) -> index { - %ret, %ret2 = scf.index_switch %i -> tensor<2x4xi32>, tensor<3xf32> - case 0 { - %x, %y = "dummy.op1"() : () -> (tensor<2x4xi32>, tensor<3xf32>) - scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32> - } - default { - %x, %y = "dummy.op2"() : () -> (tensor<2x4xi32>, tensor<3xf32>) - scf.yield %x, %y : tensor<2x4xi32>, tensor<3xf32> - } - return %i : index +func.func @unused_index_switch_results(%i: index) -> index { + %ret, %ret2 = scf.index_switch %i -> tensor<8xi32>, tensor<3xf32> + case 0 { + %x, %y = "dummy.op1"() : () -> (tensor<8xi32>, tensor<3xf32>) + scf.yield %x, %y : tensor<8xi32>, tensor<3xf32> + } + default { + %x, %y = "dummy.op2"() : () -> (tensor<8xi32>, tensor<3xf32>) + scf.yield %x, %y : tensor<8xi32>, tensor<3xf32> } + return %i : index } - // CHECK-LABEL: func.func @unused_index_switch_results // CHECK-SAME: (%[[I:.*]]: index) // CHECK-NEXT: scf.index_switch %[[I]] @@ -738,17 +637,14 @@ module { // ----- -module { - func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { - %c16 = arith.constant 16 : index - %c22 = arith.constant 22 : index - %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32> - %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32> - %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32> - func.return %out2 : tensor<43xf32> - } +func.func @transfer_write(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> tensor<43xf32> { + %c16 = arith.constant 16 : index + %c22 = arith.constant 22 : index + %cst = arith.constant dense<[1.0, 2.0]> : vector<2xf32> + %out = vector.transfer_write %cst, %arg0[%c16] : vector<2xf32>, tensor<43xf32> + %out2 = vector.transfer_write %cst, %out[%c22] : vector<2xf32>, tensor<43xf32> + func.return %out2 : tensor<43xf32> } - // CHECK-LABEL: @transfer_write // CHECK: %[[PTR1:.*]] = llvm.getelementptr inbounds %[[BUF:.*]][16] // CHECK-NEXT: llvm.store %[[CST:.*]], %[[PTR1]] @@ -757,32 +653,26 @@ module { // ----- -module { - func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> { - %c16 = arith.constant 16 : index - %c0 = arith.constant 0.0 : f32 - %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32> - func.return %out : vector<2xf32> - } +func.func @transfer_read(%arg0: tensor<43xf32> {xla.slice_index = 1}) -> vector<2xf32> { + %c16 = arith.constant 16 : index + %c0 = arith.constant 0.0 : f32 + %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xf32>, vector<2xf32> + func.return %out : vector<2xf32> } - // CHECK-LABEL: @transfer_read // CHECK: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16] // CHECK-NEXT: llvm.load %[[PTR]] : !llvm.ptr -> vector<2xf32> // ----- -module { - func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}, - %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> { - %c16 = arith.constant 16 : index - %c22 = arith.constant 22 : index - %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1> - %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1> - func.return %out2 : tensor<43xi1> - } +func.func @transfer_write_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}, + %v1: vector<2xi1>, %v2: vector<2xi1>) -> tensor<43xi1> { + %c16 = arith.constant 16 : index + %c22 = arith.constant 22 : index + %out = vector.transfer_write %v1, %arg0[%c16] : vector<2xi1>, tensor<43xi1> + %out2 = vector.transfer_write %v2, %out[%c22] : vector<2xi1>, tensor<43xi1> + func.return %out2 : tensor<43xi1> } - // CHECK-LABEL: @transfer_write_i1 // CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr // CHECK-SAME: %[[V1:.*]]: vector<2xi1>, %[[V2:.*]]: vector<2xi1>) @@ -795,15 +685,12 @@ module { // ----- -module { - func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> { - %c16 = arith.constant 16 : index - %false = arith.constant false - %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1> - func.return %out : vector<2xi1> - } +func.func @transfer_read_i1(%arg0: tensor<43xi1> {xla.slice_index = 1}) -> vector<2xi1> { + %c16 = arith.constant 16 : index + %false = arith.constant false + %out = vector.transfer_read %arg0[%c16], %false : tensor<43xi1>, vector<2xi1> + func.return %out : vector<2xi1> } - // CHECK-LABEL: @transfer_read_i1 // CHECK-DAG: %[[C0:.*]] = arith.constant dense<0> : vector<2xi8> // CHECK-DAG: %[[PTR:.*]] = llvm.getelementptr inbounds %{{.*}}[16] @@ -811,44 +698,3 @@ module { // CHECK: %[[CAST:.*]] = arith.cmpi ne, %[[LOADED]], %[[C0]] // CHECK: return %[[CAST]] : vector<2xi1> -// ----- - -module { - func.func @transfer_write_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}, - %v1: vector<4xi4>) -> tensor<43xi4> { - %c16 = arith.constant 16 : index - %out = vector.transfer_write %v1, %arg0[%c16] : vector<4xi4>, tensor<43xi4> - func.return %out : tensor<43xi4> - } -} - -// CHECK-LABEL: @transfer_write_i4 -// CHECK-SAME: , %[[V1:.*]]: vector<4xi4> -// CHECK-DAG: %[[A0:.*]] = vector.extract %[[V1]][0] -// CHECK-DAG: %[[A1:.*]] = vector.extract %[[V1]][1] -// CHECK-DAG: %[[A2:.*]] = vector.extract %[[V1]][2] -// CHECK-DAG: %[[A3:.*]] = vector.extract %[[V1]][3] -// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1] -// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0] -// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3] -// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2] - -module { - func.func @transfer_read_i4(%arg0: tensor<43xi4> {xla.slice_index = 1}) -> vector<4xi4> { - %c16 = arith.constant 16 : index - %c0 = arith.constant 0 : i4 - %out = vector.transfer_read %arg0[%c16], %c0 : tensor<43xi4>, vector<4xi4> - func.return %out : vector<4xi4> - } -} - -// CHECK-LABEL: @transfer_read_i4 -// CHECK: %[[LOADED:.*]] = llvm.load -// CHECK-DAG: %[[A0:.*]] = vector.extract %[[LOADED]][0] -// CHECK-DAG: %[[A1:.*]] = vector.extract %[[LOADED]][1] -// CHECK-DAG: %[[A2:.*]] = vector.extract %[[LOADED]][2] -// CHECK-DAG: %[[A3:.*]] = vector.extract %[[LOADED]][3] -// CHECK-DAG: vector.insert %[[A0]], {{.*}}[1] -// CHECK-DAG: vector.insert %[[A1]], {{.*}}[0] -// CHECK-DAG: vector.insert %[[A2]], {{.*}}[3] -// CHECK-DAG: vector.insert %[[A3]], {{.*}}[2] diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir new file mode 100644 index 00000000000000..f0de25a74b7f8c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_loops_to_scf.mlir @@ -0,0 +1,52 @@ +// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-loops-to-scf | FileCheck %s + +#map = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), + domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32], s0 + s1 in [0, 90]> +func.func @loop_op(%input: tensor<1024x32xf32>, %init: f32, %dim: index) -> (f32) { + %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<1024x32xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } {xla.range = [0 : index, 42 : index]} + func.return %sum : f32 +} + +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0 + s1), +// CHECK-SAME: domain: d0 in [0, 3], s0 in [0, 1024], s1 in [0, 32]> + +// CHECK-LABEL: func.func @loop_op( +// CHECK-SAME: %[[IN:.*]]: tensor<1024x32xf32>, +// CHECK-SAME: %[[INIT:.*]]: f32, %[[DIM:.*]]: index) -> f32 { + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C33:.*]] = arith.constant 33 : index +// CHECK-DAG: %[[C90:.*]] = arith.constant 90 : index +// CHECK-DAG: %[[C1025:.*]] = arith.constant 1025 : index + +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C1025]] step %[[C1]] +// CHECK-SAME: iter_args(%[[INIT_:.*]] = %[[INIT]]) -> (f32) { + +// CHECK: %[[INNER_FOR:.*]] = scf.for %[[J:.*]] = %[[C0]] to %[[C33]] +// CHECK-SAME: step %[[C1]] iter_args(%[[INIT__:.*]] = %[[INIT_]]) -> (f32) { + +// CHECK: %[[INDEX:.*]] = xla_gpu.apply_indexing +// CHECK-SAME: #[[$MAP]](%[[DIM]])[%[[I]], %[[J]]] +// CHECK: %[[VAL1:.*]] = arith.cmpi sge, %[[INDEX]], %[[C0]] : index +// CHECK: %[[VAL2:.*]] = arith.cmpi sle, %[[INDEX]], %[[C90]] : index +// CHECK: %[[VAL3:.*]] = arith.andi %[[VAL1]], %[[VAL2]] : i1 +// CHECK: %[[VAL4:.*]] = arith.cmpi sge, %[[DIM]], %[[C0]] : index +// CHECK: %[[VAL5:.*]] = arith.cmpi sle, %[[DIM]], %[[C3]] : index +// CHECK: %[[VAL6:.*]] = arith.andi %[[VAL4]], %[[VAL5]] : i1 +// CHECK: %[[INBOUNDS:.*]] = arith.andi %[[VAL3]], %[[VAL6]] : i1 +// CHECK: %[[IF_RESULT:.*]] = scf.if %[[INBOUNDS]] -> (f32) { +// CHECK: %[[ELEM:.*]] = tensor.extract %[[IN]][%[[I]], %[[J]]] +// CHECK: %[[SUM:.*]] = arith.addf %[[INIT__]], %[[ELEM]] : f32 +// CHECK: scf.yield %[[SUM]] : f32 +// CHECK: } else { +// CHECK: scf.yield %[[INIT__]] : f32 +// CHECK: } +// CHECK: scf.yield %[[IF_RESULT]] : f32 +// CHECK: } +// CHECK: scf.yield %[[INNER_FOR]] : f32 diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir similarity index 57% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir index 645430ae0d1bcc..2f9494aeb4d6e6 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/lower_xla_gpu_to_scf.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/lower_xla_gpu_to_scf.mlir @@ -1,16 +1,14 @@ -// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf | FileCheck %s +// RUN: mlir_fusions_opt %s -xla-gpu-lower-xla-gpu-to-scf --split-input-file \ +// RUN: | FileCheck %s -module { - func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { - return %a, %b : f32, i32 - } - - func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { - %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32 - return %ret#0, %ret#1 : f32, i32 - } +func.func @reducer(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { + return %a, %b : f32, i32 } +func.func @shuffler(%a: f32, %b: i32) -> (f32, i32) { + %ret:2 = xla_gpu.shuffle_reduce @reducer(%a, %b) to 4 : f32, i32 + return %ret#0, %ret#1 : f32, i32 +} // CHECK: @shuffler(%[[A:.*]]: f32, %[[B:.*]]: i32) // CHECK-DAG: %[[C1:.*]] = arith.constant 1 // CHECK-DAG: %[[C2:.*]] = arith.constant 2 @@ -29,83 +27,68 @@ module { // ----- -module { - func.func @reducer(%a: f64, %b: f64) -> f64 { - return %a : f64 - } - - func.func @shuffler(%a: f64) -> f64 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64 - return %ret : f64 - } +func.func @reducer(%a: f64, %b: f64) -> f64 { + return %a : f64 } +func.func @shuffler(%a: f64) -> f64 { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : f64 + return %ret : f64 +} // CHECK: @shuffler(%[[A:.*]]: f64 // CHECK: gpu.shuffle down {{.*}}, %[[C1]] // CHECK: gpu.shuffle down {{.*}}, %[[C1]] // ----- -module { - func.func @reducer(%a: complex, %b: complex) -> complex { - return %a : complex - } - - func.func @shuffler(%a: complex) -> complex { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex - return %ret : complex - } +func.func @reducer(%a: complex, %b: complex) -> complex { + return %a : complex } +func.func @shuffler(%a: complex) -> complex { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : complex + return %ret : complex +} // CHECK: @shuffler // CHECK-COUNT-4: gpu.shuffle down {{.*}}, %[[C1]] // ----- -module { - func.func @reducer(%a: ui64, %b: ui64) -> ui64 { - return %a : ui64 - } - - func.func @shuffler(%a: ui64) -> ui64 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64 - return %ret : ui64 - } +func.func @reducer(%a: ui64, %b: ui64) -> ui64 { + return %a : ui64 } +func.func @shuffler(%a: ui64) -> ui64 { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : ui64 + return %ret : ui64 +} // CHECK: @shuffler // CHECK: unrealized_conversion_cast // CHECK-COUNT-2: gpu.shuffle down {{.*}}, %[[C1]] // ----- -module { - func.func @reducer(%a: i8, %b: i8) -> i8 { - return %a : i8 - } - - func.func @shuffler_i8(%a: i8) -> i8 { - %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8 - return %ret : i8 - } +func.func @reducer(%a: i8, %b: i8) -> i8 { + return %a : i8 } +func.func @shuffler_i8(%a: i8) -> i8 { + %ret = xla_gpu.shuffle_reduce @reducer(%a) to 1 : i8 + return %ret : i8 +} // CHECK: @shuffler_i8( // CHECK-NOT: vector // CHECK-COUNT-1: gpu.shuffle down {{.*}}, %[[C1]] // ----- -module { - func.func @predicated_insert( - %v: i32, %tensor: tensor<2xi32>, %index: index, - %cond: i1) -> tensor<2xi32> { - %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond - : tensor<2xi32> - return %ret : tensor<2xi32> - } +func.func @predicated_insert( + %v: i32, %tensor: tensor<2xi32>, %index: index, + %cond: i1) -> tensor<2xi32> { + %ret = xla_gpu.predicated_insert %v into %tensor[%index] if %cond + : tensor<2xi32> + return %ret : tensor<2xi32> } - // CHECK: @predicated_insert // CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>, // CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1 @@ -119,16 +102,13 @@ module { // ----- -module { - func.func @predicated_extract( - %v: i32, %tensor: tensor<2xi32>, %index: index, - %cond: i1) -> i32 { - %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v - : tensor<2xi32> - return %ret : i32 - } +func.func @predicated_extract( + %v: i32, %tensor: tensor<2xi32>, %index: index, + %cond: i1) -> i32 { + %ret = xla_gpu.predicated_extract %tensor[%index] if %cond else %v + : tensor<2xi32> + return %ret : i32 } - // CHECK: @predicated_extract // CHECK-SAME: %[[V:.*]]: i32, %[[TENSOR:.*]]: tensor<2xi32>, // CHECK-SAME: %[[INDEX:.*]]: index, %[[COND:.*]]: i1 diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/merge_pointers_to_same_slice.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/merge_pointers_to_same_slice.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir similarity index 77% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir index 6f903f3ace4748..cb6c0486aa8a55 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/optimize_loops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/optimize_loops.mlir @@ -1,8 +1,11 @@ // RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-optimize-loops | FileCheck %s -#map = affine_map<(d0) -> (d0 floordiv 8)> -#map1 = affine_map<(d0) -> (d0 mod 8)> -#map2 = affine_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512)> +#map = #xla_gpu.indexing_map<(d0) -> (d0 floordiv 8), + domain: d0 in [0, 31]> +#map1 = #xla_gpu.indexing_map<(d0) -> (d0 mod 8), + domain: d0 in [0, 31]> +#map2 = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), + domain: d0 in [0, 1], d1 in [0, 255], s0 in [0, 7]> module { func.func @fully_unroll(%arg0: tensor<4x8x4096xf32>, %arg1: tensor<4096xbf16>, %arg2: tensor<4x8xf32>, %arg3: tensor<4096xbf16>, @@ -21,23 +24,23 @@ module { %1 = arith.cmpi eq, %0, %c0 : index %2 = arith.divui %thread_id_x, %c32 : index %3 = arith.cmpi ult, %thread_id_x, %c8 : index - %4 = xla_gpu.apply_indexing #map(%block_id_x in [0, 31]) - %5 = xla_gpu.apply_indexing #map1(%block_id_x in [0, 31]) + %4 = xla_gpu.apply_indexing #map(%block_id_x) + %5 = xla_gpu.apply_indexing #map1(%block_id_x) %extracted = tensor.extract %arg2[%4, %5] : tensor<4x8xf32> %6 = arith.mulf %extracted, %cst : f32 %7 = arith.addf %6, %cst : f32 %8 = math.rsqrt %7 : f32 %9:2 = scf.for %arg7 = %c0 to %c8 step %c1 iter_args(%arg8 = %arg6, %arg9 = %cst) -> (tensor<4x8x4096xf32>, f32) { - %18 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %18 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %19 = vector.transfer_read %arg1[%18], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %20 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %20 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %21 = vector.transfer_read %arg3[%20], %cst_1 {in_bounds = [true]} : tensor<4096xbf16>, vector<2xbf16> - %22 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %22 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %23 = vector.transfer_read %arg4[%4, %5, %22], %cst_1 {in_bounds = [true]} : tensor<4x8x4096xbf16>, vector<2xbf16> - %24 = xla_gpu.apply_indexing #map2(%c0 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %24 = xla_gpu.apply_indexing #map2(%c0, %thread_id_x)[%arg7] %25 = vector.transfer_read %arg0[%4, %5, %24], %cst {in_bounds = [true]} : tensor<4x8x4096xf32>, vector<2xf32> %26:2 = scf.for %arg10 = %c0 to %c2 step %c1 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<4x8x4096xf32>, f32) { - %27 = xla_gpu.apply_indexing #map2(%arg10 in [0, 1], %thread_id_x in [0, 255])[%arg7 in [0, 7]] + %27 = xla_gpu.apply_indexing #map2(%arg10, %thread_id_x)[%arg7] %28 = vector.extract %25[%arg10] : f32 from vector<2xf32> %29 = vector.extract %23[%arg10] : bf16 from vector<2xbf16> %30 = arith.extf %29 : bf16 to f32 @@ -124,7 +127,7 @@ module { } } -// CHECK: #[[$MAP:.*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_extract // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C30:.*]] = arith.constant 30 : index @@ -151,7 +154,7 @@ module { %cst = arith.constant dense<[0.0, 0.0]> : vector<2xf32> %cst0 = arith.constant 0.0 : f32 %ret = scf.for %i = %c0 to %c17 step %c1 iter_args (%iter = %cst) -> (vector<2xf32>) { - %base = xla_gpu.apply_indexing affine_map<(d0) -> (d0 * 2)>(%i in [0, 15]) + %base = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 15]>(%i) %val = vector.transfer_read %arg[%base], %cst0 : tensor<34xf32>, vector<2xf32> %log = math.log %val : vector<2xf32> %add = arith.addf %log, %iter : vector<2xf32> @@ -161,8 +164,8 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 2)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 + 1), // CHECK-LABEL: @pipeline_transfer // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index @@ -180,3 +183,29 @@ module { // CHECK: math.log %[[VAL]] // CHECK: %[[ADD:.*]] = arith.addf // CHECK: yield %[[ADD]], %[[NEXT_VAL]] + +// ----- + +module { + func.func @sequential_extract(%arg0: tensor<6xindex>, %arg1: tensor<22xindex>) -> (index) { + %c1 = arith.constant 1 : index + %c733 = arith.constant 733 : index + %c0 = arith.constant 0 : index + %2 = scf.for %i = %c0 to %c733 step %c1 iter_args(%x = %c1) -> (index) { + %extracted = tensor.extract %arg0[%i] : tensor<6xindex> + %extracted_1 = tensor.extract %arg1[%extracted] : tensor<22xindex> + scf.yield %extracted_1 : index + } + return %2 : index + } +} + +// Once `extracted` is pipelined, it becomes an iter arg, so `extracted_1` is +// extract %arg1[%arg]. While it is possible to pipeline this in principle, we +// do not currently do this. + +// CHECK-LABEL: @sequential_extract +// CHECK-SAME: (%[[ARG0:.*]]: tensor<6xindex>, %[[ARG1:.*]]: tensor<22xindex>) +// CHECK: tensor.extract %[[ARG0]] +// CHECK-NOT: tensor.extract +// CHECK: scf.for diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir new file mode 100644 index 00000000000000..20442544b890db --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/peel_loops.mlir @@ -0,0 +1,88 @@ +// RUN: mlir_fusions_opt -split-input-file %s -xla-gpu-peel-loops \ +// RUN: | FileCheck %s + +#map = #xla_gpu.indexing_map< + (d0)[s0, s1] -> (s0, s1), + domain: + d0 in [0, 3], + s0 in [0, 7], + s1 in [0, 10], + d0 + s0 in [0, 9], + d0 + s1 in [0, 12] +> +func.func @peel_both_loops(%input: tensor<16x32xf32>, + %init: f32, %dim: index) -> (f32) { + %sum = xla_gpu.loop (%dim)[%i, %j] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i, %j] : tensor<16x32xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } + func.return %sum : f32 +} +// CHECK: #[[$PEELED_MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, s1), domain: d0 in [0, 3], s0 in [0, 6], s1 in [0, 9]> +// CHECK: #[[$TAIL_MAP0:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (7, s1), domain: d0 in [0, 2], s0 in [7, 7], s1 in [0, 9]> +// CHECK: #[[$TAIL_MAP1:.*]] = #xla_gpu.indexing_map<(d0)[s0, s1] -> (s0, 10), domain: d0 in [0, 2], s0 in [0, 7], s1 in [10, 10]> + +// CHECK-LABEL: func.func @peel_both_loops( +// CHECK-SAME: %[[INPUT:.*]]: tensor<16x32xf32>, +// CHECK-SAME: %[[INIT:.*]]: f32, %[[DIM:.*]]: index) + +// CHECK: %[[PEELED:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] +// CHECK-SAME: in #[[$PEELED_MAP]] iter_args(%[[INIT_:.*]] = %[[INIT]]) +// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] : tensor<16x32xf32> +// CHECK: arith.addf %[[INIT_]] + +// CHECK: %[[TAIL0:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] +// CHECK-SAME: in #[[$TAIL_MAP0]] iter_args(%[[INIT_:.*]] = %[[PEELED]]) +// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] +// CHECK: arith.addf %[[INIT_]] + +// CHECK: %[[TAIL1:.*]] = xla_gpu.loop (%[[DIM]])[%[[I:.*]], %[[J:.*]]] +// CHECK-SAME: in #[[$TAIL_MAP1]] iter_args(%[[INIT_:.*]] = %[[TAIL0]]) +// CHECK: tensor.extract %[[INPUT]][%[[I]], %[[J]]] +// CHECK: arith.addf %[[INIT_]] + +// CHECK: return %[[TAIL1]] : f32 + +// ----- + +#map = #xla_gpu.indexing_map< + (d0)[s0] -> (s0), + domain: + d0 in [0, 3], + s0 in [0, 7] +> +func.func @not_constrained_symbol(%input: tensor<16xf32>, %init: f32, + %dim: index) -> (f32) { + %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i] : tensor<16xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } + func.return %sum : f32 +} +// CHECK-LABEL: func.func @not_constrained_symbol +// CHECK: xla_gpu.loop +// CHECK-NOT: xla_gpu.loop + +// ----- + +#map = #xla_gpu.indexing_map< + (d0)[s0] -> (s0), + domain: + d0 in [0, 3], + s0 in [0, 7], + s0 mod 5 in [0, 1] +> +func.func @constraint_exists_after_peeling(%input: tensor<16xf32>, %init: f32, + %dim: index) -> (f32) { + %sum = xla_gpu.loop (%dim)[%i] in #map iter_args(%sum_ = %init) -> (f32) { + %t = tensor.extract %input[%i] : tensor<16xf32> + %add = arith.addf %sum_, %t : f32 + xla_gpu.yield %add : f32 + } + func.return %sum : f32 +} +// CHECK-LABEL: func.func @constraint_exists_after_peeling +// CHECK: xla_gpu.loop +// CHECK-NOT: xla_gpu.loop \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/propagate_slice_indices.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/propagate_slice_indices.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir index ec1a726da9db13..d51566a5b3dace 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_affine.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_affine.mlir @@ -62,8 +62,9 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt %0 = gpu.thread_id x %1 = gpu.block_id x scf.for %i = %c0 to %c4 step %c1 { - %2 = xla_gpu.apply_indexing affine_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4))> - [%1 in [0, 3071], %0 in [0, 127], %i in [0, 3]] + %2 = xla_gpu.apply_indexing + #xla_gpu.indexing_map<()[s0, s1, s2] -> (s0 * 512 + s1 * 4 + s2 + (s1 floordiv 128) + (s2 floordiv 4)), + domain: s0 in [0, 3071], s1 in [0, 127], s2 in [0, 3]>[%1, %0, %i] %3 = arith.index_castui %2 : index to i64 %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i64) -> !llvm.ptr, f32 %5 = llvm.load %4 invariant : !llvm.ptr -> f32 @@ -91,8 +92,8 @@ func.func @op_and_for_ranges(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.pt func.func @arg_ranges(%arg0: index, %arg1: index) -> index { %0 = xla_gpu.apply_indexing - affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100)> - [%arg0 in [0, 42], %arg1 in [0, 1000]] + #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100), + domain: s0 in [0, 42], s1 in [0, 1000]>[%arg0, %arg1] return %0 : index } @@ -105,8 +106,8 @@ func.func @arg_ranges(%arg0: index, %arg1: index) -> index { func.func @cant_lower(%arg0: index, %arg1: index) -> (index, index) { %0:2 = xla_gpu.apply_indexing - affine_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1)> - [%arg0 in [-10, 42], %arg1 in [0, 1000]] + #xla_gpu.indexing_map<()[s0, s1] -> (s0 floordiv 100 + s1 floordiv 100, s0 + s1), + domain: s0 in [-10, 42], s1 in [0, 1000]>[%arg0, %arg1] return %0#0, %0#1 : index, index } @@ -123,8 +124,8 @@ func.func @order_summands(%arg1: index) { scf.for %arg2 = %c0 to %c4 step %c1 { scf.for %arg3 = %c0 to %c4 step %c1 { %0 = xla_gpu.apply_indexing - affine_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10)> - [%arg2 in [0, 3], %arg1 in [0, 3], %arg3 in [0, 3]] + #xla_gpu.indexing_map<()[s0, s1, s2] -> ((s0 + s1) floordiv 3 + s0 * 512 + s1 * 4 + s2 * 10), + domain: s0 in [0, 3], s1 in [0, 3], s2 in [0, 3]>[%arg2, %arg1, %arg3] "dummy.op"(%0) : (index) -> () } } diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir similarity index 92% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir index ee2e0ddbe29035..09c8901fab6000 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/tests/simplify_arith.mlir +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/simplify_arith.mlir @@ -247,7 +247,7 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { %c42_f32 = arith.constant 42.0 : f32 %loop = scf.for %i = %c0 to %c3 step %c1 iter_args(%in_ = %tensor) -> (tensor<100xf32>) { - %0 = xla_gpu.apply_indexing affine_map<(d0) -> (d0 mod 4)> (%i in [0, 9]) + %0 = xla_gpu.apply_indexing #xla_gpu.indexing_map<(d0) -> (d0 mod 4), domain: d0 in [0, 9]>(%i) %updated = tensor.insert %c42_f32 into %in_[%0] : tensor<100xf32> scf.yield %updated :tensor<100xf32> } @@ -261,8 +261,10 @@ func.func @refine_constraints(%tensor: tensor<100xf32>) -> tensor<100xf32> { // ----- -#map = affine_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000)> -#map1 = affine_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9)> +#map = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> (((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768) mod 2400000), + domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 73], s1 in [0, 3]> +#map1 = #xla_gpu.indexing_map<(d0, d1)[s0] -> ((d0 * 4 + d1 * 512 + s0) mod 9), + domain: d0 in [0, 127], d1 in [0, 575], s0 in [0, 3]> func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, %arg1: tensor<2400000x9xf32>) -> tensor<2400000x9xf32> { %c0 = arith.constant 0 : index @@ -276,10 +278,8 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, -> (tensor<2400000x9xf32>) { %2 = scf.for %j = %c0 to %c4 step %c1 iter_args(%arg5 = %arg3) -> (tensor<2400000x9xf32>) { - %3 = xla_gpu.apply_indexing #map(%th_x in [0, 127], %bl_x in [0, 575]) - [%i in [0, 73], %j in [0, 3]] - %4 = xla_gpu.apply_indexing #map1(%th_x in [0, 127], %bl_x in [0, 575]) - [%j in [0, 3]] + %3 = xla_gpu.apply_indexing #map(%th_x, %bl_x)[%i, %j] + %4 = xla_gpu.apply_indexing #map1(%th_x, %bl_x)[%j] %inserted = tensor.insert %c42_f32 into %arg5[%3, %4] : tensor<2400000x9xf32> scf.yield %inserted : tensor<2400000x9xf32> @@ -288,5 +288,5 @@ func.func @refine_constraints_for_symbol(%arg0: tensor<2400000x9xf32>, } return %0 : tensor<2400000x9xf32> } -// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768)> +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0, s1] -> ((d0 * 4 + d1 * 512 + s1) floordiv 9 + s0 * 32768), // CHECK-LABEL: func.func @refine_constraints_for_symbol diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir similarity index 100% rename from third_party/xla/xla/service/gpu/fusions/mlir/tests/unswitch_loops.mlir rename to third_party/xla/xla/service/gpu/fusions/transforms/tests/unswitch_loops.mlir diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir new file mode 100644 index 00000000000000..09769fc382bd58 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/transforms/tests/vectorize_loads_stores.mlir @@ -0,0 +1,421 @@ +// RUN: mlir_fusions_opt -allow-unregistered-dialect %s -split-input-file \ +// RUN: -xla-gpu-vectorize-loads-stores -canonicalize -cse \ +// RUN: | FileCheck %s +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @simple_read(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2), domain: d0 in [0, 63]> +// CHECK-LABEL: @simple_read +// CHECK-SAME: (%[[ARG0:.*]]: tensor +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[ITER:.*]] = +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]]) +// CHECK-NEXT: %[[V:.*]] = vector.transfer_read %[[ARG0]][%[[BASE]]] +// CHECK-NEXT: scf.for %[[J:.*]] = %[[C0]] +// CHECK-NEXT: vector.extract %[[V]][%[[J]]] +// CHECK-NEXT: addf + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 + 1), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @misaligned_indexing_map(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @misaligned_indexing_map +// CHECK-NOT: vector.transfer_read + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 3 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @misaligned_indexing_map_2(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @misaligned_indexing_map_2 +// CHECK-NOT: vector.transfer_read + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> (3 * d0 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @misaligned_shape(%arg0: tensor<192xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<192xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @misaligned_shape +// CHECK-NOT: vector.transfer_read + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> (d0 + s0 * 2), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @wrong_stride(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @wrong_stride +// CHECK-NOT: vector.transfer_read + +// ----- + +// We could vectorize this as a float vector load of double the size, but we +// don't currently. +#map = #xla_gpu.indexing_map<(d0)[s0] -> (2 * d0 + s0), + domain: d0 in [0, 127], s0 in [0, 1]> +func.func @simple_read_complex(%arg0: tensor<128xcomplex>, %i: index) -> (complex) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst = complex.constant [0.0 : f32, 0.0 : f32] : complex + %loop = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter = %cst) -> complex { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xcomplex> + %added = complex.add %iter, %extracted : complex + scf.yield %added : complex + } + return %loop : complex +} + +// CHECK-LABEL: @simple_read_complex +// CHECK-NOT: vector.transfer_read + +// ----- + +// This is vectorizable, but not currently supported. +func.func @layout(%arg0: tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c64 = arith.constant 64 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c64 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %extracted = tensor.extract %arg0[%j, %i] + : tensor<2x64xf32, dense<[0, 1]> : tensor<2xi64>> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @layout +// CHECK-NOT: vector.transfer_read + +// ----- + +func.func @simple_write(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> + scf.yield %inserted : tensor<64xf32> + } + return %loop : tensor<64xf32> +} +// CHECK-LABEL: @simple_write +// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[V:.*]] = scf.for +// CHECK-NEXT: vector.insert +// CHECK-NEXT: scf.yield +// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[V]], %[[ARG0]][%[[C0]]] +// CHECK-NEXT: return %[[WRITTEN]] + +// ----- + +func.func @write_with_use(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %iter[%j] : tensor<64xf32> + "dummy.op1"(%inserted) : (tensor<64xf32>) -> () + scf.yield %inserted : tensor<64xf32> + } + return %loop : tensor<64xf32> +} +// CHECK-LABEL: @write_with_use +// CHECK-NOT: transfer_write + +// ----- + + func.func @write_not_to_iter_arg(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %arg0[%j] : tensor<64xf32> + scf.yield %inserted : tensor<64xf32> + } + return %loop : tensor<64xf32> + } + +// CHECK-LABEL: @write_not_to_iter_arg +// CHECK-NOT: transfer_write + +// ----- + +func.func @write_not_yielded(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 2 : index + %cst = arith.constant 0.0 : f32 + %loop = scf.for %j = %c0 to %c4 step %c1 iter_args(%iter = %arg0) -> tensor<64xf32> { + %inserted = tensor.insert %cst into %arg0[%j] : tensor<64xf32> + scf.yield %arg0 : tensor<64xf32> + } + return %loop : tensor<64xf32> +} +// CHECK-LABEL: @write_not_yielded +// CHECK-NOT: transfer_write + +// ----- + +#map = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d1 * 2 + d0 + s0 * 512), + domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]> +#map1 = #xla_gpu.indexing_map<(d0, d1, d2)[s0] -> (d0 * 32 + d2 * 2 + d1 + s0 * 512), + domain: d0 in [0, 7], d1 in [0, 1], d2 in [0, 255], s0 in [0, 7]> +func.func @multiple(%arg0: tensor<131072xf32>, %arg1: tensor<4096xbf16>, + %arg2: tensor<32xf32>, %arg3: tensor<131072xf32>, + %arg4: index) -> (tensor<131072xf32>, f32) { + %cst = arith.constant 1.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %extracted1 = tensor.extract %arg2[%arg4] : tensor<32xf32> + %0:2 = scf.for %i = %c0 to %c8 step %c1 iter_args(%iter0 = %arg3, %iter1 = %cst) -> (tensor<131072xf32>, f32) { + %1:2 = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter2 = %iter0, %iter3 = %iter1) -> (tensor<131072xf32>, f32) { + %2 = xla_gpu.apply_indexing #map(%j, %arg4)[%i] + %idx = xla_gpu.apply_indexing #map1(%i, %j, %arg4)[%i] + %extracted2 = tensor.extract %arg0[%idx] : tensor<131072xf32> + %extracted3 = tensor.extract %arg1[%2] : tensor<4096xbf16> + %3 = arith.extf %extracted3 : bf16 to f32 + %4 = arith.addf %extracted2, %3 : f32 + %5 = arith.addf %extracted1, %4 : f32 + %6 = arith.addf %iter3, %5 : f32 + %inserted = tensor.insert %5 into %iter2[%idx] : tensor<131072xf32> + scf.yield %inserted, %6 : tensor<131072xf32>, f32 + } + scf.yield %1#0, %1#1 : tensor<131072xf32>, f32 + } + return %0#0, %0#1 : tensor<131072xf32>, f32 +} +// CHECK-DAG: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0 * 512), domain: d0 in [0, 255], s0 in [0, 7]> +// CHECK-DAG: #[[$MAP1:.*]] = #xla_gpu.indexing_map<(d0, d1)[s0] -> (d0 * 32 + d1 * 2 + s0 * 512), domain: d0 in [0, 7], d1 in [0, 255], s0 in [0, 7]> +// CHECK-LABEL: @multiple +// CHECK-SAME: (%[[ARG0:.*]]: tensor{{.*}}, %[[ARG1:.*]]: tensor{{.*}}, %[[ARG2:.*]]: tensor{{.*}}, %[[ARG3:.*]]: tensor{{.*}}, %[[ARG4:.*]]: index) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] +// CHECK-DAG: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[ARG4]])[%[[I]]] +// CHECK-DAG: %[[IDX:.*]] = xla_gpu.apply_indexing #[[$MAP1]](%[[I]], %[[ARG4]])[%[[I]]] +// CHECK: %[[READ1:.*]] = vector.transfer_read %[[ARG1]][%[[BASE]]] +// CHECK: %[[READ2:.*]] = vector.transfer_read %[[ARG0]][%[[IDX]]] +// CHECK: %[[INNER:.*]]:2 = scf.for %[[J:.*]] = %[[C0]] {{.*}} iter_args(%[[F:.*]] = {{.*}}, %[[V:.*]] = {{.*}}) -> (f32, vector<2xf32>) +// CHECK-DAG: vector.extract %[[READ1]][%[[J]]] +// CHECK-DAG: vector.extract %[[READ2]][%[[J]]] +// CHECK: extf +// CHECK-NEXT: addf +// CHECK-NEXT: %[[TO_INSERT:.*]] = arith.addf +// CHECK-NEXT: %[[TO_YIELD:.*]] = arith.addf +// CHECK-NEXT: %[[V_NEXT:.*]] = vector.insert %[[TO_INSERT]], %[[V]] [%[[J]]] +// CHECK-NEXT: scf.yield %[[TO_YIELD]], %[[V_NEXT]] +// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[INNER]]#1, %{{.*}}[%[[IDX]]] +// CHECK: scf.yield %[[WRITTEN]], %[[INNER]]#0 + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 64 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @remainder_with_modulo(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0) -> ((d0 mod 16) * 4), +// CHECK-LABEL: @remainder_with_modulo +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: scf.for %[[I:.*]] = %[[C0]] +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP]](%[[I]] +// CHECK: vector.transfer_read {{.*}}[%[[BASE]]] + +// ----- + +#map = #xla_gpu.indexing_map<(d0)[s0] -> ((d0 * 4) mod 65 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +func.func @remainder_with_modulo_misaligned(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map(%i)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 +} +// CHECK-LABEL: @remainder_with_modulo_misaligned +// CHECK-NOT: vector.transfer_read + +// ----- + +#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), + domain: d0 in [0, 63]> +#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +module { + func.func @apply_indexing_sequence(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %offset = xla_gpu.apply_indexing #map0(%i) + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + %idx = xla_gpu.apply_indexing #map1(%offset)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 + } +} + +// CHECK: #[[$MAP0:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 2 + 10), +// CHECK-SAME: domain: d0 in [0, 63]> +// CHECK-LABEL: @apply_indexing_sequence +// CHECK: %[[BASE:.*]] = xla_gpu.apply_indexing #[[$MAP0]] +// CHECK: vector.transfer_read {{.*}}[%[[BASE]]] + +// ----- + + +#map0 = #xla_gpu.indexing_map<(d0) -> (d0 + 5), + domain: d0 in [0, 63]> +#map1 = #xla_gpu.indexing_map<(d0)[s0] -> (d0 * 2 + s0), + domain: d0 in [0, 63], s0 in [0, 1]> +module { + func.func @apply_indexing_sequence_same_block(%arg0: tensor<128xf32>) -> (f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c63 = arith.constant 63 : index + %cst = arith.constant 0.0 : f32 + %outer = scf.for %i = %c0 to %c63 step %c1 iter_args(%iter = %cst) -> f32 { + %inner = scf.for %j = %c0 to %c2 step %c1 iter_args(%iter1 = %iter) -> f32 { + // Usually, this will be hoisted by LICM or folded, so we do not detect + // this pattern. + %offset = xla_gpu.apply_indexing #map0(%i) + %idx = xla_gpu.apply_indexing #map1(%offset)[%j] + %extracted = tensor.extract %arg0[%idx] : tensor<128xf32> + %added = arith.addf %iter1, %extracted : f32 + scf.yield %added : f32 + } + scf.yield %inner : f32 + } + return %outer : f32 + } +} + +// CHECK-LABEL: @apply_indexing_sequence_same_block +// CHECK-NOT: vector.transfer_read diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc index 7d963f31292b4d..d514a678624162 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/unswitch_loops.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/unswitch_loops.cc @@ -30,7 +30,7 @@ namespace xla { namespace gpu { #define GEN_PASS_DEF_UNSWITCHLOOPSPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" namespace { diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc similarity index 87% rename from third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc rename to third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc index 00079845867fb0..9795a96e387f53 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/vectorize_loads_stores.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/vectorize_loads_stores.cc @@ -40,15 +40,16 @@ limitations under the License. #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" namespace xla { namespace gpu { +namespace { #define GEN_PASS_DEF_VECTORIZELOADSANDSTORESPASS -#include "xla/service/gpu/fusions/mlir/passes.h.inc" +#include "xla/service/gpu/fusions/transforms/passes.h.inc" -namespace { +using mlir::Value; // Tries to find the stride of a symbol or dimension in an affine expression. // Returns std::nullopt if the stride could not be determined. @@ -120,12 +121,15 @@ int64_t GetAlignmentOfRemainder(mlir::AffineExpr expr, // - checks that the upper bound is 2 or 4. // Returns a vector type with the given upper bound and the tensor's element // type. +// All tensors are 1D after flatten-tensors pass. mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type, mlir::scf::ForOp loop) { - // TODO(jreiffers): Support layouts. if (tensor_type.getEncoding()) { return nullptr; } + if (tensor_type.getRank() != 1) { + return nullptr; + } if (!mlir::VectorType::isValidElementType(tensor_type.getElementType())) { return nullptr; } @@ -138,37 +142,22 @@ mlir::VectorType GetVectorType(mlir::RankedTensorType tensor_type, if (vector_size != 2 && vector_size != 4) { return nullptr; // Unsupported vector size. } - if (tensor_type.getRank() > 1 && - tensor_type.getShape().back() % *vector_size) { + if (tensor_type.getShape().back() % *vector_size) { return nullptr; // Misaligned start indices. } return mlir::VectorType::get({*vector_size}, tensor_type.getElementType()); } -std::optional> GetVectorBaseIndices( - mlir::ValueRange indices, mlir::scf::ForOp loop, - mlir::VectorType vector_type, mlir::ImplicitLocOpBuilder& b) { - if (indices.empty()) { - return std::nullopt; - } - - // The major dimensions' indices must all be defined outside the loop. - for (int i = 0; i < indices.size() - 1; ++i) { - if (!indices[i].getParentRegion()->isProperAncestor( - &loop.getBodyRegion())) { - return std::nullopt; - } - } - - mlir::Value induction_var = loop.getInductionVar(); - if (indices.back() == induction_var) { - llvm::SmallVector ret = indices; - ret.back() = b.create(0); - return ret; +std::optional GetVectorBaseIndices(Value index, mlir::scf::ForOp loop, + mlir::VectorType vector_type, + mlir::ImplicitLocOpBuilder& b) { + Value induction_var = loop.getInductionVar(); + if (index == induction_var) { + return b.create(0); } auto apply_indexing = - mlir::dyn_cast_or_null(indices.back().getDefiningOp()); + mlir::dyn_cast_or_null(index.getDefiningOp()); if (!apply_indexing) { return std::nullopt; } @@ -192,6 +181,11 @@ std::optional> GetVectorBaseIndices( ? mlir::getAffineDimExpr(index, b.getContext()) : mlir::getAffineSymbolExpr( index - map.getNumDims(), b.getContext()); + } else if (!operand.getParentRegion()->isProperAncestor( + &loop.getBodyRegion())) { + // If the operand is defined inside the loop, we can't hoist the + // apply_indexing outside the loop. + return std::nullopt; } } if (!induction_var_expr) { @@ -212,12 +206,8 @@ std::optional> GetVectorBaseIndices( operands[induction_var_operand_index] = b.create(0); - llvm::SmallVector ret = indices; - ret.back() = - b.create(operands, map, apply_indexing.getLowerBounds(), - apply_indexing.getUpperBounds()) - ->getResult(0); - return ret; + return b.create(operands, apply_indexing.getIndexingMap()) + ->getResult(0); } bool IsConflictFree(mlir::tensor::ExtractOp op) { @@ -247,16 +237,14 @@ struct VectorizeLoad : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); b.setInsertionPoint(loop); - auto vector_indices = - GetVectorBaseIndices(op.getIndices(), loop, vector_type, b); - if (!vector_indices) { + auto vector_index = + GetVectorBaseIndices(op.getIndices().front(), loop, vector_type, b); + if (!vector_index) { return rewriter.notifyMatchFailure( op, "the instruction does not access contiguous elements"); } - auto loaded_vector = b.create( - vector_type, op.getTensor(), *vector_indices, - llvm::ArrayRef{true}); + vector_type, op.getTensor(), *vector_index, llvm::ArrayRef{true}); rewriter.replaceOpWithNewOp( op, loaded_vector, loop.getInductionVar()); return mlir::success(); @@ -297,9 +285,9 @@ struct VectorizeStore : mlir::OpRewritePattern { mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); b.setInsertionPoint(loop); - auto vector_indices = - GetVectorBaseIndices(op.getIndices(), loop, vector_type, b); - if (!vector_indices) { + auto vector_index = + GetVectorBaseIndices(op.getIndices().front(), loop, vector_type, b); + if (!vector_index) { return rewriter.notifyMatchFailure( op, "the instruction does not access contiguous elements"); } @@ -314,7 +302,7 @@ struct VectorizeStore : mlir::OpRewritePattern { .getInductionVar(); auto insert_op = yield_b.create( yield_loc, op.getScalar(), bbarg.front(), induction_var); - return llvm::SmallVector{insert_op.getResult()}; + return llvm::SmallVector{insert_op.getResult()}; }; int result_index = op->use_begin()->getOperandNumber(); auto new_for = *loop.replaceWithAdditionalYields( @@ -326,7 +314,7 @@ struct VectorizeStore : mlir::OpRewritePattern { auto filled_vector = new_for->getResults().back(); auto written = b.create( - filled_vector, new_for.getInits()[result_index], *vector_indices, + filled_vector, new_for.getInits()[result_index], *vector_index, llvm::ArrayRef{true}); new_for->getResult(result_index).replaceAllUsesWith(written.getResult()); diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index 8053a8972f43ba..a06b9d9dc052de 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" +#include "absl/status/status.h" #include "absl/types/span.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" @@ -41,14 +42,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/mlir/utils/type_util.h" #include "xla/permutation_util.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/fusion_emitter.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/computation_partitioner.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/type_util.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -72,8 +71,6 @@ using mlir::Value; using mlir::ValueRange; using mlir::func::FuncOp; using mlir::func::ReturnOp; -using mlir::tensor::ExtractOp; -using mlir::tensor::InsertOp; using mlir_converter::ApplyIndexing; constexpr int kNumRows = 4; @@ -87,7 +84,8 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) : analysis_(analysis), transpose_(analysis.tiled_transpose()), permutation_(transpose_.permutation), - input_shape_(Permute(transpose_.dimensions, permutation_)) { + input_shape_( + Permute(transpose_.dimensions, InversePermutation(permutation_))) { ConstHloInstructionSet transposes_to_tile; int index = 0; int64_t shmem_usage = 0; @@ -99,6 +97,11 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) transposes_to_tile.insert(&hero.instruction()); shmem_transpose_roots_.push_back(&root.instruction()); int size = primitive_util::ByteWidth(hero.shape().element_type()); + // If the last dimension stays the same, we need to make it part of the + // shared memory tile. + if (MostMinorDimensionUnchanged()) { + size *= input_shape_.back(); + } max_element_bytes = std::max(max_element_bytes, size); shmem_usage += kBaseBlockSize * (kBaseBlockSize + 1) * size; shmem_transpose_root_indices_.push_back(index); @@ -113,11 +116,20 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) auto compute_block_sizes = [this](int vector_size) { vector_size_ = vector_size; block_size_ = kBaseBlockSize * vector_size_; - block_sizes_ = {1, 1, block_size_}; - block_sizes_[permutation_[2]] = block_size_; - block_counts_ = {CeilOfRatio(input_shape_[0], block_sizes_[0]), - CeilOfRatio(input_shape_[1], block_sizes_[1]), - CeilOfRatio(input_shape_[2], block_sizes_[2])}; + block_sizes_.assign(input_shape_.size(), 1); + if (MostMinorDimensionUnchanged()) { + block_sizes_.back() = input_shape_.back(); + block_sizes_[block_sizes_.size() - 2] = block_size_; + block_sizes_[permutation_[block_sizes_.size() - 2]] = block_size_; + } else { + block_sizes_.back() = block_size_; + block_sizes_[permutation_.back()] = block_size_; + } + output_block_sizes_ = Permute(block_sizes_, permutation_); + block_counts_.resize(block_sizes_.size()); + for (int64_t i = 0; i < block_sizes_.size(); ++i) { + block_counts_[i] = CeilOfRatio(input_shape_[i], block_sizes_[i]); + } }; // Compute initial block sizes without vectorization. We use the result to // determine whether we can vectorize. @@ -135,8 +147,13 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) device.threads_per_core_limit(); bool enough_shmem = shmem_usage * elems_per_thread <= device.shared_memory_per_block(); - bool aligned_dims = (input_shape_[2] % vec_size == 0) && - (input_shape_[permutation_[2]] % vec_size == 0); + bool aligned_dims = (input_shape_.back() % vec_size == 0) && + (input_shape_[permutation_.back()] % vec_size == 0); + if (MostMinorDimensionUnchanged()) { + aligned_dims = + input_shape_[input_shape_.size() - 2] % vec_size == 0 && + input_shape_[permutation_[input_shape_.size() - 2]] % vec_size == 0; + } if (enough_work && enough_shmem && aligned_dims) { compute_block_sizes(vec_size); break; @@ -151,12 +168,8 @@ std::optional MlirTransposeFusion::ComputeThreadIdToOutputIndexing( analysis_.fusion_root(root_index).instruction(), hero)) { // The shape of non-transpose roots are bitcast compatible with the input // shape of transpose heroes. - auto map = ComposeIndexingMaps( - GetIndexing(/*input=*/true, hero.shape(), mlir_context), - GetBitcastMap(hero.shape(), analysis_.fusion_root(root_index).shape(), - mlir_context)); - map.Simplify(); - return map; + return GetIndexing(/*input=*/true, + analysis_.fusion_root(root_index).shape(), mlir_context); } return GetIndexing(/*input=*/false, hero.shape(), mlir_context); } @@ -186,10 +199,29 @@ LaunchDimensions MlirTransposeFusion::launch_dimensions() const { IndexingMap MlirTransposeFusion::GetSharedMemoryIndexing( bool read, mlir::MLIRContext* ctx) const { - auto thread_offsets = - Permute(GetThreadOffsets(ctx), read ? Vector3{0, 1, 2} : permutation_); + auto thread_offsets = GetThreadOffsets(/*read=*/true, ctx); + if (!read) { + // Regarding shared memory indexing, the permutation we need to apply is + // just a swap of the two dimensions that are tiled. + if (MostMinorDimensionUnchanged()) { + std::swap(thread_offsets[thread_offsets.size() - 2], + thread_offsets[permutation_[permutation_.size() - 2]]); + } else { + std::swap(thread_offsets.back(), thread_offsets[permutation_.back()]); + } + } + std::vector dim_var_sizes(6, 1); + dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = + kNumThreadsPerBlock; + if (MostMinorDimensionUnchanged()) { + return {mlir::AffineMap::get(6, 3, thread_offsets, ctx), + DimVarsFromTensorSizes(dim_var_sizes), + RangeVarsFromTensorSizes( + {block_size_ / kNumRows, vector_size_, input_shape_.back()}), + {}}; + } return {mlir::AffineMap::get(6, 2, thread_offsets, ctx), - DimVarsFromTensorSizes({kNumThreadsPerBlock, 1, 1, 1, 1, 1}), + DimVarsFromTensorSizes(dim_var_sizes), RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), {}}; } @@ -203,7 +235,13 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( MLIRContext* ctx = builder.getContext(); auto shmem_tensor_size = block_sizes_; // Avoid bank conflicts. - ++shmem_tensor_size.back(); + if (MostMinorDimensionUnchanged()) { + // Increase the dimension that is actually iterated over. The most minor + // dimension is always completely loaded into the shared memory tile. + ++shmem_tensor_size[shmem_tensor_size.size() - 2]; + } else { + ++shmem_tensor_size.back(); + } // Allocate shared memory. SmallVector inits; @@ -237,8 +275,8 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( root_computation, transpose, /*operand_index=*/0, input_indices(transpose->operand(0)), call_target_provider, entry_function, builder)[0]; - result_tensors.push_back( - builder.create(result_scalar, output, shmem_indices)); + result_tensors.push_back(builder.create( + result_scalar, output, shmem_indices)); } // Produce all side outputs and then write them. @@ -258,7 +296,7 @@ MlirTransposeFusion::WriteResult MlirTransposeFusion::EmitWriteToShMemMlir( llvm::zip(side_outputs, side_output_indices, output_tensors.take_back(side_output_roots_.size()))) { result_tensors.push_back( - builder.create(value, output, indices)); + builder.create(value, output, indices)); } return result_tensors; @@ -306,7 +344,7 @@ void MlirTransposeFusion::EmitReadFromShMemMlir( for (auto [transpose, shmem] : llvm::zip(shmem_transposes_, written.shmem_tensors)) { transpose_values[transpose].push_back( - builder.create(shmem, shmem_indices)); + builder.create(shmem, shmem_indices)); } llvm::SmallVector epilogue_indices = dim_values; absl::c_copy(symbol_values, std::back_inserter(epilogue_indices)); @@ -320,7 +358,7 @@ void MlirTransposeFusion::EmitReadFromShMemMlir( shmem_transpose_root_indices_)) { llvm::SmallVector indices = ApplyIndexing(indexing, dim_values, symbol_values, builder); - results[root_index] = builder.create( + results[root_index] = builder.create( result_scalars.at(root).front(), results[root_index], indices); } return results; @@ -365,14 +403,19 @@ absl::Status MlirTransposeFusion::EmitEntryFunction( } llvm::SmallVector MlirTransposeFusion::GetThreadOffsets( - mlir::MLIRContext* ctx) const { + bool read, mlir::MLIRContext* ctx) const { auto thread = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx); auto loop = mlir::getAffineSymbolExpr(0, ctx); auto vector = mlir::getAffineSymbolExpr(1, ctx); int loop_stride = block_size_ * kNumRows; auto linear_index = loop * loop_stride + thread * vector_size_ + vector; - return DelinearizeInBoundsIndex(linear_index, block_sizes_); + if (MostMinorDimensionUnchanged()) { + auto minor_dim = mlir::getAffineSymbolExpr(2, ctx); + linear_index = linear_index * input_shape_.back() + minor_dim; + } + return DelinearizeInBoundsIndex(linear_index, + read ? block_sizes_ : output_block_sizes_); } IndexingMap MlirTransposeFusion::GetIndexing(bool input, @@ -380,19 +423,31 @@ IndexingMap MlirTransposeFusion::GetIndexing(bool input, mlir::MLIRContext* ctx) const { auto raw_id = mlir::getAffineDimExpr( KernelFusionInterface::kIndexingMapBlockIdxDims[0], ctx); - auto block_ids = Permute(DelinearizeInBoundsIndex(raw_id, block_counts_), - input ? Vector3{0, 1, 2} : permutation_); - auto thread_offsets = GetThreadOffsets(ctx); + auto block_ids = DelinearizeInBoundsIndex(raw_id, block_counts_); + if (!input) { + absl::c_copy(Permute(block_ids, permutation_), block_ids.begin()); + } + auto thread_offsets = GetThreadOffsets(input, ctx); + const auto& permuted_block_sizes = input ? block_sizes_ : output_block_sizes_; llvm::SmallVector offsets; for (auto [block_id, block_size, thread] : - llvm::zip(block_ids, block_sizes_, thread_offsets)) { + llvm::zip(block_ids, permuted_block_sizes, thread_offsets)) { offsets.push_back(block_id * block_size + thread); } + std::vector dim_var_sizes(6, 1); + dim_var_sizes[KernelFusionInterface::kIndexingMapThreadIdxDims[0]] = + kNumThreadsPerBlock; + dim_var_sizes[KernelFusionInterface::kIndexingMapBlockIdxDims[0]] = + Product(block_counts_); + auto range_var_sizes = + std::vector{block_size_ / kNumRows, vector_size_}; + if (MostMinorDimensionUnchanged()) { + range_var_sizes.push_back(input_shape_.back()); + } IndexingMap result{ - mlir::AffineMap::get(6, 2, offsets, ctx), - DimVarsFromTensorSizes( - {kNumThreadsPerBlock, 1, 1, Product(block_counts_), 1, 1}), - RangeVarsFromTensorSizes({block_size_ / kNumRows, vector_size_}), + mlir::AffineMap::get(6, range_var_sizes.size(), offsets, ctx), + DimVarsFromTensorSizes(dim_var_sizes), + RangeVarsFromTensorSizes(range_var_sizes), {}}; auto normalized_shape = input ? ShapeUtil::MakeShape(shape.element_type(), input_shape_) @@ -407,5 +462,9 @@ IndexingMap MlirTransposeFusion::GetIndexing(bool input, return result; } +bool MlirTransposeFusion::MostMinorDimensionUnchanged() const { + return permutation_.back() == permutation_.size() - 1; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 07d1e99d382b51..afb2777967220e 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -35,7 +36,6 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/model/indexing_map.h" -#include "xla/util.h" namespace xla { namespace gpu { @@ -97,12 +97,14 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { mlir::MLIRContext* ctx) const; IndexingMap GetSharedMemoryIndexing(bool read, mlir::MLIRContext* ctx) const; llvm::SmallVector GetThreadOffsets( - mlir::MLIRContext* ctx) const; + bool read, mlir::MLIRContext* ctx) const; + bool MostMinorDimensionUnchanged() const; TransposeDescription transpose_; - Vector3 permutation_; + absl::InlinedVector permutation_; std::vector input_shape_; std::vector block_sizes_; // In input elements. + std::vector output_block_sizes_; std::vector block_counts_; int vector_size_; int block_size_; diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc index 1861672a82279d..d773503859d934 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir_test.cc @@ -16,11 +16,12 @@ limitations under the License. #include #include +#include "mlir/IR/MLIRContext.h" #include "xla/error_spec.h" #include "xla/service/gpu/fusions/mlir_emitter_test_base.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/model/indexing_test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -44,7 +45,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { )")); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); EXPECT_THAT( @@ -87,29 +88,29 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing021) { )")); } -TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { +TEST_F(MlirTransposeFusionTest, ThreadIndexing201_SimplifiedTo021) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module fusion { - %input = f32[100,64,32] parameter(0) - ROOT transpose = f32[32,100,64] transpose(%input), dimensions={2,0,1} + %input = f32[1,6400,32] parameter(0) + ROOT transpose = f32[1,32,6400] transpose(%input), dimensions={0,2,1} } ENTRY entry { - %input = f32[100,64,32] parameter(0) - ROOT %fusion = f32[32,100,64] fusion(%input), kind=kInput, calls=fusion + %input = f32[1,6400,32] parameter(0) + ROOT %fusion = f32[1,32,6400] fusion(%input), kind=kInput, calls=fusion })")); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); EXPECT_THAT( fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s0 * 4 + d0 floordiv 32, + 0, + d3 * 32 + s0 * 4 + d0 floordiv 32, d0 mod 32 ) domain: @@ -127,9 +128,9 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), MatchIndexingString(R"( (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + 0, d0 floordiv 32 + s0 * 4, - d3 floordiv 2, - (d3 mod 2) * 32 + d0 mod 32 + d3 * 32 + d0 mod 32 ) domain: d0 in [0, 127] @@ -144,6 +145,72 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexing201) { )")); } +TEST_F(MlirTransposeFusionTest, Transpose_ThreadIndexing1302) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} + } + ENTRY main { + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloString)); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); + + MlirTransposeFusion fusion(analysis); + EXPECT_THAT( + fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + d3 floordiv 80, + (d3 floordiv 5) mod 16, + d0 floordiv 32 + s0 * 4, + (d3 mod 5) * 32 + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 3] + s1 in [0, 0] + (d3 mod 5) * 32 + d0 mod 32 in [0, 143] + )")); + EXPECT_THAT( + fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), + MatchIndexingString(R"( + (d0, d1, d2, d3, d4, d5)[s0, s1] -> ( + (d3 floordiv 5) mod 16, + (d3 mod 5) * 32 + s0 * 4 + d0 floordiv 32, + d3 floordiv 80, + d0 mod 32 + ) + domain: + d0 in [0, 127] + d1 in [0, 0] + d2 in [0, 0] + d3 in [0, 1519] + d4 in [0, 0] + d5 in [0, 0] + + s0 in [0, 7] + s1 in [0, 0] + (d3 mod 5) * 8 + s0 in [0, 35] + d0 mod 32 in [0, 15] + )")); +} + TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( HloModule module @@ -158,7 +225,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized021) { )")); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); EXPECT_THAT( @@ -212,7 +279,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingVectorized210) { })")); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); EXPECT_THAT( @@ -295,6 +362,55 @@ TEST_F(MlirTransposeFusionTest, FusedTranspose021) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } +TEST_F(MlirTransposeFusionTest, FusedTranspose102) { + auto kHloString = R"( + HloModule Transpose + + %fused_computation { + %p0 = s8[160,170,3] parameter(0) + ROOT %transpose = s8[170,160,3] transpose(%p0), dimensions={1,0,2} + } + ENTRY main { + %param = s8[160,170,3] parameter(0) + ROOT %fusion = s8[170,160,3] fusion(%param), kind=kInput, + calls=%fused_computation + } + )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, R"( + // CHECK-LABEL: func.func @fused_computation( + // CHECK-SAME: }, %[[OUT:.*]]: tensor<170x160x3xi8> + // + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index + // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index + + // CHECK: %[[SHMEM:.*]] = xla_gpu.allocate_shared : tensor<32x33x3xi8> + // CHECK: %[[SHMEM_WITH_VALS:.*]] = scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[SHMEM_:.*]] = %[[SHMEM]]) + // CHECK: %[[SHMEM_WITH_VALS2:.*]] = scf.for + // CHECK-SAME: %[[C0]] to %[[C3]] step %[[C1]] + // CHECK-SAME: iter_args(%[[SHMEM2_:.*]] = %[[SHMEM_]]) + // CHECK: %[[P0:.*]] = xla_gpu.pure_call @fused_computation_p0 + // CHECK: tensor.insert %[[P0]] into %[[SHMEM2_]] + + // CHECK: %[[SYNC:.*]] = xla_gpu.sync_threads %[[SHMEM_WITH_VALS]] + + // CHECK: scf.for + // CHECK-SAME: %[[C0]] to %[[C8]] step %[[C1]] + // CHECK-SAME: iter_args(%[[OUT_:.*]] = %[[OUT]]) + // CHECK: scf.for + // CHECK-SAME: %[[C0]] to %[[C3]] step %[[C1]] + // CHECK-SAME: iter_args(%[[OUT2_:.*]] = %[[OUT_]]) + // CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[SYNC]] + // CHECK: %[[RES:.*]] = xla_gpu.pure_call @fused_computation__epilogue__transpose + // CHECK: tensor.insert %[[RES]] into %[[OUT2_]] + )")); + + EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); +} + TEST_F(MlirTransposeFusionTest, FusedTranspose210) { auto kHloString = R"( HloModule Transpose @@ -400,25 +516,6 @@ TEST_F(MlirTransposeFusionTest, Transpose021_NoEpilogue) { EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirTransposeFusionTest, Transpose_4D) { - auto kHloString = R"( - HloModule Transpose - - %fused_computation { - %param_0 = f64[2,24,6,4] parameter(0) - ROOT %transpose= f64[6,4,2,24] transpose(f64[2,24,6,4] %param_0), - dimensions={2,3,0,1} - } - ENTRY main { - %param = f64[2,24,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%param), kind=kInput, - calls=%fused_computation - } - )"; - TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - TEST_F(MlirTransposeFusionTest, Transpose_2D) { auto kHloString = R"( HloModule Transpose @@ -434,30 +531,24 @@ TEST_F(MlirTransposeFusionTest, Transpose_2D) { calls=%fused_computation } )"; + TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } -TEST_F(MlirTransposeFusionTest, Transpose_2D_2) { +TEST_F(MlirTransposeFusionTest, Transpose_4D) { auto kHloString = R"( - HloModule m + HloModule Transpose %fused_computation { - %p0 = f32[17,2820]{0,1} parameter(0) - %p1 = f32[30,17,94] parameter(1) - - %bitcast0 = f32[2,3,5,17,94] bitcast(f32[30,17,94] %p1) - %transpose = f32[2,3,5,94,17] transpose(f32[2,3,5,17,94] %bitcast0), dimensions={0,1,2,4,3} - %bitcast1 = f32[2820,17]{1,0} bitcast(f32[2,3,5,94,17] %transpose) - %bitcast2 = f32[2820,17]{1,0} bitcast(f32[17,2820]{0,1} %p0) - %neg = f32[2820,17]{1,0} negate(f32[2820,17] %bitcast2) - ROOT %add = f32[2820,17]{1,0} add(f32[2820,17] %bitcast1, f32[2820,17]{1,0} %neg) + %param_0 = f32[19, 16, 16, 144] parameter(0) + ROOT %transpose= f32[16, 144, 19, 16] transpose( %param_0), + dimensions={1,3,0,2} } - ENTRY main { - %p1 = f32[30,17,94]{2,1,0} parameter(1) - %p0 = f32[17,2820]{0,1} parameter(0) - ROOT %fusion = f32[2820,17]{1,0} fusion(%p0, %p1), kind=kInput, calls=%fused_computation + %param = f32[19, 16, 16, 144] parameter(0) + ROOT %fusion = f32[16, 144, 19, 16] fusion(%param), kind=kInput, + calls=%fused_computation } )"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); @@ -469,19 +560,19 @@ TEST_F(MlirTransposeFusionTest, MultipleRootsForTranspose) { HloModule m %fused_computation { - %iota.0 = s32[200,200] iota(), iota_dimension=1 - %iota.1 = s32[200,200] iota(), iota_dimension=0 - %compare = pred[200,200] compare(%iota.0, %iota.1), direction=GE - %transpose = pred[200,200] transpose(%compare), dimensions={1,0} - %copy = pred[200,200] copy(%transpose) - %copy.1 = pred[200,200] copy(%transpose) - ROOT %tuple = (pred[200,200], pred[200,200], pred[200,200]{1,0}) + %iota.0 = s32[1,200,200] iota(), iota_dimension=1 + %iota.1 = s32[1,200,200] iota(), iota_dimension=0 + %compare = pred[1,200,200] compare(%iota.0, %iota.1), direction=GE + %transpose = pred[1,200,200] transpose(%compare), dimensions={0,2,1} + %copy = pred[1,200,200] copy(%transpose) + %copy.1 = pred[1,200,200] copy(%transpose) + ROOT %tuple = (pred[1,200,200], pred[1,200,200], pred[1,200,200]) tuple(%transpose, %copy, %copy.1) } ENTRY main { ROOT %fusion = - (pred[200,200]{1,0}, pred[200,200]{1,0}, pred[200,200]{1,0}) + (pred[1,200,200], pred[1,200,200], pred[1,200,200]) fusion(), kind=kInput, calls=%fused_computation } )"; @@ -494,13 +585,13 @@ TEST_F(MlirTransposeFusionTest, PartialTile) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %t = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} + %p0 = f64[24,2,24] parameter(0) + ROOT %t = f64[24,2,24] transpose(%p0), dimensions={2,1,0} } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) - ROOT %fusion = f64[6,4,2,24] fusion(%p0), kind=kInput, calls=%fused_computation + %p0 = f64[24,2,24] parameter(0) + ROOT %fusion = f64[24,2,24] fusion(%p0), kind=kInput, calls=%fused_computation } )"; TF_EXPECT_OK(EmitAndCheckIR(kHloString, "// CHECK: xla_gpu.allocate_shared")); @@ -512,20 +603,19 @@ TEST_F(MlirTransposeFusionTest, MixedIndexing) { HloModule m fused_computation { - %p0 = f64[24,2,6,4] parameter(0) - %bc = f64[24,2,24] bitcast(%p0) - %t1 = f64[6,4,2,24] transpose(%p0), dimensions={2,3,1,0} - %t2 = f64[24,2,24] transpose(%bc), dimensions={2,1,0} + %p0 = f64[24,2,24] parameter(0) + %t1 = f64[24,2,24] transpose(%p0), dimensions={2,1,0} + %b = f64[6,4,2,24] bitcast(%t1) %p1 = f64[] parameter(1) %bc1 = f64[6,4,2,24] broadcast(%p1), dimensions={} %bc2 = f64[24,2,24] broadcast(%p1), dimensions={} - %a1 = f64[6,4,2,24] add(%t1, %bc1) - %a2 = f64[24,2,24] add(%t2, %bc2) + %a1 = f64[6,4,2,24] add(%b, %bc1) + %a2 = f64[24,2,24] add(%t1, %bc2) ROOT %t = (f64[6,4,2,24], f64[24,2,24]) tuple(%a1, %a2) } ENTRY main { - %p0 = f64[24,2,6,4] parameter(0) + %p0 = f64[24,2,24] parameter(0) %p1 = f64[] parameter(1) ROOT %fusion = (f64[6,4,2,24], f64[24,2,24]) fusion(%p0, %p1), kind=kInput, calls=%fused_computation @@ -578,7 +668,7 @@ TEST_F(MlirTransposeFusionTest, SameInputIndexingForRealHeroAndSideOutput) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); mlir::MLIRContext mlir_context; @@ -608,7 +698,7 @@ TEST_F(MlirTransposeFusionTest, ThreadIndexingSideOutput) { .value(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); MlirTransposeFusion fusion(analysis); mlir::MLIRContext mlir_context; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/BUILD b/third_party/xla/xla/service/gpu/fusions/triton/BUILD index 7fcb4ffcf95e1b..3b1f0e20bedef7 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/triton/BUILD @@ -1,3 +1,4 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") load("@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured") load("//xla:xla.bzl", "xla_cc_test") @@ -32,8 +33,7 @@ cc_library( ]), hdrs = ["triton_fusion_emitter.h"], deps = [ - ":prevent_mmav3_loop_unrolling", - ":sparse_extensions", + ":passes", "//xla:autotuning_proto_cc", "//xla:comparison_util", "//xla:debug_options_flags", @@ -42,6 +42,7 @@ cc_library( "//xla:status_macros", "//xla:util", "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", "//xla/mlir_hlo", @@ -58,9 +59,9 @@ cc_library( "//xla/service/gpu:target_util", "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", - "//xla/service/gpu/fusions/mlir:passes", - "//xla/service/gpu/fusions/mlir/ir:xla_gpu", + "//xla/service/gpu/fusions/transforms:passes", "//xla/service/gpu/llvm_gpu_backend", "//xla/service/gpu/model:affine_map_printer", "//xla/service/gpu/model:indexing_analysis", @@ -68,6 +69,7 @@ cc_library( "//xla/service/gpu/model:symbolic_tiled_hlo_instruction", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/service/gpu/model:tiled_hlo_instruction", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/stream_executor:device_description", "//xla/stream_executor:launch_dim", @@ -135,11 +137,34 @@ cc_library( ]), ) +gentbl_cc_library( + name = "passes_inc_gen", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TritonFusionTransforms", + ], + "passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes.td", + visibility = ["//visibility:private"], + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + cc_library( - name = "sparse_extensions", - srcs = ["sparse_extensions.cc"], - hdrs = ["sparse_extensions.h"], + name = "passes", + srcs = [ + "generalize_kernel_signature.cc", + "passes.cc", + "prevent_mmav3_loop_unrolling.cc", + "sparse_extensions.cc", + ], + hdrs = ["passes.h"], deps = [ + ":passes_inc_gen", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:GPUCommonTransforms", @@ -151,6 +176,7 @@ cc_library( "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", @@ -165,20 +191,6 @@ cc_library( ], ) -cc_library( - name = "prevent_mmav3_loop_unrolling", - srcs = ["prevent_mmav3_loop_unrolling.cc"], - hdrs = ["prevent_mmav3_loop_unrolling.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:Support", - "@triton//:TritonDialects", - ], -) - xla_test( name = "triton_fusion_emitter_device_legacy_test", srcs = if_gpu_is_configured(["triton_fusion_emitter_device_legacy_test.cc"]), @@ -211,14 +223,15 @@ xla_test( "//xla/tests:filecheck", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", @@ -255,13 +268,13 @@ xla_test( "//xla/stream_executor/cuda:cublas_plugin", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", @@ -287,6 +300,7 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:matmul_utils", "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", @@ -316,10 +330,12 @@ xla_cc_test( "//xla/service/gpu/model:symbolic_tile_analysis", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/service/gpu/model:tiled_hlo_instruction", + "//xla/service/gpu/model:triton_emitter_constraints", "//xla/service/llvm_ir:llvm_util", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -329,7 +345,6 @@ xla_cc_test( "@llvm-project//mlir:IR", "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@triton//:TritonDialects", ], @@ -391,8 +406,14 @@ xla_test( cc_library( name = "triton_support", - srcs = ["triton_support.cc"], - hdrs = ["triton_support.h"], + srcs = [ + "triton_support.cc", + "triton_support_legacy.cc", + ], + hdrs = [ + "triton_support.h", + "triton_support_legacy.h", + ], deps = [ "//xla:shape_util", "//xla:xla_data_proto_cc", @@ -433,8 +454,10 @@ xla_cc_test( "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status_matchers", @@ -465,11 +488,11 @@ xla_test( "//xla/service/gpu:triton_fusion_analysis", "//xla/service/gpu/model:tiled_hlo_computation", "//xla/stream_executor:device_description", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc index eeac6366bb2c75..46a569d265bcdd 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_cuda.cc @@ -24,8 +24,7 @@ limitations under the License. #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h" -#include "xla/service/gpu/fusions/triton/sparse_extensions.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/hlo_module_config.h" @@ -65,9 +64,9 @@ absl::Status CreateTritonPipeline( pm.addPass(mt::createConvertTritonToTritonGPUPass( absl::StrFormat("cuda:%u", ccAsInt), block_level_parameters.num_warps, threadsPerWarp, block_level_parameters.num_ctas)); - pm.addPass(CreateAddSparseDotEncodingPass(block_level_parameters.num_warps, - threadsPerWarp, - block_level_parameters.num_ctas)); + pm.addPass(CreateSparseAddEncodingPass(block_level_parameters.num_warps, + threadsPerWarp, + block_level_parameters.num_ctas)); pm.addPass(mt::gpu::createTritonGPUCoalesce()); if (ccCuda.IsAtLeastAmpere()) { pm.addPass(mt::gpu::createTritonGPUF32DotTC()); @@ -93,6 +92,7 @@ absl::Status CreateTritonPipeline( pm.addPass( mt::gpu::createTritonGPUOptimizeDotOperands({ccCuda.IsAtLeastAmpere()})); pm.addPass(mt::gpu::createTritonGPURemoveLayoutConversions()); + pm.addPass(CreateSparseRemoveLayoutConversionPass()); pm.addPass(mt::gpu::createTritonGPUReduceDataDuplication()); pm.addPass(mt::gpu::createTritonGPUReorderInstructions()); pm.addPass(mlir::createCSEPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 4c130b83f71fc0..2a95ea833f4bcc 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -22,7 +22,6 @@ limitations under the License. #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "xla/service/gpu/fusions/triton/sparse_extensions.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" @@ -57,8 +56,6 @@ absl::Status CreateTritonPipeline( mlir::OpPassManager& pm, const se::GpuComputeCapability& cc, const BlockLevelParameters& block_level_parameters, mt::nvidia_gpu::ClusterInfo& out_cluster_info) { - // TODO(ROCm): Check whether value different than 0 can be used. - const int ccAsInt = 0; // TODO(ROCm): Check why some test fail when threadsPerWarp is set to 64. const int threadsPerWarp = 32; auto ccRocm = std::get(cc); @@ -107,6 +104,8 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( ccRocm.gfx_version())); const int custom_lds_size = 0; + pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(), + custom_lds_size)); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass(mt::gpu::createAllocateSharedMemoryPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc b/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc new file mode 100644 index 00000000000000..7ce29350b2d42c --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/generalize_kernel_signature.cc @@ -0,0 +1,130 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/triton/passes.h" + +namespace xla::gpu { +namespace { + +// Extract additional attributes from an LLVM function that are not passed +// to the builder directly. +mlir::SmallVector GetExtraAttrs( + mlir::LLVM::LLVMFuncOp func) { + llvm::StringSet<> registered_attr_names{ + func.getSymNameAttrName().getValue(), + func.getFunctionTypeAttrName().getValue(), + func.getLinkageAttrName().getValue(), + func.getDsoLocalAttrName().getValue(), + func.getCConvAttrName().getValue(), + func.getArgAttrsAttrName().getValue(), + func.getFunctionEntryCountAttrName().getValue()}; + return llvm::to_vector( + llvm::make_filter_range(func->getAttrs(), [&](mlir::NamedAttribute attr) { + return !registered_attr_names.contains(attr.getName().getValue()); + })); +} + +// Strip address spaces from function parameters. +void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, + mlir::LLVM::LLVMFuncOp func) { + // Figure out what the new signature should be. + mlir::LLVM::LLVMFunctionType func_ty = func.getFunctionType(); + mlir::SmallVector generic_func_params( + llvm::map_range(func_ty.getParams(), [](mlir::Type type) -> mlir::Type { + auto ptr_ty = mlir::dyn_cast(type); + if (!ptr_ty) return type; + if (ptr_ty.getAddressSpace() != mlir::NVVM::kGlobalMemorySpace) + return type; + return mlir::LLVM::LLVMPointerType::get(ptr_ty.getContext()); + })); + mlir::LLVM::LLVMFunctionType generic_func_ty = + func_ty.clone(generic_func_params, func_ty.getReturnTypes()); + + // Create a function with the new signature. + mlir::SmallVector arg_attrs(llvm::map_range( + func.getArgAttrsAttr().getValue(), [](mlir::Attribute attr) { + return mlir::cast(attr); + })); + auto generic_func = rewriter.create( + func.getLoc(), func.getSymName(), generic_func_ty, func.getLinkage(), + func.getDsoLocal(), func.getCConv(), /*comdat=*/nullptr, + GetExtraAttrs(func), arg_attrs, func.getFunctionEntryCount()); + + // Convert generic address spaces back to original ones within the function + // body. + mlir::Block* entry = generic_func.addEntryBlock(rewriter); + rewriter.setInsertionPointToEnd(entry); + mlir::SmallVector converted_args; + for (auto [arg, type] : + llvm::zip(generic_func.getArguments(), func_ty.getParams())) { + mlir::Value converted = arg; + if (arg.getType() != type) { + converted = + rewriter.create(arg.getLoc(), type, arg); + } + converted_args.push_back(converted); + } + + // Move the rest of function body from the original function. + rewriter.cloneRegionBefore(func.getBody(), generic_func.getBody(), + generic_func.getBody().end()); + rewriter.eraseOp(func); + rewriter.mergeBlocks(entry->getNextNode(), entry, converted_args); +} + +#define GEN_PASS_DEF_GENERALIZEKERNELSIGNATUREPASS +#include "xla/service/gpu/fusions/triton/passes.h.inc" + +// Rewrite signatures of kernel functions to use generic data pointers and +// cast them to global ones within the kernel. +struct GeneralizeKernelSignaturePass + : public impl::GeneralizeKernelSignaturePassBase< + GeneralizeKernelSignaturePass> { + void runOnOperation() override { + mlir::IRRewriter rewriter(&getContext()); + getOperation()->walk([&](mlir::LLVM::LLVMFuncOp func) { + if (!func->hasAttr(mlir::NVVM::NVVMDialect::getKernelFuncAttrName())) { + return; + } + rewriter.setInsertionPointAfter(func); + StripParameterAddressSpaces(rewriter, func); + }); + } +}; + +} // namespace + +std::unique_ptr CreateGeneralizeKernelSignaturePass() { + return std::make_unique(); +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc similarity index 53% rename from third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h rename to third_party/xla/xla/service/gpu/fusions/triton/passes.cc index f8e1af041b226e..0d0ff381874644 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.cc @@ -13,23 +13,21 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_ -#define XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_ +#include "xla/service/gpu/fusions/triton/passes.h" -#include - -#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLFunctionalExtras.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Visitors.h" namespace xla::gpu { -// This pass is a result of b/344841434: -// PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in -// compilation time. Most unrolling has already been done before PTX, -// this pragma prevents ptxas from doing more. -std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); - -void RegisterPreventMmaV3LoopUnrollingPass(); +bool ContainsOp(mlir::Operation* op, + llvm::function_ref fn) { + auto visitor = [&](mlir::Operation* nested_op) { + return fn(nested_op) ? mlir::WalkResult::interrupt() + : mlir::WalkResult::advance(); + }; + return op->walk(visitor).wasInterrupted(); +} } // namespace xla::gpu - -#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PREVENT_MMAV3_LOOP_UNROLLING_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h b/third_party/xla/xla/service/gpu/fusions/triton/passes.h similarity index 52% rename from third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h rename to third_party/xla/xla/service/gpu/fusions/triton/passes.h index 5d48a4353ae9d6..9bb3ab6a92d6cf 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.h @@ -13,25 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_ -#define XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_ #include #include +#include "llvm/ADT/STLFunctionalExtras.h" +#include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" namespace xla::gpu { -std::unique_ptr CreateAddSparseDotEncodingPass( - int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas); +#define GEN_PASS_DECL +#include "xla/service/gpu/fusions/triton/passes.h.inc" + +std::unique_ptr CreateSparseAddEncodingPass( + int32_t num_warps = 4, int32_t threads_per_warp = 32, int32_t num_ctas = 1); std::unique_ptr CreateSparseBlockedToMMAPass(); +std::unique_ptr CreateSparseRemoveLayoutConversionPass(); std::unique_ptr CreateSparseLocalLoadToLLVMPass(); std::unique_ptr CreateSparseDotOpToLLVMPass(); std::unique_ptr CreateSparseWGMMAOpToLLVMPass(); +std::unique_ptr CreatePreventMmaV3LoopUnrollingPass(); +std::unique_ptr CreateGeneralizeKernelSignaturePass(); + +// Returns true if the `op` contains an operation in it's regions that satisfies +// the `fn`. +bool ContainsOp(mlir::Operation* op, + llvm::function_ref fn); -void RegisterSparsePasses(); +#define GEN_PASS_REGISTRATION +#include "xla/service/gpu/fusions/triton/passes.h.inc" } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_SPARSE_EXTENSIONS_H_ +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/passes.td b/third_party/xla/xla/service/gpu/fusions/triton/passes.td new file mode 100644 index 00000000000000..f437a44b37c8a4 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/passes.td @@ -0,0 +1,109 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_ + +include "mlir/Pass/PassBase.td" + +def SparseAddEncodingPass : Pass<"sparse-add-encoding", "mlir::ModuleOp"> { + let summary = "Add sparse encoding for all the arguments of a SparseDotOp."; + let options = [ + Option<"num_warps_", "num-warps", "int32_t", /*default=*/"4", + "Number of warps">, + Option<"threads_per_warp_", "threads-per-warp", "int32_t", /*default=*/"32", + "Number of threads per warp">, + Option<"num_ctas_", "num-ctas", "int32_t", /*default=*/"1", + "Number of CTAs in a CGA">, + ]; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + ]; + let constructor = "CreateSparseAddEncodingPass()"; +} + +def SparseBlockedToMMAPass : Pass<"sparse-blocked-to-mma", "mlir::ModuleOp"> { + let summary = "Add convert layouts to/from MMA before and after SparseDotOp."; + let description = [{ + Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3, + shared memory allocations will be used for A and B operands. + }]; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + ]; + let constructor = "CreateSparseBlockedToMMAPass()"; +} + +def SparseRemoveLayoutConversionPass + : Pass<"sparse-remove-layout-conversion", "mlir::ModuleOp"> { + let summary = "Replaces ConvertLayoutOp with sparse dot encoding"; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + ]; + let constructor = "CreateSparseRemoveLayoutConversionPass()"; +} + +def SparseLocalLoadToLLVMPass + : Pass<"sparse-local-load-to-llvm", "mlir::ModuleOp"> { + let summary = "Lowers sparse local load to LLVM"; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + "mlir::LLVM::LLVMDialect" + ]; + let constructor = "CreateSparseLocalLoadToLLVMPass()"; +} + +def SparseDotOpToLLVMPass : Pass<"sparse-dot-to-llvm", "mlir::ModuleOp"> { + let summary = "Lowers sparse dot to LLVM"; + let constructor = "CreateSparseDotOpToLLVMPass()"; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + "mlir::triton::nvgpu::NVGPUDialect", + ]; +} + +def SparseWGMMAOpToLLVMPass : Pass<"sparse-wgmma-to-llvm", "mlir::ModuleOp"> { + let summary = "Lowers sparse WGMMA to LLVM"; + let dependentDialects = [ + "triton::gpu::TritonGPUDialect", + "mlir::triton::nvgpu::NVGPUDialect", + ]; + let constructor = "CreateSparseWGMMAOpToLLVMPass()"; +} + +def PreventMmaV3LoopUnrollingPass + : Pass<"prevent-mmav3-loop-unrolling", "mlir::ModuleOp"> { + let summary = "Prevent MMAv3 loop unrolling."; + let description = [{ + This pass is a result of b/344841434: + PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in + compilation time. Most unrolling has already been done before PTX, + this pragma prevents ptxas from doing more. + }]; + let constructor = "CreatePreventMmaV3LoopUnrollingPass()"; +} + + +def GeneralizeKernelSignaturePass + : Pass<"generalize-kernel-signature"> { + let summary = "Rewrite kernels to use generic data pointer arguments."; + let description = [{ + Rewrite signatures of kernel functions from global pointers to generic + pointers and cast them to global ones within the kernel. + }]; + let constructor = "CreateGeneralizeKernelSignaturePass()"; +} + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_PASSES_TD_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc index 7cb0a551548e66..e5b3d4e6dbaea9 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.cc @@ -13,29 +13,27 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h" - #include -#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -class PreventMmaV3LoopUnrollingPass - : public mlir::PassWrapper> { - public: - llvm::StringRef getArgument() const override { - return "prevent-mmav3-loop-unrolling"; - } +namespace xla::gpu { +namespace { + +#define GEN_PASS_DEF_PREVENTMMAV3LOOPUNROLLINGPASS +#include "xla/service/gpu/fusions/triton/passes.h.inc" +struct PreventMmaV3LoopUnrollingPass + : public impl::PreventMmaV3LoopUnrollingPassBase< + PreventMmaV3LoopUnrollingPass> { // TODO(b/344841434): Remove this if NVIDIA fixes compile-time issue. // PTX sometimes unrolls wgmma loops that can cause a 1000x slow down in // compilation time. Most unrolling has already been done before PTX; @@ -60,10 +58,10 @@ class PreventMmaV3LoopUnrollingPass } }; -std::unique_ptr xla::gpu::CreatePreventMmaV3LoopUnrollingPass() { +} // namespace + +std::unique_ptr CreatePreventMmaV3LoopUnrollingPass() { return std::make_unique(); } -void xla::gpu::RegisterPreventMmaV3LoopUnrollingPass() { - registerPass(CreatePreventMmaV3LoopUnrollingPass); -} +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc index 8b2f1aba7ee14d..7bd6a3b04fbaa2 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/sparse_extensions.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusions/triton/sparse_extensions.h" - #include #include #include @@ -38,6 +36,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" @@ -47,9 +46,9 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "triton/Analysis/Allocation.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h" @@ -64,6 +63,13 @@ limitations under the License. using namespace mlir; // NOLINT(build/namespaces) +namespace ttn = triton::nvgpu; +using ::mlir::LLVM::getSharedMemoryObjectFromStruct; +using ::mlir::triton::gpu::getShapePerCTA; +using ::mlir::triton::gpu::getShapePerCTATile; +using ::mlir::triton::gpu::SharedEncodingAttr; +using ttn::OperandsAndConstraints; + // The functions below are defined in AccelerateMatmul.cpp. namespace mlir::triton::gpu { SmallVector getWarpsPerTile( @@ -79,13 +85,32 @@ Value createDescriptor(ConversionPatternRewriter &rewriter, Location loc, int64_t swizzling, uint32_t stride); int64_t getSwizzlingFromLayout(const triton::gpu::SharedEncodingAttr &layout, uint32_t widthInByte); -triton::nvgpu::WGMMAEltType getMmaRetType(Value); -triton::nvgpu::WGMMAEltType getMmaOperandType(Value, bool); +ttn::WGMMAEltType getMmaRetType(Value); +ttn::WGMMAEltType getMmaOperandType(Value, bool); +namespace xla::gpu { namespace { -// Add sparse encoding for all the arguments of a SparseDotOp. -struct AddSparseEncoding +#define GEN_PASS_DEF_SPARSEADDENCODINGPASS +#define GEN_PASS_DEF_SPARSEBLOCKEDTOMMAPASS +#define GEN_PASS_DEF_SPARSEDOTOPTOLLVMPASS +#define GEN_PASS_DEF_SPARSELOCALLOADTOLLVMPASS +#define GEN_PASS_DEF_SPARSEREMOVELAYOUTCONVERSIONPASS +#define GEN_PASS_DEF_SPARSEWGMMAOPTOLLVMPASS +#include "xla/service/gpu/fusions/triton/passes.h.inc" + +constexpr int kThreadsPerWarp = 32; +// Each 16x16 original sparse matrix tile requires 16 metadata values of +// 16-bit size, where the first thread (T0) in each 4-thread group holds two +// such values in a register (32-bit). +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage +constexpr int kTileSize = 16; +constexpr int kMetaElementsBitSize = 2; +// Metadata elements are packed into 16-bits values. +constexpr int kMetaElementsPerPackedValue = 16 / kMetaElementsBitSize; +constexpr int kColumnsPerCtaTile = kTileSize / kMetaElementsPerPackedValue; + +struct SparseAddEncoding : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -178,29 +203,16 @@ struct AddSparseEncoding } }; -class AddSparseEncodingPass - : public PassWrapper> { - public: - AddSparseEncodingPass() = default; - AddSparseEncodingPass(int32_t num_warps, int32_t threads_per_warp, - int32_t num_ctas) { - num_warps_ = num_warps; - threads_per_warp_ = threads_per_warp; - num_ctas_ = num_ctas; - } - AddSparseEncodingPass(const AddSparseEncodingPass &other) { - num_warps_ = other.num_warps_; - threads_per_warp_ = other.threads_per_warp_; - num_ctas_ = other.num_ctas_; - }; - - StringRef getArgument() const override { return "add-sparse-encoding"; } +struct SparseAddEncodingPass + : public impl::SparseAddEncodingPassBase { + using impl::SparseAddEncodingPassBase< + SparseAddEncodingPass>::SparseAddEncodingPassBase; void runOnOperation() override { MLIRContext *context = &getContext(); TritonGPUTypeConverter type_converter(context, num_warps_, threads_per_warp_, num_ctas_); - auto pattern = std::make_unique(type_converter, context); + auto pattern = std::make_unique(type_converter, context); RewritePatternSet patterns(context, std::move(pattern)); TritonGPUConversionTarget target(*context, type_converter); target.addDynamicallyLegalOp( @@ -211,22 +223,8 @@ class AddSparseEncodingPass std::move(patterns)))) return signalPassFailure(); } - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddSparseEncodingPass) - - private: - Option num_warps_{ - *this, "num-warps", llvm::cl::desc("number of warps"), llvm::cl::init(4)}; - Option threads_per_warp_{ - *this, "threads-per-warp", llvm::cl::desc("number of threads per warp"), - llvm::cl::init(32)}; - Option num_ctas_{*this, "num-ctas", - llvm::cl::desc("number of ctas in a cga"), - llvm::cl::init(1)}; }; -// Add convert layouts to and from MMA before and after SparseDotOp. In MMAV3, -// shared memory allocations will be used for A and B operands. class SparseBlockedToMMA : public RewritePattern { using ConvertLayoutOp = triton::gpu::ConvertLayoutOp; using SparseDotOp = triton::gpu::SparseDotOp; @@ -331,13 +329,8 @@ class SparseBlockedToMMA : public RewritePattern { int compute_capability_; }; -class SparseBlockedToMMAPass - : public PassWrapper> { - public: - SparseBlockedToMMAPass() = default; - - StringRef getArgument() const override { return "sparse-blocked-to-mma"; } - +struct SparseBlockedToMMAPass + : public impl::SparseBlockedToMMAPassBase { void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp module = getOperation(); @@ -349,8 +342,42 @@ class SparseBlockedToMMAPass return signalPassFailure(); } } +}; + +struct SparseRemoveLayoutConversionPass + : public impl::SparseRemoveLayoutConversionPassBase< + SparseRemoveLayoutConversionPass> { + void runOnOperation() override { + getOperation().walk([&](triton::gpu::ConvertLayoutOp op) { + ImplicitLocOpBuilder builder(op.getLoc(), op); + // Skip if the source is already in shared memory. + auto src_encoding = + cast(op.getSrc().getType()).getEncoding(); + if (isa(src_encoding)) { + return; + } + auto dst_type = cast(op.getType()); + // Skip if the destination is not a sparse dot meta. + if (!isa( + dst_type.getEncoding())) { + return; + } - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseBlockedToMMAPass) + auto shared_layout = builder.getAttr( + // Packing metadata elements together. No swizzling. + /*vec=*/kMetaElementsPerPackedValue, /*perPhase=*/1, /*maxPhase=*/1, + triton::gpu::getOrder(src_encoding), + triton::gpu::getCTALayout(src_encoding)); + auto mem_type = triton::MemDescType::get( + dst_type.getShape(), dst_type.getElementType(), shared_layout, + builder.getAttr()); + Value alloc = + builder.create(mem_type, op.getSrc()); + Value convert = builder.create(dst_type, alloc); + op.replaceAllUsesWith(convert); + op.erase(); + }); + } }; class SparseLocalLoadToLLVM @@ -376,17 +403,6 @@ class SparseLocalLoadToLLVM LogicalResult lowerSharedToSparseMeta( triton::gpu::LocalLoadOp op, triton::gpu::LocalLoadOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - constexpr int kThreadsPerWarp = 32; - // Each 16x16 original sparse matrix tile requires 16 metadata values of - // 16-bit size, where the first thread (T0) in each 4-thread group holds two - // such values in a register (32-bit). - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#sparse-matrix-storage - constexpr int kTileSize = 16; - constexpr int kMetaElementsBitSize = 2; - // Metadata elements are packed into 16-bits values. - constexpr int kMetaElementsPerPackedValue = 16 / kMetaElementsBitSize; - constexpr int kColumnsPerCtaTile = kTileSize / kMetaElementsPerPackedValue; - auto loc = op.getLoc(); auto load_sparse_encoding = cast( cast(op.getResult().getType()).getEncoding()); @@ -468,35 +484,25 @@ class SparseLocalLoadToLLVM } }; -class SparseLocalLoadToLLVMPass - : public PassWrapper> { - public: - SparseLocalLoadToLLVMPass() = default; - StringRef getArgument() const override { return "sparse-local-load-to-llvm"; } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } +bool IsLocalLoadWithSparseEncoding(Operation *op) { + auto local_load = mlir::dyn_cast(op); + if (!local_load) return false; + return isa( + local_load.getType().getEncoding()); +} +struct SparseLocalLoadToLLVMPass + : public impl::SparseLocalLoadToLLVMPassBase { void runOnOperation() override { // Exit early if there are no sparse ops. - mlir::ModuleOp mod = getOperation(); - if (!mod.walk([](triton::gpu::LocalLoadOp op) { - if (isa( - op.getType().getEncoding())) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }) - .wasInterrupted()) { - return; - } + ModuleOp mod = getOperation(); + if (!ContainsOp(mod, IsLocalLoadWithSparseEncoding)) return; + // Allocate shared memory and set barrier // This is also done in the TritonGPUToLLVMPass but we need to do it before // we write the local load op to LLVM to have barriers in the right place. - // See b/351986109. + // See b/358375493. ModuleAllocation allocation(getOperation()); ModuleMembarAnalysis membar_pass(&allocation); membar_pass.run(); @@ -510,7 +516,7 @@ class SparseLocalLoadToLLVMPass return !isa( op.getType().getEncoding()); }); - mlir::LowerToLLVMOptions option(context); + LowerToLLVMOptions option(context); TritonGPUToLLVMTypeConverter typeConverter(context, option); auto pattern = std::make_unique(typeConverter); RewritePatternSet patterns(context, std::move(pattern)); @@ -519,15 +525,8 @@ class SparseLocalLoadToLLVMPass return signalPassFailure(); } } - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass) }; -using ::mlir::LLVM::getSharedMemoryObjectFromStruct; -using ::mlir::triton::gpu::getShapePerCTA; -using ::mlir::triton::gpu::getShapePerCTATile; -using ::mlir::triton::gpu::SharedEncodingAttr; - using ValueTableV2 = std::map, Value>; constexpr int kContractingFactor = 2; // implied by N:M (2:4) @@ -658,7 +657,6 @@ LogicalResult convertSparseMMA(triton::gpu::SparseDotOp op, // ----- Hopper implementation. -constexpr int kThreadsPerWarp = 32; constexpr int kWarpsInGroup = 4; constexpr int kMmaAccumulatorCount = 2; constexpr int kMmaLineSize = 128; @@ -775,17 +773,17 @@ LogicalResult convertSparseWGMMA(triton::gpu::SparseDotOp op, assert(hMetaPacked.size() == repM * repK); // Generate prologue. - triton::nvgpu::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false); - triton::nvgpu::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false); - triton::nvgpu::WGMMAEltType eltTypeC = getMmaRetType(op.getD()); + ttn::WGMMAEltType eltTypeA = getMmaOperandType(op.getA(), false); + ttn::WGMMAEltType eltTypeB = getMmaOperandType(op.getB(), false); + ttn::WGMMAEltType eltTypeC = getMmaRetType(op.getD()); - triton::nvgpu::WGMMALayout layoutA = transA ? triton::nvgpu::WGMMALayout::col - : triton::nvgpu::WGMMALayout::row; - triton::nvgpu::WGMMALayout layoutB = transB ? triton::nvgpu::WGMMALayout::row - : triton::nvgpu::WGMMALayout::col; + ttn::WGMMALayout layoutA = + transA ? ttn::WGMMALayout::col : ttn::WGMMALayout::row; + ttn::WGMMALayout layoutB = + transB ? ttn::WGMMALayout::row : ttn::WGMMALayout::col; - rewriter.create(loc, 0); - rewriter.create(loc); + rewriter.create(loc, 0); + rewriter.create(loc); // Generate main loop. for (int m = 0; m < repM; ++m) { @@ -798,7 +796,7 @@ LogicalResult convertSparseWGMMA(triton::gpu::SparseDotOp op, Value a = loadA(m, k); Value b = loadB(n, k); Value meta = hMetaPacked[k * repM + m]; - d = rewriter.create( + d = rewriter.create( loc, accTy, a, meta, b, d, kWarpsInGroup * instrShape[0], instrShape[1], kContractingFactor * instrShape[2], eltTypeC, eltTypeA, eltTypeB, layoutA, layoutB); @@ -815,8 +813,8 @@ LogicalResult convertSparseWGMMA(triton::gpu::SparseDotOp op, op.getContext(), SmallVector(fc.size(), f32_ty)); Value res = packLLElements(loc, typeConverter, fc, rewriter, structTy); - rewriter.create(loc); - res = rewriter.create(loc, res, 0); + rewriter.create(loc); + res = rewriter.create(loc, res, 0); rewriter.replaceOp(op, res); return success(); @@ -856,42 +854,29 @@ struct SparseDotOpConversion } }; -class SparseDotOpToLLVMPass - : public PassWrapper> { - public: - SparseDotOpToLLVMPass() = default; - - StringRef getArgument() const override { return "sparse-dot-to-llvm"; } - - void getDependentDialects(mlir::DialectRegistry ®istry) const override { - registry.insert(); - } - +struct SparseDotOpToLLVMPass + : public impl::SparseDotOpToLLVMPassBase { void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); target.addLegalDialect(); + arith::ArithDialect, ttn::NVGPUDialect>(); target.addIllegalOp(); target.addIllegalDialect(); - mlir::LowerToLLVMOptions option(context); + LowerToLLVMOptions option(context); TritonGPUToLLVMTypeConverter typeConverter(context, option); RewritePatternSet patterns(context); patterns.add(typeConverter); + // TODO(b/358375493): Remove this once TritonGPUToLLVMTypeConverter is + // splitted into smaller passes. populateGpuToNVVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { return signalPassFailure(); } } - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass) }; -namespace ttn = mlir::triton::nvgpu; -using ttn::OperandsAndConstraints; - class SparseWGMMAOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -978,13 +963,8 @@ class SparseWGMMAOpPattern : public OpRewritePattern { } }; -class SparseWGMMAOpToLLVMPass - : public PassWrapper> { - public: - SparseWGMMAOpToLLVMPass() = default; - - StringRef getArgument() const override { return "sparse-wgmma-to-llvm"; } - +struct SparseWGMMAOpToLLVMPass + : public impl::SparseWGMMAOpToLLVMPassBase { void runOnOperation() override { MLIRContext *context = &getContext(); auto pattern = std::make_unique(context); @@ -994,38 +974,38 @@ class SparseWGMMAOpToLLVMPass return signalPassFailure(); } } - - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SparseLocalLoadToLLVMPass) }; } // namespace -std::unique_ptr xla::gpu::CreateAddSparseDotEncodingPass( - int32_t num_warps, int32_t threads_per_warp, int32_t num_ctas) { - return std::make_unique(num_warps, threads_per_warp, - num_ctas); +std::unique_ptr CreateSparseAddEncodingPass(int32_t num_warps, + int32_t threads_per_warp, + int32_t num_ctas) { + SparseAddEncodingPassOptions options; + options.num_warps_ = num_warps; + options.threads_per_warp_ = threads_per_warp; + options.num_ctas_ = num_ctas; + return std::make_unique(options); } -std::unique_ptr xla::gpu::CreateSparseBlockedToMMAPass() { +std::unique_ptr CreateSparseBlockedToMMAPass() { return std::make_unique(); } -std::unique_ptr xla::gpu::CreateSparseLocalLoadToLLVMPass() { +std::unique_ptr CreateSparseRemoveLayoutConversionPass() { + return std::make_unique(); +} + +std::unique_ptr CreateSparseLocalLoadToLLVMPass() { return std::make_unique(); } -std::unique_ptr xla::gpu::CreateSparseDotOpToLLVMPass() { +std::unique_ptr CreateSparseDotOpToLLVMPass() { return std::make_unique(); } -std::unique_ptr xla::gpu::CreateSparseWGMMAOpToLLVMPass() { +std::unique_ptr CreateSparseWGMMAOpToLLVMPass() { return std::make_unique(); } -void xla::gpu::RegisterSparsePasses() { - registerPass([] { return std::make_unique(); }); - registerPass(CreateSparseBlockedToMMAPass); - registerPass(CreateSparseLocalLoadToLLVMPass); - registerPass(CreateSparseDotOpToLLVMPass); - registerPass(CreateSparseWGMMAOpToLLVMPass); -} +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index 20fa19c0a9fac6..97c3576f9d092d 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -47,7 +47,6 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" @@ -77,7 +76,6 @@ limitations under the License. #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" @@ -87,7 +85,6 @@ limitations under the License. #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Support/TypeID.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" @@ -110,9 +107,10 @@ limitations under the License. #include "xla/service/algorithm_util.h" #include "xla/service/dump.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" -#include "xla/service/gpu/fusions/mlir/ir/xla_gpu_ops.h" -#include "xla/service/gpu/fusions/mlir/passes.h" +#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/launch_dimensions.h" @@ -123,6 +121,7 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/gpu/target_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/service/gpu/triton_tiling_propagation.h" @@ -135,6 +134,7 @@ limitations under the License. #include "xla/stream_executor/launch_dim.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/util.h" +#include "xla/xla.pb.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -153,7 +153,6 @@ namespace gpu { namespace ma = ::mlir::arith; namespace mm = ::mlir::math; -namespace ml = ::mlir::LLVM; namespace mn = ::mlir::NVVM; namespace mt = ::mlir::triton; @@ -188,6 +187,10 @@ absl::StatusOr TritonType(mlir::OpBuilder b, PrimitiveType t) { return b.getI1Type(); case S8: return b.getI8Type(); + case S4: // The unpacking to i8 is supported by the emitter. + // We pass the s4 tensor as i8 tensor with the minor dimension having 2x + // less elements and unpack in the inner loop of the triton kernel. + return b.getI8Type(); case F8E5M2: return b.getFloat8E5M2Type(); case F8E4M3FN: @@ -647,16 +650,27 @@ struct DimProperties { int split_value; }; -absl::StatusOr EmitBroadcast( - ImplicitLocOpBuilder& b, const TritonFusionAnalysis* analysis, - TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, - const HloInstruction& broadcast, Value input) { +struct Side { + explicit Side(TritonFusionAnalysis::Scope scope, + std::vector tiled_dims = {}, + std::optional batch_dim_idx = std::nullopt) + : scope(scope), tiled_dims(tiled_dims), batch_dim_idx(batch_dim_idx) {} + TritonFusionAnalysis::Scope scope; + std::vector tiled_dims; + std::optional batch_dim_idx; + int64_t unpack_dim_idx = 0; +}; + +absl::StatusOr EmitBroadcast(ImplicitLocOpBuilder& b, + const TritonFusionAnalysis* analysis, + const Side& side, + const HloInstruction& broadcast, + Value input) { TF_RET_CHECK(analysis != nullptr); std::vector out_shape; - for (const DimProperties& dim : tiled_dimensions) { + for (const DimProperties& dim : side.tiled_dims) { const TensorIterationSpec::DimIterationSpec* spec = - analysis->IterSpec(scope, &broadcast, dim.index); + analysis->IterSpec(side.scope, &broadcast, dim.index); if (spec != nullptr && spec->at(0).stride > 0) { out_shape.push_back(dim.block_size); } @@ -673,10 +687,10 @@ absl::StatusOr EmitBroadcast( // Add broadcasted dimensions one by one. Value expanded_input = tensor_input; int dim_idx = 0; - for (const DimProperties& dim : tiled_dimensions) { - if (analysis->IterSpec(scope, &broadcast, dim.index) != nullptr && - analysis->IterSpec(scope, &broadcast, dim.index)->at(0).stride > 0) { - if (analysis->IterSpec(scope, broadcast.operand(0), dim.index) == + for (const DimProperties& dim : side.tiled_dims) { + if (auto* spec = analysis->IterSpec(side.scope, &broadcast, dim.index); + spec != nullptr && spec->at(0).stride > 0) { + if (analysis->IterSpec(side.scope, broadcast.operand(0), dim.index) == nullptr) { // Broadcasted dimension. expanded_input = b.create(expanded_input, dim_idx); @@ -690,8 +704,7 @@ absl::StatusOr EmitBroadcast( absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, + const TritonFusionAnalysis* analysis, const Side& side, absl::Span instructions, absl::flat_hash_map& values); @@ -797,7 +810,7 @@ absl::StatusOr EmitReduce( TF_ASSIGN_OR_RETURN( Value result, EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, region_values)); b.create(SmallVector({result})); b.setInsertionPointAfter(reduction); @@ -851,7 +864,7 @@ absl::StatusOr EmitNestedFusion( TF_RET_CHECK(to_emit.back() == fusion_computation->root_instruction()); return EmitScope(b, libdevice_path, device_info, /*analysis=*/nullptr, - TritonFusionAnalysis::Scope::OUTPUT, {}, to_emit, + Side(TritonFusionAnalysis::Scope::OUTPUT), to_emit, region_values); } @@ -1014,19 +1027,58 @@ absl::StatusOr EmitTiledScope( return values[tiled_computation.GetRoot()]; } +// Emit sequence of operations for unpacking 2xi4 -> i8. +absl::StatusOr EmitUnpackInt4(ImplicitLocOpBuilder& b, + const HloInstruction* hlo, + const Side& side, Value& value) { + VLOG(6) << "EmitUnpackInt4: " << hlo->ToString(); + auto input_type = mlir::cast(value.getType()); + if (input_type.getShape().size() != 2) { + return absl::InvalidArgumentError( + absl::StrCat("UnpackInt4 works only for 2d inputs: ", hlo->ToString())); + } + // We use shifts instead the mask because we need to keep the sign bit. + Value shift4 = + Splat(b, CreateConst(b, b.getI8Type(), 4), input_type.getShape()); + Value lo = b.create(b.create(value, shift4), shift4); + Value hi = b.create(value, shift4); + Value result = b.create(hi, lo); + SmallVector result_shape(input_type.getShape()); + result_shape[side.unpack_dim_idx] *= 2; + if (side.unpack_dim_idx == 0) { + result = b.create(result, b.getDenseI32ArrayAttr({0, 2, 1})); + } + auto type = mlir::RankedTensorType::get(result_shape, b.getI8Type()); + return b.create(type, result, /*allow_reorder=*/false); +} + // Emit sequence of instructions using compatible tiling ordered producers // before consumers. absl::StatusOr EmitScope( ImplicitLocOpBuilder& b, absl::string_view libdevice_path, const se::DeviceDescription& device_info, - const TritonFusionAnalysis* analysis, TritonFusionAnalysis::Scope scope, - absl::Span tiled_dimensions, + const TritonFusionAnalysis* analysis, const Side& side, absl::Span instructions, absl::flat_hash_map& values) { for (const HloInstruction* hlo : instructions) { Value result; - if (hlo->opcode() == HloOpcode::kConcatenate || - hlo->opcode() == HloOpcode::kDynamicSlice) { + if (hlo->opcode() == HloOpcode::kConvert && + hlo->operand(0)->shape().element_type() == S4) { + if (!hlo->GetModule() + ->config() + .debug_options() + .xla_gpu_enable_triton_gemm_int4()) { + return absl::UnimplementedError( + "Int4 support is not enabled in the debug options."); + } + + TF_ASSIGN_OR_RETURN( + auto unpacked, EmitUnpackInt4(b, hlo, side, values[hlo->operand(0)])); + std::vector operands({unpacked}); + TF_ASSIGN_OR_RETURN(result, EmitElementwise(b, libdevice_path, + device_info, *hlo, operands)); + } else if (hlo->opcode() == HloOpcode::kConcatenate || + hlo->opcode() == HloOpcode::kDynamicSlice) { // Parameter loads and their concatenations are handled outside EmitScope. TF_RET_CHECK(values.contains(hlo)) << hlo->ToString(); continue; @@ -1042,9 +1094,8 @@ absl::StatusOr EmitScope( // Splat makes it a tensor to avoid type mismatches. result = Splat(b, constant, {}); } else if (hlo->opcode() == HloOpcode::kBroadcast) { - TF_ASSIGN_OR_RETURN( - result, EmitBroadcast(b, analysis, scope, tiled_dimensions, *hlo, - values[hlo->operand(0)])); + TF_ASSIGN_OR_RETURN(result, EmitBroadcast(b, analysis, side, *hlo, + values[hlo->operand(0)])); } else if (HloInstruction::IsOpElementwise(hlo->opcode())) { std::vector operands; operands.reserve(hlo->operands().size()); @@ -1079,86 +1130,6 @@ absl::StatusOr EmitScope( return values[instructions.back()]; } -// Extract additional attributes from an LLVM function that are not passed -// to the builder directly. -SmallVector GetExtraAttrs(ml::LLVMFuncOp func) { - llvm::StringSet<> registered_attr_names{ - func.getSymNameAttrName().getValue(), - func.getFunctionTypeAttrName().getValue(), - func.getLinkageAttrName().getValue(), - func.getDsoLocalAttrName().getValue(), - func.getCConvAttrName().getValue(), - func.getArgAttrsAttrName().getValue(), - func.getFunctionEntryCountAttrName().getValue()}; - return llvm::to_vector( - llvm::make_filter_range(func->getAttrs(), [&](mlir::NamedAttribute attr) { - return !registered_attr_names.contains(attr.getName().getValue()); - })); -} - -// Strip address spaces from function parameters. -void StripParameterAddressSpaces(mlir::RewriterBase& rewriter, - ml::LLVMFuncOp func) { - // Figure out what the new signature should be. - ml::LLVMFunctionType func_ty = func.getFunctionType(); - SmallVector generic_func_params( - llvm::map_range(func_ty.getParams(), [](Type type) -> Type { - auto ptr_ty = mlir::dyn_cast(type); - if (!ptr_ty) return type; - if (ptr_ty.getAddressSpace() != mn::kGlobalMemorySpace) return type; - return ml::LLVMPointerType::get(ptr_ty.getContext()); - })); - ml::LLVMFunctionType generic_func_ty = - func_ty.clone(generic_func_params, func_ty.getReturnTypes()); - - // Create a function with the new signature. - SmallVector arg_attrs(llvm::map_range( - func.getArgAttrsAttr().getValue(), [](mlir::Attribute attr) { - return mlir::cast(attr); - })); - auto generic_func = rewriter.create( - func.getLoc(), func.getSymName(), generic_func_ty, func.getLinkage(), - func.getDsoLocal(), func.getCConv(), /*comdat=*/nullptr, - GetExtraAttrs(func), arg_attrs, func.getFunctionEntryCount()); - - // Convert generic address spaces back to original ones within the function - // body. - mlir::Block* entry = generic_func.addEntryBlock(rewriter); - rewriter.setInsertionPointToEnd(entry); - SmallVector converted_args; - for (auto [arg, type] : - llvm::zip(generic_func.getArguments(), func_ty.getParams())) { - Value converted = arg; - if (arg.getType() != type) { - converted = rewriter.create(arg.getLoc(), type, arg); - } - converted_args.push_back(converted); - } - - // Move the rest of function body from the original function. - rewriter.cloneRegionBefore(func.getBody(), generic_func.getBody(), - generic_func.getBody().end()); - rewriter.eraseOp(func); - rewriter.mergeBlocks(entry->getNextNode(), entry, converted_args); -} - -// Rewrite signatures of kernel functions to use generic data pointers and -// cast them to global ones within the kernel. -struct GeneralizeKernelSignaturePass - : mlir::PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GeneralizeKernelSignaturePass); - void runOnOperation() override { - mlir::IRRewriter rewriter(&getContext()); - getOperation()->walk([&](ml::LLVMFuncOp func) { - if (!func->hasAttr(mn::NVVMDialect::getKernelFuncAttrName())) { - return; - } - rewriter.setInsertionPointAfter(func); - StripParameterAddressSpaces(rewriter, func); - }); - } -}; - const TensorIterationSpec::DimIterationSpec* GetLhsNoncontractingSplitSpec( const TritonFusionAnalysis& analysis, int64_t lhs_noncontracting_dim_idx) { const TensorIterationSpec::DimIterationSpec* result = nullptr; @@ -1385,12 +1356,6 @@ absl::Status ValidateMatMulConfig(const TritonGemmConfig& config, return absl::OkStatus(); } -struct Side { - TritonFusionAnalysis::Scope scope; - std::vector tiled_dims; - std::optional batch_dim_idx; -}; - // if (index < limits[0]) { // return choices[0]; // } else if (index < limits[1]) { @@ -1507,7 +1472,7 @@ class MatMulEmitterHelper { } } } - CHECK(to_order.insert(current).second); + to_order.insert(current); to_add.pop(); } } @@ -1522,11 +1487,10 @@ class MatMulEmitterHelper { return to_emit; } - Value MakeInput(Side& side, int64_t operand_index, + Value MakeInput(const Side& side, int64_t operand_index, absl::flat_hash_map& values) { return *EmitScope( - b_, libdevice_path_, device_info_, &analysis_, side.scope, - side.tiled_dims, + b_, libdevice_path_, device_info_, &analysis_, side, dot_instr_->parent()->MakeInstructionPostOrderFrom( const_cast(*dot_instr_->operand(operand_index))), values); @@ -1551,6 +1515,7 @@ class MatMulEmitterHelper { Value base; std::vector bounds; std::vector strides; + std::vector strides_sizes; // We use it to detect the minor dim. // Offsets from tensor origin, same for all thread blocks. std::vector tensor_offsets; std::vector block_dims; @@ -1641,7 +1606,9 @@ class MatMulEmitterHelper { for (const HloInstruction* input : inputs) { specs.push_back( analysis_.IterSpec(side.scope, input, properties.index)); - input_strides.push_back(Cst64(specs.back()->at(0).stride)); + const auto stride = specs.back()->at(0).stride; + strides_sizes.push_back(stride); + input_strides.push_back(Cst64(stride)); input_offsets.push_back(b_.create( pid_offset, Cst32(specs.back()->at(0).slice_start))); input_bounds.push_back(Cst64(specs.back()->at(0).count)); @@ -1816,9 +1783,14 @@ class MatMulEmitterHelper { if (has_batch_offset) { Value pid_batch = b_.create(launch_config_.batch_program_id_dim); + Value pid_offset_batch = b_.create( b_.create(Cst(offset_batch), ConvertScalar(pid_batch)), batch_stride); + + if (hlo->shape().element_type() == PrimitiveType::S4) { + pid_offset_batch = b_.create(pid_offset_batch, Cst(2)); + } base = AddPtr(b_, base, pid_offset_batch); } @@ -1837,6 +1809,18 @@ class MatMulEmitterHelper { // Load of a scalar. return base; } + if (hlo->shape().element_type() == PrimitiveType::S4) { + // Divide the stride by 2 for S4 inputs except for the minor dimension. + for (int i = 0; i < strides.size(); ++i) { + // We assume that the pack happens along the minor dimension. + if (strides_sizes[i] == 1) { // minor dimension + auto s4_bound = b_.create(bounds[i], Cst64(2)); + bounds[i] = s4_bound; + continue; + } + strides[i] = b_.create(strides[i], Cst64(2)); + } + } auto tensor_ptr = mlir::cast( b_.create(base, bounds, strides, tensor_offsets, block_dims, dim_order) @@ -2156,6 +2140,140 @@ absl::Status CheckGemmTilingComplexityHeuristic( return absl::OkStatus(); } +class Scopes { + public: + Scopes(ImplicitLocOpBuilder& b, const TritonFusionAnalysis& analysis, + const MatMulDims& dims, const TritonGemmConfig& config, + const MatMulLaunchConfig launch_config, bool is_sparse) + : lhs_(TritonFusionAnalysis::Scope::LHS), + rhs_(TritonFusionAnalysis::Scope::RHS), + out_(TritonFusionAnalysis::Scope::OUTPUT) { + constexpr int group_m = 8; + const int64_t width = group_m * launch_config.grid_n; + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; + + auto pid_nc = b.create( + launch_config.noncontracting_program_id_dim); + pid_k_ = (config.split_k > 1) + ? b.create(mt::ProgramIDDim::Z) + : Value{}; + + auto group_id = b.create(pid_nc, c32(width)); + ma::ConstantOp group_m_op = c32(group_m); + auto first_pid_m = b.create(group_id, group_m_op); + auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); + auto group_size = b.create( + b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, + group_m_op); + + pid_m_ = b.create(first_pid_m, + b.create(pid_nc, group_size)); + + pid_n_ = b.create(b.create(pid_nc, c32(width)), + group_size); + + int lhs_non_contracting_block_size = config.block_m; + int lhs_contracting_block_size = config.block_k; + int lhs_unpack_dim_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::LHS)) { + if (dims.lhs_contracting_dim_idx > dims.lhs_noncontracting_dim_idx) { + // lhs is int4 and the contracting dimension is minor. + lhs_contracting_block_size /= 2; + lhs_unpack_dim_idx = 1; + } else { + // lhs is int4 and the contracting dimension is major. + lhs_non_contracting_block_size /= 2; + lhs_unpack_dim_idx = 0; + } + } + if (is_sparse) { + lhs_contracting_block_size /= 2; + } + lhs_.tiled_dims = { + DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + lhs_non_contracting_block_size, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + lhs_contracting_block_size, config.split_k)}; + lhs_.batch_dim_idx = dims.lhs_batch_dim_idx; + lhs_.unpack_dim_idx = lhs_unpack_dim_idx; + + int rhs_contracting_block_size = config.block_k; + int rhs_non_contracting_block_size = config.block_n; + int rhs_unpack_dim_idx = 0; + if (is_int4_param(analysis, TritonFusionAnalysis::Scope::RHS)) { + if (dims.rhs_contracting_dim_idx > dims.rhs_noncontracting_dim_idx) { + // rhs is int4 and the contracting dimension is minor. + rhs_contracting_block_size /= 2; + rhs_unpack_dim_idx = 0; + } else { + // rhs is int4 and the contracting dimension is major. + rhs_non_contracting_block_size /= 2; + rhs_unpack_dim_idx = 1; + } + } + rhs_.tiled_dims = { + DimProperties(dims.rhs_contracting_dim_idx, pid_k_, + rhs_contracting_block_size, config.split_k), + DimProperties(dims.rhs_noncontracting_dim_idx, pid_n_, + rhs_non_contracting_block_size, + /*split_value=*/1)}; + rhs_.batch_dim_idx = dims.rhs_batch_dim_idx; + rhs_.unpack_dim_idx = rhs_unpack_dim_idx; + + out_.tiled_dims = {DimProperties(dims.out_lhs_noncontracting_dim_idx, + pid_m_, config.block_m, + /*split_value=*/1), + DimProperties(dims.out_rhs_noncontracting_dim_idx, + pid_n_, config.block_n, + /*split_value=*/1)}; + out_.batch_dim_idx = dims.out_batch_dim_idx; + + if (is_sparse) { + meta_ = Side{TritonFusionAnalysis::Scope::META, + /*tiled_dims=*/ + {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m_, + config.block_m, + /*split_value=*/1), + DimProperties(dims.lhs_contracting_dim_idx, pid_k_, + config.block_k / 16, config.split_k)}, + dims.lhs_batch_dim_idx}; + } + } + + std::vector input_scopes() const { + if (meta_.has_value()) { + return {&lhs_, &rhs_, &meta_.value()}; + } + return {&lhs_, &rhs_}; + } + const Side& lhs() const { return lhs_; } + const Side& rhs() const { return rhs_; } + const Side& out() const { return out_; } + const std::optional& meta() const { return meta_; } + const Value& pid_m() const { return pid_m_; } + const Value& pid_k() const { return pid_k_; } + const Value& pid_n() const { return pid_n_; } + + static bool is_int4_param(const TritonFusionAnalysis& analysis, + TritonFusionAnalysis::Scope scope) { + const ConstHloInstructionSet& params = analysis.ScopeParameters(scope); + return params.size() == 1 && + (*params.cbegin())->shape().element_type() == S4; + } + + private: + Side lhs_; + Side rhs_; + Side out_; + std::optional meta_; + + Value pid_m_; + Value pid_k_; + Value pid_n_; +}; + } // namespace // Variable naming: lhs [m, k] x rhs [k, n] -> out [m, n]. @@ -2240,30 +2358,6 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, index_ty, dims, launch_config, analysis); - constexpr int group_m = 8; - const int64_t width = group_m * launch_config.grid_n; - - auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; - - auto pid_nc = - b.create(launch_config.noncontracting_program_id_dim); - Value pid_k = (split_k > 1) - ? b.create(mt::ProgramIDDim::Z) - : Value{}; - - auto group_id = b.create(pid_nc, c32(width)); - ma::ConstantOp group_m_op = c32(group_m); - auto first_pid_m = b.create(group_id, group_m_op); - auto sub0 = b.create(c32(launch_config.grid_m), first_pid_m); - auto group_size = b.create( - b.create(ma::CmpIPredicate::slt, sub0, group_m_op), sub0, - group_m_op); - - auto pid_m = b.create(first_pid_m, - b.create(pid_nc, group_size)); - auto pid_n = b.create(b.create(pid_nc, c32(width)), - group_size); - TF_ASSIGN_OR_RETURN(mlir::FloatType acc_ty, emitter.GetDotAccumulatorType()); ma::ConstantOp accumulator_init = @@ -2274,46 +2368,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, absl::flat_hash_map iter_args_to_inputs; absl::flat_hash_map> iter_args_to_boundary_checks; - Side lhs{TritonFusionAnalysis::Scope::LHS, - /*tiled_dims=*/ - {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k, - block_k / (1 + is_sparse), split_k)}, - dims.lhs_batch_dim_idx}; - Side rhs{ - TritonFusionAnalysis::Scope::RHS, - /*tiled_dims=*/ - {DimProperties(dims.rhs_contracting_dim_idx, pid_k, block_k, split_k), - DimProperties(dims.rhs_noncontracting_dim_idx, pid_n, block_n, - /*split_value=*/1)}, - dims.rhs_batch_dim_idx}; - Side out{TritonFusionAnalysis::Scope::OUTPUT, - /*tiled_dims=*/ - {DimProperties(dims.out_lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.out_rhs_noncontracting_dim_idx, pid_n, block_n, - /*split_value=*/1)}, - dims.out_batch_dim_idx}; - - std::vector scopes = {lhs, rhs}; - if (is_sparse) { - scopes.push_back( - {TritonFusionAnalysis::Scope::META, - /*tiled_dims=*/ - {DimProperties(dims.lhs_noncontracting_dim_idx, pid_m, block_m, - /*split_value=*/1), - DimProperties(dims.lhs_contracting_dim_idx, pid_k, block_k / 16, - split_k)}, - dims.lhs_batch_dim_idx}); - } + // Calculate the sizes of the lhs, rhs, meta, and output sides. + Scopes scopes(b, analysis, dims, config, launch_config, is_sparse); + + auto c32 = [&](int64_t v) { return CreateConst(b, b.getI32Type(), v); }; constexpr size_t kLhsMetaOperandIdx = HloDotInstruction::kOperands; size_t lsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::LHS).size(); size_t rsize = ScopeInputs(analysis, TritonFusionAnalysis::Scope::RHS).size(); absl::flat_hash_map triton_type_for_input; - for (const Side& side : {lhs, rhs}) { + for (const Side& side : {scopes.lhs(), scopes.rhs()}) { for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { TF_ASSIGN_OR_RETURN(Type input_ty, TritonType(b, input->shape().element_type())); @@ -2330,7 +2395,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, // Load tiles of all parameters of LHS and RHS scopes and advance pointers. for (int i = 0; i < iter_args.size() - 1; ++i) { const int index = i < lsize ? 0 : i < lsize + rsize ? 1 : 2; - Side& side = scopes[index]; + const Side& side = *(scopes.input_scopes()[index]); const HloInstruction* param_hlo = iter_args_to_inputs[i]; Type param_ty = index == kLhsMetaOperandIdx @@ -2370,10 +2435,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, } // Emit all operations of LHS and RHS scopes. - Value dot_input_lhs = emitter.MakeInput(lhs, 0, values[0]); - Value dot_input_rhs = emitter.MakeInput(rhs, 1, values[1]); + Value dot_input_lhs = emitter.MakeInput(scopes.lhs(), 0, values[0]); + Value dot_input_rhs = emitter.MakeInput(scopes.rhs(), 1, values[1]); Value dot_input_meta = - is_sparse ? emitter.MakeInput(scopes.back(), 2, values[2]) : Value{}; + is_sparse ? emitter.MakeInput(*scopes.meta(), 2, values[2]) : Value{}; // Operation in the fusion before the dot can alter the elements of the // tiles that were zero masked during loads. These have to be zeroed here @@ -2386,9 +2451,10 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, auto elements_in_tile = b.create(c32(dims.k / denom), ki); int size = block_k / denom; auto range_k = Range(b, size); - if (pid_k != nullptr) { + if (scopes.pid_k() != nullptr) { range_k = b.create( - range_k, Splat(b, b.create(pid_k, c32(size)), size)); + range_k, + Splat(b, b.create(scopes.pid_k(), c32(size)), size)); } auto ty = mlir::cast(input.getType()); TensorValue range_expanded = mlir::cast( @@ -2464,15 +2530,15 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, SmallVector iter_args; iter_args.reserve(lsize + rsize + 1 + is_sparse); - for (const Side& side : scopes) { - for (const HloInstruction* input : ScopeInputs(analysis, side.scope)) { + for (const Side* side : scopes.input_scopes()) { + for (const HloInstruction* input : ScopeInputs(analysis, side->scope)) { TF_RET_CHECK( iter_args_to_inputs.insert({iter_args.size(), input}).second); TF_ASSIGN_OR_RETURN(SmallVector arguments, GetArguments(fn, *input)); TF_ASSIGN_OR_RETURN(Value tensor_ptr, emitter.EmitTensorPointer( - input, side, arguments, pid_k, + input, *side, arguments, scopes.pid_k(), iter_args_to_boundary_checks[iter_args.size()])); iter_args.push_back(tensor_ptr); } @@ -2499,17 +2565,17 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, std::vector boundary_checks; TF_ASSIGN_OR_RETURN(SmallVector arguments, GetArguments(fn, *input)); - TF_ASSIGN_OR_RETURN(Value tensor_pointer, - emitter.EmitTensorPointer(input, out, arguments, - pid_k, boundary_checks)); + TF_ASSIGN_OR_RETURN( + Value tensor_pointer, + emitter.EmitTensorPointer(input, scopes.out(), arguments, + scopes.pid_k(), boundary_checks)); TF_RET_CHECK(values_out .insert({input, EmitParameterLoad(b, tensor_pointer, boundary_checks)}) .second); } TF_RETURN_IF_ERROR(EmitScope(b, libdevice_path, device_info, &analysis, - TritonFusionAnalysis::Scope::OUTPUT, - out.tiled_dims, to_emit, values_out) + scopes.out(), to_emit, values_out) .status()); } @@ -2522,9 +2588,9 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, TF_ASSIGN_OR_RETURN( Value tensor_pointer, emitter.EmitTensorPointer( - producer, out, - {fn.getArgument(i + dot_instr->parent()->num_parameters())}, pid_k, - boundary_checks)); + producer, scopes.out(), + {fn.getArgument(i + dot_instr->parent()->num_parameters())}, + scopes.pid_k(), boundary_checks)); b.create(tensor_pointer, values_out[producer], boundary_checks, mt::CacheModifier::NONE, mt::EvictionPolicy::NORMAL); } @@ -2665,8 +2731,9 @@ absl::Status EmitGeneric(mlir::OpBuilder builder, const BlockLevelParameters& block_level_parameters) { const HloComputation* computation = fusion->fused_instructions_computation(); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeComputation(*computation, - builder.getContext()); + SymbolicTileAnalysis::AnalyzeComputation( + *computation, builder.getContext(), + TritonEmitterConstraints::GetBuilder()); if (std::holds_alternative(symbolic_tile_analysis_or)) { return Internal( "Unsupported fusion in EmitGeneric: %s", @@ -2756,10 +2823,19 @@ absl::Status CreateInternalError(std::string_view message, os << message << "\n"; os << fusion->fused_instructions_computation()->ToString() << "\n"; os << "triton_module: \n"; - triton_module->print(os); + triton_module->print(os, mlir::OpPrintingFlags().enableDebugInfo(true, true)); return absl::InternalError(err); } +absl::Status DoSupportType(const DebugOptions& debug_options, + PrimitiveType type) { + if (type == S4 && !debug_options.xla_gpu_enable_triton_gemm_int4()) { + return absl::FailedPreconditionError( + "Int4 support is not enabled in the debug options."); + } + return absl::OkStatus(); +} + absl::StatusOr> CreateTritonModule( absl::string_view fn_name, const HloFusionInstruction* fusion, const se::DeviceDescription& device_info, @@ -2776,10 +2852,12 @@ absl::StatusOr> CreateTritonModule( llvm_ir::CreateMlirModuleOp(loc); b.setInsertionPointToEnd(triton_module->getBody()); + const auto debug_options = fusion->GetModule()->config().debug_options(); // Build Triton kernel. SmallVector fn_arg_types; for (HloInstruction* p : hlo_computation->parameter_instructions()) { PrimitiveType type = p->shape().element_type(); + TF_RETURN_IF_ERROR(DoSupportType(debug_options, type)); Type ir_type; if (type == U16) { ir_type = b.getI16Type(); @@ -2837,10 +2915,17 @@ absl::StatusOr> CreateTritonModule( "Failed to create Triton module for fusion:", fusion, *triton_module); } - VLOG(6) << llvm_ir::DumpToString(*triton_module); + auto dump_triton_ir = [&]() { + std::string triton_ir; + llvm::raw_string_ostream os(triton_ir); + triton_module->print(os, + mlir::OpPrintingFlags().enableDebugInfo(true, true)); + return triton_ir; + }; + VLOG(6) << dump_triton_ir(); if (DumpingEnabledForHloModule(*hlo_computation->parent())) { DumpToFileInDirOrStdout(*hlo_computation->parent(), "triton_ir", "ttir", - llvm_ir::DumpToString(*triton_module)); + dump_triton_ir()); } return std::move(triton_module); @@ -2948,16 +3033,16 @@ absl::StatusOr CompileTritonToLLVM( .ok()) { return Internal("Failed to create Triton pipeline."); } - if (log_stream.has_value()) { - pm.printAsTextualPipeline(log_stream.value()); - log_stream->write("\n\n", 2); - } // Triton generates pointers to the global address space, while XLA needs a // kernel signature with pointers to the generic address space. - pm.addPass(std::make_unique()); + pm.addPass(CreateGeneralizeKernelSignaturePass()); // llvm::Linker::linkModules() segfaults if we don't strip locations. pm.addPass(mlir::createStripDebugInfoPass()); + if (log_stream.has_value()) { + pm.printAsTextualPipeline(log_stream.value()); + log_stream->write("\n\n", 2); + } bool succeeded = mlir::succeeded(pm.run(triton_module)); if (log_stream.has_value()) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index fe133d88f44e45..3f7c3bc26417d9 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -17,7 +17,6 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_FUSION_EMITTER_H_ #include -#include #include #include @@ -87,19 +86,6 @@ absl::Status EmitMatMul(mlir::OpBuilder b, absl::string_view libdevice_path, mlir::triton::FuncOp fn, const BlockLevelParameters& block_level_parameters); -// Generate Softmax in Triton IR inside 'fn'. -// Use execution parameters from 'block_level_parameters'. -absl::Status EmitSoftMax(mlir::OpBuilder b, absl::string_view libdevice_path, - const se::DeviceDescription& device_info, - const HloFusionInstruction* fusion, - mlir::triton::FuncOp fn, - const BlockLevelParameters& block_level_parameters); - -using TritonIrEmitter = std::function; - // Load the MLIR dialects required for Triton IR generation. void LoadMlirDialectsForTriton(mlir::MLIRContext& mlir_context); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc index 9a2d2d6dbb7520..016d9925c4dcff 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_legacy_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include #include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" @@ -51,8 +52,8 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/path.h" @@ -108,6 +109,7 @@ class TritonGemmTest : public TritonTest { debug_options.set_xla_gpu_enable_split_k_autotuning(false); // Always rewrite Gemms with Triton regardless of size. debug_options.set_xla_gpu_gemm_rewrite_size_threshold(0); + debug_options.set_xla_gpu_enable_triton_gemm_int4(true); return debug_options; } @@ -136,6 +138,245 @@ class TritonGemmTestWithoutTritonGemmAny : public TritonGemmTest { } }; +TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDim) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[1024,8]{1,0} parameter(0) + lhs_converted = bf16[1024,8]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={0}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[1024,8]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, LHSInt4NonMinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with non minor + // lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + lhs_converted = bf16[16,1024,8]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={1}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,1024,8]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmTest, LHSInt4MinorContractingDim) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, Int4ConvertPlusNegate) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[8,1024]{1,0} parameter(0) + lhs_converted = bf16[8,1024]{1,0} convert(lhs) + lhs_negated = bf16[8,1024]{1,0} negate(lhs_converted) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} dot(lhs_negated, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = bf16[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4]{1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, LHSInt4MinorContractingDimWithBatchDim0) { + // We prove that triton can handle int4 dot with minor lhs_contracting_dim. + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + lhs_converted = bf16[16,8,1024]{2,1,0} convert(lhs) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} dot(lhs_converted, rhs), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = s4[16,8,1024]{2,1,0} parameter(0) + rhs = bf16[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4]{2,1,0} fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + rhs_converted = bf16[1024,4]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={0} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[1024,4]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + rhs_converted = bf16[4,1024]{1,0} convert(rhs) + ROOT dot = bf16[8,4] dot(lhs, rhs_converted), + lhs_contracting_dims={1}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[8,1024]{1,0} parameter(0) + rhs = s4[4,1024]{1,0} parameter(1) + ROOT dot = bf16[8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithMinorContractingDimWithBatchDim) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + rhs_converted = bf16[16,1024,4]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,1024,4]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + +TEST_F(TritonGemmTest, RHSInt4TestWithNotMinorContractingDimWithBatchDim0) { + const std::string kHloText = R"( + HloModule t + + triton_computation { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + rhs_converted = bf16[16,4,1024]{2,1,0} convert(rhs) + ROOT dot = bf16[16,8,4] dot(lhs, rhs_converted), + lhs_batch_dims={0}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={2} + } + + ENTRY main { + lhs = bf16[16,8,1024]{2,1,0} parameter(0) + rhs = s4[16,4,1024]{2,1,0} parameter(1) + ROOT dot = bf16[16,8,4] fusion(lhs, rhs), kind=kCustom, + calls=triton_computation, + backend_config={"fusion_backend_config": {"kind":"__triton_gemm"}} + } + )"; + EXPECT_TRUE(RunAndCompareNoHloPasses( + kHloText, ErrorSpec{/*aabs=*/1e-2, /*arel=*/1e-2})); +} + TEST_F(TritonTest, TestGemm) { const std::string kHloText = R"( HloModule t, is_scheduled=true @@ -1549,6 +1790,35 @@ ENTRY e { kHloText, ErrorSpec{/*aabs=*/1e-4, /*arel=*/1e-6})); } +TEST_F(TritonGemmTest, MultiplePathsToSameOperandWorks) { + const std::string kHloText = R"( +triton_computation { + p0 = bf16[8192,512]{1,0} parameter(0) + p1 = bf16[512,512]{1,0} parameter(1) + dot = bf16[8192,512]{1,0} dot(bf16[8192,512]{1,0} p0, bf16[512,512]{1,0} p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + p2 = bf16[8192,512]{1,0} parameter(2) + multiply.1 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} dot, bf16[8192,512]{1,0} p2) + ROOT multiply.2 = bf16[8192,512]{1,0} multiply(bf16[8192,512]{1,0} multiply.1, bf16[8192,512]{1,0} p2) +} + +ENTRY e { + p0 = bf16[8192,512]{1,0} parameter(0) + p1 = bf16[512,512]{1,0} parameter(1) + p2 = bf16[8192,512]{1,0} parameter(2) + ROOT fusion = bf16[8192,512]{1,0} fusion(p0,p1,p2), kind=kCustom, calls=triton_computation, + backend_config={"fusion_backend_config": + {"kind":"__triton_gemm", "triton_gemm_config":{"block_m":"64","block_n":"256","block_k":"32","split_k":"1","num_stages":"4","num_warps":"4","num_ctas":"1"}}} +})"; + + TF_ASSERT_OK( + CreateTritonIrAndFileCheckForDot(this, kHloText, "triton_computation", R"( + CHECK: tt.dot + CHECK-SAME: tensor<64x32xbf16> * tensor<32x256xbf16> -> tensor<64x256xf32> + CHECK: arith.mulf + CHECK: arith.mulf + )")); +} + class TritonGemmDynamicSliceClampingTest : public TritonTest, public ::testing::WithParamInterface {}; @@ -2529,10 +2799,12 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetOptimizedModule(kHloText)); - EXPECT_THAT(module->entry_computation()->root_instruction(), - GmockMatch(m::Transpose( - m::Fusion(m::Parameter(), m::Parameter()) - .WithFusionKind(HloInstruction::FusionKind::kCustom)))); + EXPECT_THAT( + module->entry_computation()->root_instruction(), + GmockMatch(m::Bitcast( + m::Fusion(m::Fusion(m::Parameter(), m::Parameter()) + .WithFusionKind(HloInstruction::FusionKind::kCustom)) + .WithFusionKind(HloInstruction::FusionKind::kInput)))); EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } @@ -2620,9 +2892,9 @@ TEST_F(TritonGemmTestAny, HloModule t ENTRY e { - parameter_0 = f32[32,4000] parameter(0) - parameter_1 = f32[32,4000,6400] parameter(1) - ROOT dot = f32[32,6400] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + parameter_0 = f32[1,40] parameter(0) + parameter_1 = f32[1,40,250000] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; @@ -2642,9 +2914,9 @@ TEST_F(TritonGemmTestAny, HloModule t ENTRY e { - parameter_0 = f32[32,4000,6400] parameter(0) - parameter_1 = f32[32,4000] parameter(1) - ROOT dot = f32[32,6400] dot(parameter_0, parameter_1), lhs_batch_dims={0}, + parameter_0 = f32[1,40,250000] parameter(0) + parameter_1 = f32[1,40] parameter(1) + ROOT dot = f32[1,250000] dot(parameter_0, parameter_1), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={0}, rhs_contracting_dims={1} })"; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc index 9ca1b90100e0a9..28f630ba5f5a98 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_device_test.cc @@ -38,8 +38,8 @@ limitations under the License. #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -214,7 +214,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 CHECK: arith.index_castui %[[PID]] : i32 to index @@ -272,7 +272,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK-LABEL: tt.func @triton_fn( CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -339,7 +339,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P2:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P3:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK-DAG: %[[PID:.*]] = tt.get_program_id x : i32 CHECK-DAG: %[[PID_INDEX:.*]] = arith.index_castui %[[PID]] : i32 to index @@ -517,7 +517,7 @@ ENTRY main { TF_ASSERT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 1, 16}), "triton_softmax_computation", R"( -// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> +// CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 16) // CHECK-LABEL: tt.func @triton_fn( // CHECK-SAME: %[[P0:[A-Za-z0-9_]*]]: !tt.ptr // CHECK-SAME: %[[P1:[A-Za-z0-9_]*]]: !tt.ptr @@ -609,6 +609,64 @@ ENTRY entry { // TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should be // moved to deviceless test file. +TEST_F(TritonEmitterTest, + EmitterFailsIfFusionBackendConfigDoesNotSatisfyConstraints) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation + broadcast = f32[8192,50304] broadcast(reduce), dimensions={0} + ROOT subtract = f32[8192,50304] subtract(param_0, broadcast) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[8192,50304] fusion(param_0), + kind=kCustom, calls=fused_computation, + backend_config={"fusion_backend_config": { + "kind":"__triton", + "block_level_fusion_config": {"output_tile_sizes": ["1024","1"], + "num_warps": "1"}}} +})")); + const HloFusionInstruction* triton_fusion = Cast( + hlo_module->entry_computation()->root_instruction()); + + auto compute_capability = + se::CudaComputeCapability{se::CudaComputeCapability::HOPPER, /*minor=*/0}; + const se::DeviceDescription dev_info = + TestGpuDeviceInfo::RTXA6000DeviceInfo(compute_capability); + llvm::LLVMContext llvm_ctx; + llvm::Module llvm_module("module", llvm_ctx); + mlir::MLIRContext mlir_context; + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {1024, 1}; + block_level_parameters.num_warps = 1; + + // Because of reduce, we need to load full rows from param_0 and the load tile + // will be 1024 * 65536 = 67108864 elements, that is larger than the limit of + // 1048576. + EXPECT_THAT( + TritonWrapper("test_fn", triton_fusion, compute_capability, dev_info, + block_level_parameters, &llvm_module, mlir_context), + tsl::testing::StatusIs( + absl::StatusCode::kInvalidArgument, + ::testing::HasSubstr( + "Tile parameters 1024, 1 do not satisfy constraints."))); +} + +// TODO(b/353484968): Tests that don't run RunAndCompareNoHloPasses should b +// moved to deviceless test file. TEST_F(TritonEmitterTest, TestGenericEmitterReductionFusion) { const std::string kHloText = R"( HloModule t @@ -674,7 +732,7 @@ ENTRY main { TF_EXPECT_OK(CreateTritonIrAndFileCheck(this, kHloText, FromOutputTileSizes({1, 127}), "triton_softmax_computation", R"( -CHECK: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 127)> +CHECK: #[[MAP:.*]] = #xla_gpu.indexing_map<(d0) -> (d0 * 127) CHECK: tt.func @triton_fn(%[[P0:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[P1:[^:]*]]: !tt.ptr {tt.divisibility = 16 : i32}) { CHECK: %[[PID:.*]] = tt.get_program_id x : i32 CHECK: arith.index_castui %[[PID]] : i32 to index diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc index 44611bda590dfa..92616fa78f7225 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_mem_utils_test.cc @@ -48,10 +48,11 @@ limitations under the License. #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" // IWYU pragma: keep #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" @@ -112,7 +113,9 @@ TritonMakeTensorPtrTest::CreateAndTileParameterHloInstruction( verified_hlo_module->entry_computation()->root_instruction()); SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeFusion(*fusion_adaptor, &mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + *fusion_adaptor, &mlir_context_, + TritonEmitterConstraints::GetBuilder()); CHECK( std::holds_alternative(symbolic_tile_analysis_or)); diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc index f946cc4257e5e0..0c66c03d8aed7c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter_parametrized_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" @@ -52,13 +52,6 @@ struct MixTypeParams { class MixedTypeTest : public GpuCodegenTest, public ::testing::WithParamInterface { public: - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } - DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); // We are testing Triton, remove cuBLAS fallback for these tests. @@ -803,13 +796,6 @@ class TritonSoftmaxTest : public GpuCodegenTest, debug_options.clear_xla_disable_hlo_passes(); return debug_options; } - - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } }; TEST_P(TritonSoftmaxTest, CanFuseAndEmitExactSoftmax) { diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc index 44c9d51c5921d0..942e27f3226982 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.cc @@ -15,10 +15,7 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_support.h" -#include -#include #include -#include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_set.h" @@ -33,279 +30,43 @@ limitations under the License. #include "xla/layout.h" #include "xla/primitive_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/variant_visitor.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/tensor_float_32_utils.h" namespace xla { namespace gpu { -namespace legacy_triton { - -bool IsDistributiveOverAddition(const HloInstruction& hlo) { - // The list is most likely incomplete. - // For example division can be added too but only for operand #0. - if (hlo.opcode() == HloOpcode::kMultiply || - hlo.opcode() == HloOpcode::kNegate || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || - hlo.opcode() == HloOpcode::kTranspose || - hlo.opcode() == HloOpcode::kConvert || - hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kSlice) { - return true; - } - return false; -} - -// Types that are supported by Triton as dot output. -// -// BF16 is supported in a sense that all operations on it are implemented -// through F32 and converts have to be inserted into the HLO graph, but -// they can be missing during fusion. -bool IsTritonSupportedDotOutputType( - const PrimitiveType t, const se::GpuComputeCapability& gpu_version) { - switch (t) { - case F16: - case F32: - return true; - case F8E5M2: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastAmpere(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); - - case F8E4M3FN: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return cc.IsAtLeastHopper(); - }, - [](const se::RocmComputeCapability& cc) { - return false; - }}, - gpu_version); - case BF16: - return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { - return true; - }, - [](const se::RocmComputeCapability& cc) { - return cc.has_bf16_dtype_support(); - }}, - gpu_version); - default: - return false; - } -}; - -// Data types that are supported by the Triton emitters. -// TODO(b/266862493): Support more data types (F8, F64, etc.). -bool IsTritonSupportedDataType(PrimitiveType type, - const se::GpuComputeCapability& gpu_version) { - if (IsTritonSupportedDotOutputType(type, gpu_version)) { - return true; - } - switch (type) { - case PRED: - case S8: - case S16: - case S32: - return true; - default: - return false; - } -} -std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - std::vector ret = {HloOpcode::kConvert}; - if (element_type == PrimitiveType::PRED) { - ret.push_back(HloOpcode::kNot); - return ret; - } - ret.push_back(HloOpcode::kAbs); - ret.push_back(HloOpcode::kNegate); - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, - HloOpcode::kExpm1, HloOpcode::kFloor, - HloOpcode::kCeil, HloOpcode::kLog, - HloOpcode::kLog1p, HloOpcode::kRsqrt, - HloOpcode::kSin, HloOpcode::kSqrt, - HloOpcode::kCbrt, HloOpcode::kTan, - HloOpcode::kTanh, HloOpcode::kErf}, - std::back_inserter(ret)); - } - return ret; -} - -std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - if (element_type == PrimitiveType::PRED) { - return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, - HloOpcode::kCompare}; - } - std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, - HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kMultiply, HloOpcode::kSubtract}; - if (element_type == PrimitiveType::F32 || - element_type == PrimitiveType::BF16 || - element_type == PrimitiveType::F64) { - ret.push_back(HloOpcode::kAtan2); - ret.push_back(HloOpcode::kDivide); - ret.push_back(HloOpcode::kPower); - } - return ret; -} - -std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( - PrimitiveType element_type) { - return {HloOpcode::kSelect, HloOpcode::kClamp}; -} - -bool IsTritonSupportedElementwiseUpToFloatNormalization( - HloOpcode opcode, PrimitiveType element_type) { - return absl::c_linear_search( - TritonSupportedUnaryElementwiseUpToFloatNormalization( - element_type), - opcode) || - absl::c_linear_search( - TritonSupportedBinaryElementwiseUpToFloatNormalization( - element_type), - opcode) || - absl::c_linear_search( - TritonSupportedTernaryElementwiseUpToFloatNormalization( - element_type), - opcode); -} - -CodegenDecision CanTritonHandleElementwise( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { - if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { - return "Unsupported output data type."; - } - - for (const HloInstruction* operand : instr.operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { - return "Unsupported input data type."; - } - } - - if (instr.opcode() == HloOpcode::kConstant) { - return CodegenDecision{}; - } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( - instr.opcode(), instr.operand(0)->shape().element_type())) { - return "Unsupported elementwise operation."; - } - return CodegenDecision{}; -} - -bool IsDotAlgorithmSupportedByTriton( - PrecisionConfig::Algorithm algorithm, - const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - switch (algorithm) { - case PrecisionConfig::ALG_DOT_TF32_TF32_F32: - if (cuda_compute_capability) { - return true; - } - return false; - case PrecisionConfig::ALG_DOT_BF16_BF16_F32: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: - case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: - if (cuda_compute_capability) { - return true; - } - if (rocm_compute_capability) { - return rocm_compute_capability->has_bf16_dtype_support(); - } - return false; - - // TODO(b/326579472): Fix the support of this algorithm and maybe allow it - // here. - case PrecisionConfig::ALG_DOT_F16_F16_F32: - // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is - // slow to compile. Disable it for now. - case PrecisionConfig::ALG_DOT_F32_F32_F32: - default: - return false; - } -} - -// Filters GEMMs which can be handled using Triton. -CodegenDecision CanTritonHandleGEMM( - const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { - auto cuda_compute_capability = - std::get_if(&gpu_version); - auto rocm_compute_capability = - std::get_if(&gpu_version); - - CHECK(cuda_compute_capability || rocm_compute_capability); - - if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { - if (!tsl::tensor_float_32_execution_enabled() || - absl::c_any_of(dot.precision_config().operand_precision(), - [](int x) { return x != PrecisionConfig::DEFAULT; })) { - return "Having non-default operand precisions or TensorFloat-32 disabled " - "for Dot op with unset algorithm."; - } - } else { - if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), - gpu_version)) { - return "Unsupported algorithm on the current device(s)."; - } - } - - // TODO(b/266862493): Support more output types. - if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), - gpu_version)) { - return "Unsupported output data type for Dot op."; - } - - if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), - gpu_version) || - !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), - gpu_version)) { - return "Unsupported input data type for Dot op."; - } - - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - - // TODO(b/269580541): support multiple batch dimensions. - if (dim_numbers.lhs_batch_dimensions().size() > 1) { - return "Multiple batch dimensions."; - } - - return CodegenDecision{}; -} +namespace legacy_triton { // Filters Reduces which can be handled using Triton. +// TODO(b/345763510): The function is in use by the new version of the triton +// support but the implementation of this function relies on the legacy +// IsTritonSupport... functions. It should be rewritten for the new +// infrastructure. legacy_triton:: prefix is used to avoid name collision with +// the new implementation and for clarity. CodegenDecision CanTritonHandleReduce( const HloReduceInstruction& reduce, const se::GpuComputeCapability& gpu_version) { - if (!IsTritonSupportedDataType(reduce.shape().element_type(), gpu_version)) { + if (!legacy_triton::IsTritonSupportedDataType(reduce.shape().element_type(), + gpu_version)) { return "Unsupported output data type for Reduce op."; } for (const HloInstruction* operand : reduce.operands()) { - if (!IsTritonSupportedDataType(operand->shape().element_type(), - gpu_version)) { + if (!legacy_triton::IsTritonSupportedDataType( + operand->shape().element_type(), gpu_version)) { return "Unsupported input data type for Reduce op."; } } bool is_triton_supported_reduction_computation = [&]() { - return absl::c_all_of( - reduce.to_apply()->instructions(), [&](const HloInstruction* instr) { - return IsTritonSupportedInstruction(*instr, gpu_version); - }); + return absl::c_all_of(reduce.to_apply()->instructions(), + [&](const HloInstruction* instr) { + return legacy_triton::IsTritonSupportedInstruction( + *instr, gpu_version); + }); }(); if (!is_triton_supported_reduction_computation) { return "Unsupported reduction computation by Triton."; @@ -317,96 +78,6 @@ CodegenDecision CanTritonHandleReduce( return "Reduction is not a row-reduction of a single operand."; } -bool NoNonContractingDimension(const HloDotInstruction& dot) { - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); - if (dim_numbers.lhs_batch_dimensions().size() + - dim_numbers.lhs_contracting_dimensions().size() == - dot.operand(0)->shape().rank() || - dim_numbers.rhs_batch_dimensions().size() + - dim_numbers.rhs_contracting_dimensions().size() == - dot.operand(1)->shape().rank()) { - return true; - } - return false; -} - -CodegenDecision IsTritonSupportedDynamicSlice( - const HloDynamicSliceInstruction& instr) { - for (const HloInstruction* index_operand : instr.index_operands()) { - switch (index_operand->shape().element_type()) { - case S8: - case S16: - case S32: - break; // supported - default: - return CodegenDecision( - "Dynamic slice is only supported with S8, S16, or S32 indices."); - } - } - - // Similar to normal slice, we cannot slice a non-major-most dimension as - // that would introduce non-contiguous strides under tiling. The existing - // check against this in GetRequirementsIfSupportedOrder is not suitable for - // dynamic slices, so we instead check for this here. - const HloInstruction* input = instr.operand(0); - Layout in_layout = input->shape().layout(); - int64_t majormost_dim_id = - in_layout.minor_to_major(in_layout.minor_to_major_size() - 1); - - for (int i = 0; i < input->shape().dimensions_size(); ++i) { - if (i == majormost_dim_id) { - continue; - } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { - return CodegenDecision( - "Unsupported dynamic slice on non-major-most dimension."); - } - } - - // TODO(b/343143854): Check the subtleties of which dynamic slices are - // supported, for example that a fragmented dimension cannot be sliced. - return CodegenDecision{}; -} - -CodegenDecision IsTritonSupportedInstruction( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { - if (instr.IsElementwise()) { - return CanTritonHandleElementwise(instr, gpu_version); - } - - switch (instr.opcode()) { - case HloOpcode::kDot: { - auto* dot = Cast(&instr); - // Cases where lhs or rhs have no non-contracting dims are not handled. - if (NoNonContractingDimension(*dot)) { - return "No non-contracting dimensions."; - } - return CanTritonHandleGEMM(*dot, gpu_version); - } - case HloOpcode::kTuple: { - if (instr.IsRoot()) { - return CodegenDecision{}; - } - return "Only supports root tuples."; - } - case HloOpcode::kDynamicSlice: { - return IsTritonSupportedDynamicSlice( - *Cast(&instr)); - } - case HloOpcode::kBitcast: - case HloOpcode::kTranspose: - case HloOpcode::kSlice: - case HloOpcode::kReshape: - case HloOpcode::kPad: - case HloOpcode::kConcatenate: - case HloOpcode::kParameter: - case HloOpcode::kBroadcast: - return CodegenDecision{}; - default: - break; - } - return "Unsupported opcode."; -} - } // namespace legacy_triton namespace { @@ -563,8 +234,8 @@ CodegenDecision IsTritonSupportedInstructionImpl( return CodegenDecision{}; } - bool output_type_is_supported = - IsTritonSupportedDataType(instr.shape().element_type(), gpu_version); + auto type = instr.shape().element_type(); + bool output_type_is_supported = IsTritonSupportedDataType(type, gpu_version); if (!output_type_is_supported) { return "Unsupported output data type."; diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h index abd2a4087216a7..14431e85b74f33 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support.h @@ -18,91 +18,16 @@ limitations under the License. // This file is the home of the basic Triton support checks which are used by // multiple other components. -#include - #include "absl/status/status.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/xla_data.pb.h" namespace xla { namespace gpu { -using CodegenDecision = FusionDecision; - -namespace legacy_triton { - -// Tells if f(a+b) == f(a) + f(b). -bool IsDistributiveOverAddition(const HloInstruction& hlo); - -// Allowlist of unary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( - PrimitiveType); - -// Allowlist of binary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( - PrimitiveType); - -// Allowlist of ternary elementwise operations supported by the legacy Triton -// emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( - PrimitiveType); -// Data types that are supported by the legacy Triton emitters. -bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); - -// Checks elementwise operation against unary, binary, and ternary elementwise -// operations supported by the legacy Triton emitters. -// -// Note: this is not an accurate representation of what is actually supported by -// the Triton emitters, because operations affected by FloatNormalization may -// be tagged as "supported" here, even though FloatNormalization is required to -// make them work. We could fix this, but this is code we aim to delete soon, so -// it doesn't seem worth it. We'll revisit this decision if the code doesn't go -// away soon. -bool IsTritonSupportedElementwiseUpToFloatNormalization(HloOpcode, - PrimitiveType); - -CodegenDecision CanTritonHandleGEMM( - const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version); - -// Checks instruction against the requirements of the legacy Triton emitters. -CodegenDecision IsTritonSupportedInstruction( - const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); - -// Checks dynamic slice against the requirements of the legacy Triton emitters. -// -// This is exposed separately from IsTritonSupportedInstruction because we can -// use it in the dimension order propagation without adding a dependency on the -// GPU version. -CodegenDecision IsTritonSupportedDynamicSlice( - const HloDynamicSliceInstruction& instr); -} // namespace legacy_triton +using CodegenDecision = FusionDecision; // Checks that Triton officially supports the provided compute capability. // diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc new file mode 100644 index 00000000000000..b07630b7cb7734 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.cc @@ -0,0 +1,396 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/fusions/triton/triton_support.h" + +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/variant_visitor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" +#include "tsl/platform/tensor_float_32_utils.h" + +namespace xla { +namespace gpu { +namespace legacy_triton { + +bool IsDistributiveOverAddition(const HloInstruction& hlo) { + // The list is most likely incomplete. + // For example division can be added too but only for operand #0. + if (hlo.opcode() == HloOpcode::kMultiply || + hlo.opcode() == HloOpcode::kNegate || + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kReshape || hlo.opcode() == HloOpcode::kCopy || + hlo.opcode() == HloOpcode::kTranspose || + hlo.opcode() == HloOpcode::kConvert || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kSlice) { + return true; + } + return false; +} + +// Types that are supported by Triton as dot output. +// +// BF16 is supported in a sense that all operations on it are implemented +// through F32 and converts have to be inserted into the HLO graph, but +// they can be missing during fusion. +bool IsTritonSupportedDotOutputType( + const PrimitiveType t, const se::GpuComputeCapability& gpu_version) { + switch (t) { + case F16: + case F32: + return true; + case F8E5M2: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastAmpere(); + }, + [](const se::RocmComputeCapability& cc) { + return false; + }}, + gpu_version); + + case F8E4M3FN: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return cc.IsAtLeastHopper(); + }, + [](const se::RocmComputeCapability& cc) { + return false; + }}, + gpu_version); + case BF16: + return std::visit(VariantVisitor{[](const se::CudaComputeCapability& cc) { + return true; + }, + [](const se::RocmComputeCapability& cc) { + return cc.has_bf16_dtype_support(); + }}, + gpu_version); + default: + return false; + } +}; + +// Data types that are supported by the Triton emitters. +// TODO(b/266862493): Support more data types (F8, F64, etc.). +bool IsTritonSupportedDataType(PrimitiveType type, + const se::GpuComputeCapability& gpu_version) { + if (IsTritonSupportedDotOutputType(type, gpu_version)) { + return true; + } + switch (type) { + case PRED: + case S8: + case S16: + case S32: + return true; + default: + return false; + } +} + +CodegenDecision IsInstructionSupportsDataTypes( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (!IsTritonSupportedDataType(instr.shape().element_type(), gpu_version)) { + return "Unsupported output data type."; + } + + for (const HloInstruction* operand : instr.operands()) { + const auto operand_type = operand->shape().element_type(); + switch (instr.opcode()) { + case HloOpcode::kConvert: + // TODO(b/358580281): remove DebugOptions from this function after + // enabling int4 in Triton GEMM. + if (operand_type == S4 && instr.GetModule() + ->config() + .debug_options() + .xla_gpu_enable_triton_gemm_int4()) { + continue; + } + [[fallthrough]]; + default: + if (!IsTritonSupportedDataType(operand_type, gpu_version)) { + return "Unsupported input data type."; + } + } + } + return CodegenDecision{}; +} + +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + std::vector ret = {HloOpcode::kConvert}; + if (element_type == PrimitiveType::PRED) { + ret.push_back(HloOpcode::kNot); + return ret; + } + ret.push_back(HloOpcode::kAbs); + ret.push_back(HloOpcode::kNegate); + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + absl::c_copy(std::vector{HloOpcode::kCos, HloOpcode::kExp, + HloOpcode::kExpm1, HloOpcode::kFloor, + HloOpcode::kCeil, HloOpcode::kLog, + HloOpcode::kLog1p, HloOpcode::kRsqrt, + HloOpcode::kSin, HloOpcode::kSqrt, + HloOpcode::kCbrt, HloOpcode::kTan, + HloOpcode::kTanh, HloOpcode::kErf}, + std::back_inserter(ret)); + } + return ret; +} + +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + if (element_type == PrimitiveType::PRED) { + return {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, + HloOpcode::kCompare}; + } + std::vector ret = {HloOpcode::kAdd, HloOpcode::kCompare, + HloOpcode::kMaximum, HloOpcode::kMinimum, + HloOpcode::kMultiply, HloOpcode::kSubtract}; + if (element_type == PrimitiveType::F32 || + element_type == PrimitiveType::BF16 || + element_type == PrimitiveType::F64) { + ret.push_back(HloOpcode::kAtan2); + ret.push_back(HloOpcode::kDivide); + ret.push_back(HloOpcode::kPower); + } + return ret; +} + +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( + PrimitiveType element_type) { + return {HloOpcode::kSelect, HloOpcode::kClamp}; +} + +bool IsTritonSupportedElementwiseUpToFloatNormalization( + HloOpcode opcode, PrimitiveType element_type) { + return absl::c_linear_search( + TritonSupportedUnaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedBinaryElementwiseUpToFloatNormalization( + element_type), + opcode) || + absl::c_linear_search( + TritonSupportedTernaryElementwiseUpToFloatNormalization( + element_type), + opcode); +} + +CodegenDecision CanTritonHandleElementwise( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (auto decision = IsInstructionSupportsDataTypes(instr, gpu_version); + !decision.CanFuse()) { + return decision; + } + if (instr.opcode() == HloOpcode::kConstant) { + return CodegenDecision{}; + } else if (!IsTritonSupportedElementwiseUpToFloatNormalization( + instr.opcode(), instr.operand(0)->shape().element_type())) { + return "Unsupported elementwise operation."; + } + return CodegenDecision{}; +} + +bool IsDotAlgorithmSupportedByTriton( + PrecisionConfig::Algorithm algorithm, + const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + switch (algorithm) { + case PrecisionConfig::ALG_DOT_TF32_TF32_F32: + if (cuda_compute_capability) { + return true; + } + return false; + case PrecisionConfig::ALG_DOT_BF16_BF16_F32: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X3: + case PrecisionConfig::ALG_DOT_BF16_BF16_F32_X6: + if (cuda_compute_capability) { + return true; + } + if (rocm_compute_capability) { + return rocm_compute_capability->has_bf16_dtype_support(); + } + return false; + + // TODO(b/326579472): Fix the support of this algorithm and maybe allow it + // here. + case PrecisionConfig::ALG_DOT_F16_F16_F32: + // TODO(b/311331155): Triton F32 is about 3x slower than Triton TF32 and is + // slow to compile. Disable it for now. + case PrecisionConfig::ALG_DOT_F32_F32_F32: + default: + return false; + } +} + +// Filters GEMMs which can be handled using Triton. +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version) { + auto cuda_compute_capability = + std::get_if(&gpu_version); + auto rocm_compute_capability = + std::get_if(&gpu_version); + + CHECK(cuda_compute_capability || rocm_compute_capability); + + if (dot.precision_config().algorithm() == PrecisionConfig::ALG_UNSET) { + if (!tsl::tensor_float_32_execution_enabled() || + absl::c_any_of(dot.precision_config().operand_precision(), + [](int x) { return x != PrecisionConfig::DEFAULT; })) { + return "Having non-default operand precisions or TensorFloat-32 disabled " + "for Dot op with unset algorithm."; + } + } else { + if (!IsDotAlgorithmSupportedByTriton(dot.precision_config().algorithm(), + gpu_version)) { + return "Unsupported algorithm on the current device(s)."; + } + } + + // TODO(b/266862493): Support more output types. + if (!IsTritonSupportedDotOutputType(dot.shape().element_type(), + gpu_version)) { + return "Unsupported output data type for Dot op."; + } + + if (!IsTritonSupportedDataType(dot.operand(0)->shape().element_type(), + gpu_version) || + !IsTritonSupportedDataType(dot.operand(1)->shape().element_type(), + gpu_version)) { + return "Unsupported input data type for Dot op."; + } + + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + + // TODO(b/269580541): support multiple batch dimensions. + if (dim_numbers.lhs_batch_dimensions().size() > 1) { + return "Multiple batch dimensions."; + } + + return CodegenDecision{}; +} + +bool NoNonContractingDimension(const HloDotInstruction& dot) { + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); + if (dim_numbers.lhs_batch_dimensions().size() + + dim_numbers.lhs_contracting_dimensions().size() == + dot.operand(0)->shape().rank() || + dim_numbers.rhs_batch_dimensions().size() + + dim_numbers.rhs_contracting_dimensions().size() == + dot.operand(1)->shape().rank()) { + return true; + } + return false; +} + +CodegenDecision IsTritonSupportedDynamicSlice( + const HloDynamicSliceInstruction& instr) { + for (const HloInstruction* index_operand : instr.index_operands()) { + switch (index_operand->shape().element_type()) { + case S8: + case S16: + case S32: + break; // supported + default: + return CodegenDecision( + "Dynamic slice is only supported with S8, S16, or S32 indices."); + } + } + + // Similar to normal slice, we cannot slice a non-major-most dimension as + // that would introduce non-contiguous strides under tiling. The existing + // check against this in GetRequirementsIfSupportedOrder is not suitable for + // dynamic slices, so we instead check for this here. + const HloInstruction* input = instr.operand(0); + Layout in_layout = input->shape().layout(); + int64_t majormost_dim_id = + in_layout.minor_to_major(in_layout.minor_to_major_size() - 1); + + for (int i = 0; i < input->shape().dimensions_size(); ++i) { + if (i == majormost_dim_id) { + continue; + } else if (input->shape().dimensions(i) != instr.slice_sizes(i)) { + return CodegenDecision( + "Unsupported dynamic slice on non-major-most dimension."); + } + } + + // TODO(b/343143854): Check the subtleties of which dynamic slices are + // supported, for example that a fragmented dimension cannot be sliced. + return CodegenDecision{}; +} + +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version) { + if (instr.IsElementwise()) { + return CanTritonHandleElementwise(instr, gpu_version); + } + + switch (instr.opcode()) { + case HloOpcode::kDot: { + auto* dot = Cast(&instr); + // Cases where lhs or rhs have no non-contracting dims are not handled. + if (NoNonContractingDimension(*dot)) { + return "No non-contracting dimensions."; + } + return CanTritonHandleGEMM(*dot, gpu_version); + } + case HloOpcode::kTuple: { + if (instr.IsRoot()) { + return CodegenDecision{}; + } + return "Only supports root tuples."; + } + case HloOpcode::kDynamicSlice: { + return IsTritonSupportedDynamicSlice( + *Cast(&instr)); + } + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: + case HloOpcode::kReshape: + case HloOpcode::kPad: + case HloOpcode::kConcatenate: + case HloOpcode::kParameter: + case HloOpcode::kBroadcast: + return CodegenDecision{}; + default: + break; + } + return "Unsupported opcode."; +} + +} // namespace legacy_triton +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h new file mode 100644 index 00000000000000..da088465fa43e8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy.h @@ -0,0 +1,110 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ +#define XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ + +// This file is the home of the basic Triton support checks which are used by +// multiple other components. + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/instruction_fusion.h" +#include "xla/stream_executor/device_description.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace gpu { + +using CodegenDecision = FusionDecision; + +namespace legacy_triton { + +// Tells if f(a+b) == f(a) + f(b). +bool IsDistributiveOverAddition(const HloInstruction& hlo); + +// Allowlist of unary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedUnaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Allowlist of binary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedBinaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Allowlist of ternary elementwise operations supported by the legacy Triton +// emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +std::vector TritonSupportedTernaryElementwiseUpToFloatNormalization( + PrimitiveType); + +// Data types that are supported by the legacy Triton emitters. +bool IsTritonSupportedDataType(PrimitiveType, const se::GpuComputeCapability&); + +// Checks elementwise operation against unary, binary, and ternary elementwise +// operations supported by the legacy Triton emitters. +// +// Note: this is not an accurate representation of what is actually supported by +// the Triton emitters, because operations affected by FloatNormalization may +// be tagged as "supported" here, even though FloatNormalization is required to +// make them work. We could fix this, but this is code we aim to delete soon, so +// it doesn't seem worth it. We'll revisit this decision if the code doesn't go +// away soon. +bool IsTritonSupportedElementwiseUpToFloatNormalization(HloOpcode, + PrimitiveType); + +CodegenDecision CanTritonHandleGEMM( + const HloDotInstruction& dot, const se::GpuComputeCapability& gpu_version); + +// Checks instruction against the requirements of the legacy Triton emitters. +CodegenDecision IsTritonSupportedInstruction( + const HloInstruction& instr, const se::GpuComputeCapability& gpu_version); + +// Checks dynamic slice against the requirements of the legacy Triton emitters. +// +// This is exposed separately from IsTritonSupportedInstruction because we can +// use it in the dimension order propagation without adding a dependency on the +// GPU version. +CodegenDecision IsTritonSupportedDynamicSlice( + const HloDynamicSliceInstruction& instr); + +} // namespace legacy_triton +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_FUSIONS_TRITON_TRITON_SUPPORT_LEGACY_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc index 89b1e8d1bc297b..41adc715e3849e 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_legacy_test.cc @@ -15,6 +15,8 @@ limitations under the License. // TODO(b/343158720): Simplify the tests in this file after a generic emitter // has landed. +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" + #include #include #include @@ -34,15 +36,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/fusions/triton/triton_test_utils.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc index d93817af5efc6e..89da38e7350091 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_support_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_support.h" +#include #include #include #include @@ -26,8 +27,10 @@ limitations under the License. #include #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "absl/strings/substitute.h" +#include "absl/types/span.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/primitive_util.h" #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" @@ -145,7 +148,7 @@ auto AllDevicesToTest() { // Generates all the possible test combinations for a given opcodes. A test // combination is a tuple of the form (data_type, opcode, compute_capability). -auto AllTestCombinationsForOpcodes(std::vector&& opcodes) { +auto AllTestCombinationsForOpcodes(absl::Span opcodes) { std::vector> test_combinations; for (PrimitiveType data_type : AllXlaDataTypes()) { @@ -226,10 +229,13 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{16}, cc); } -INSTANTIATE_TEST_SUITE_P(BitcastOrReshapeTestSuite, BitcastOrReshapeTest, - AllTestCombinationsForOpcodes({HloOpcode::kBitcast, - HloOpcode::kReshape}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsBitcastReshape = {HloOpcode::kBitcast, + HloOpcode::kReshape}; + +INSTANTIATE_TEST_SUITE_P( + BitcastOrReshapeTestSuite, BitcastOrReshapeTest, + AllTestCombinationsForOpcodes(kTestedOpsBitcastReshape), + TritonSupportTestTypeOpcodeAndDeviceToString); using UnaryElementwiseTest = TritonSupportTestWithParam; @@ -280,36 +286,38 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } +constexpr std::array kTestedOpsUnaryElementwise = {HloOpcode::kAbs, + HloOpcode::kCbrt, + HloOpcode::kCeil, + HloOpcode::kClz, + HloOpcode::kConvert, + HloOpcode::kCos, + HloOpcode::kErf, + HloOpcode::kExp, + HloOpcode::kExpm1, + HloOpcode::kFloor, + HloOpcode::kImag, + HloOpcode::kIsFinite, + HloOpcode::kLog, + HloOpcode::kLog1p, + HloOpcode::kLogistic, + HloOpcode::kNegate, + HloOpcode::kNot, + HloOpcode::kPopulationCount, + HloOpcode::kReal, + HloOpcode::kReducePrecision, + HloOpcode::kRoundNearestAfz, + HloOpcode::kRoundNearestEven, + HloOpcode::kRsqrt, + HloOpcode::kSign, + HloOpcode::kSin, + HloOpcode::kSqrt, + HloOpcode::kTan, + HloOpcode::kTanh}; + INSTANTIATE_TEST_SUITE_P( UnaryElementwiseTestSuite, UnaryElementwiseTest, - AllTestCombinationsForOpcodes({HloOpcode::kAbs, - HloOpcode::kCbrt, - HloOpcode::kCeil, - HloOpcode::kClz, - HloOpcode::kConvert, - HloOpcode::kCos, - HloOpcode::kErf, - HloOpcode::kExp, - HloOpcode::kExpm1, - HloOpcode::kFloor, - HloOpcode::kImag, - HloOpcode::kIsFinite, - HloOpcode::kLog, - HloOpcode::kLog1p, - HloOpcode::kLogistic, - HloOpcode::kNegate, - HloOpcode::kNot, - HloOpcode::kPopulationCount, - HloOpcode::kReal, - HloOpcode::kReducePrecision, - HloOpcode::kRoundNearestAfz, - HloOpcode::kRoundNearestEven, - HloOpcode::kRsqrt, - HloOpcode::kSign, - HloOpcode::kSin, - HloOpcode::kSqrt, - HloOpcode::kTan, - HloOpcode::kTanh}), + AllTestCombinationsForOpcodes(kTestedOpsUnaryElementwise), TritonSupportTestTypeOpcodeAndDeviceToString); using BinaryElementwiseTest = TritonSupportTestWithParam; @@ -353,15 +361,27 @@ ENTRY triton_computation { skip_failure_branch_to_avoid_crash); } +constexpr std::array kTestedOpsBinaryElementwise = { + HloOpcode::kAnd, + HloOpcode::kOr, + HloOpcode::kXor, + HloOpcode::kAdd, + HloOpcode::kMultiply, + HloOpcode::kMaximum, + HloOpcode::kMinimum, + HloOpcode::kSubtract, + HloOpcode::kAtan2, + HloOpcode::kDivide, + HloOpcode::kRemainder, + HloOpcode::kPower, + HloOpcode::kShiftLeft, + HloOpcode::kShiftRightArithmetic, + HloOpcode::kShiftRightLogical, + HloOpcode::kCompare}; + INSTANTIATE_TEST_SUITE_P( BinaryElementwiseTestSuite, BinaryElementwiseTest, - AllTestCombinationsForOpcodes( - {HloOpcode::kAnd, HloOpcode::kOr, HloOpcode::kXor, HloOpcode::kAdd, - HloOpcode::kMultiply, HloOpcode::kMaximum, HloOpcode::kMinimum, - HloOpcode::kSubtract, HloOpcode::kAtan2, HloOpcode::kDivide, - HloOpcode::kRemainder, HloOpcode::kPower, HloOpcode::kShiftLeft, - HloOpcode::kShiftRightArithmetic, HloOpcode::kShiftRightLogical, - HloOpcode::kCompare}), + AllTestCombinationsForOpcodes(kTestedOpsBinaryElementwise), TritonSupportTestTypeOpcodeAndDeviceToString); using TernaryElementwiseTest = TritonSupportTestWithParam; @@ -387,10 +407,13 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1, 32}, cc); } -INSTANTIATE_TEST_SUITE_P(TernaryElementwiseTestSuite, TernaryElementwiseTest, - AllTestCombinationsForOpcodes({HloOpcode::kSelect, - HloOpcode::kClamp}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsTernaryElementwise = {HloOpcode::kSelect, + HloOpcode::kClamp}; + +INSTANTIATE_TEST_SUITE_P( + TernaryElementwiseTestSuite, TernaryElementwiseTest, + AllTestCombinationsForOpcodes(kTestedOpsTernaryElementwise), + TritonSupportTestTypeOpcodeAndDeviceToString); using ReduceTest = TritonSupportTestWithParam; @@ -488,7 +511,6 @@ ENTRY triton_computation { TF_ASSERT_OK_AND_ASSIGN( TestedInstruction ti, ParseTemplateAndGetInstruction(kHloTestTemplate, data_type, opcode)); - EXPECT_TRUE(IsTritonSupportedInstruction(ti.Instruction(), cc)); RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } @@ -570,8 +592,10 @@ ENTRY triton_computation { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } +constexpr std::array kTestedOpsReduction = {HloOpcode::kReduce}; + INSTANTIATE_TEST_SUITE_P(ReduceTestSuite, ReduceTest, - AllTestCombinationsForOpcodes({HloOpcode::kReduce}), + AllTestCombinationsForOpcodes(kTestedOpsReduction), TritonSupportTestTypeOpcodeAndDeviceToString); using CollectiveTest = TritonSupportTestWithParam; @@ -643,13 +667,119 @@ TEST_P(CollectiveTest, UnsupportedCollectivesFailGracefullyWithTriton) { RunSupportTest(std::move(ti), /*output_tile_sizes=*/{1}, cc); } -INSTANTIATE_TEST_SUITE_P( - CollectiveTestSuite, CollectiveTest, - AllTestCombinationsForOpcodes({HloOpcode::kAllGather, HloOpcode::kAllReduce, - HloOpcode::kAllToAll, - HloOpcode::kCollectivePermute, - HloOpcode::kReduceScatter}), - TritonSupportTestTypeOpcodeAndDeviceToString); +constexpr std::array kTestedOpsCollectives = { + HloOpcode::kAllGather, HloOpcode::kAllReduce, HloOpcode::kAllToAll, + HloOpcode::kCollectivePermute, HloOpcode::kReduceScatter}; + +INSTANTIATE_TEST_SUITE_P(CollectiveTestSuite, CollectiveTest, + AllTestCombinationsForOpcodes(kTestedOpsCollectives), + TritonSupportTestTypeOpcodeAndDeviceToString); + +absl::flat_hash_set AllTestedOpcodes() { + // The return set is initialized with ops that are implicitly tested. + absl::flat_hash_set ret{HloOpcode::kParameter}; + + ret.insert(kTestedOpsBitcastReshape.begin(), kTestedOpsBitcastReshape.end()); + ret.insert(kTestedOpsUnaryElementwise.begin(), + kTestedOpsUnaryElementwise.end()); + ret.insert(kTestedOpsBinaryElementwise.begin(), + kTestedOpsBinaryElementwise.end()); + ret.insert(kTestedOpsTernaryElementwise.begin(), + kTestedOpsTernaryElementwise.end()); + ret.insert(kTestedOpsReduction.begin(), kTestedOpsReduction.end()); + ret.insert(kTestedOpsCollectives.begin(), kTestedOpsCollectives.end()); + return ret; +} + +absl::flat_hash_set AllUntestedOpcodes() { + return absl::flat_hash_set{HloOpcode::kAddDependency, + HloOpcode::kAfterAll, + HloOpcode::kAllGatherDone, + HloOpcode::kAllGatherStart, + HloOpcode::kAllReduceDone, + HloOpcode::kAllReduceStart, + HloOpcode::kAsyncDone, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncUpdate, + HloOpcode::kBatchNormGrad, + HloOpcode::kBatchNormInference, + HloOpcode::kBatchNormTraining, + HloOpcode::kBitcastConvert, + HloOpcode::kBroadcast, + HloOpcode::kCall, + HloOpcode::kCholesky, + HloOpcode::kCollectiveBroadcast, + HloOpcode::kCollectivePermuteDone, + HloOpcode::kCollectivePermuteStart, + HloOpcode::kComplex, + HloOpcode::kConcatenate, + HloOpcode::kConditional, + HloOpcode::kConstant, + HloOpcode::kConvolution, + HloOpcode::kCopy, + HloOpcode::kCopyDone, + HloOpcode::kCopyStart, + HloOpcode::kCustomCall, + HloOpcode::kDomain, + HloOpcode::kDot, + HloOpcode::kDynamicReshape, + HloOpcode::kDynamicSlice, + HloOpcode::kDynamicUpdateSlice, + HloOpcode::kFft, + HloOpcode::kFusion, + HloOpcode::kGather, + HloOpcode::kGetDimensionSize, + HloOpcode::kGetTupleElement, + HloOpcode::kInfeed, + HloOpcode::kIota, + HloOpcode::kMap, + HloOpcode::kOptimizationBarrier, + HloOpcode::kOutfeed, + HloOpcode::kPad, + HloOpcode::kPartitionId, + HloOpcode::kRecv, + HloOpcode::kRecvDone, + HloOpcode::kReduceWindow, + HloOpcode::kReplicaId, + HloOpcode::kReverse, + HloOpcode::kRng, + HloOpcode::kRngBitGenerator, + HloOpcode::kRngGetAndUpdateState, + HloOpcode::kScatter, + HloOpcode::kSelectAndScatter, + HloOpcode::kSend, + HloOpcode::kSendDone, + HloOpcode::kSetDimensionSize, + HloOpcode::kSlice, + HloOpcode::kSort, + HloOpcode::kStochasticConvert, + HloOpcode::kTopK, + HloOpcode::kTranspose, + HloOpcode::kTriangularSolve, + HloOpcode::kTuple, + HloOpcode::kWhile}; +} + +TEST(OpCoverage, TestedAndUntestedDoNotOverlap) { + absl::flat_hash_set untested_opcodes = AllUntestedOpcodes(); + for (HloOpcode tested : AllTestedOpcodes()) { + EXPECT_FALSE(untested_opcodes.contains(tested)) + << "Opcode `" << HloOpcodeString(tested) + << "` appears in both tested and untested opcodes."; + } +} + +TEST(OpCoverage, AllOpcodesAppearInTestedOrUntested) { + absl::flat_hash_set untested_opcodes = AllUntestedOpcodes(); + absl::flat_hash_set tested_opcodes = AllTestedOpcodes(); + for (int opcode_index = 0; opcode_index < HloOpcodeCount(); ++opcode_index) { + auto opcode = static_cast(opcode_index); + EXPECT_TRUE(untested_opcodes.contains(opcode) || + tested_opcodes.contains(opcode)) + << "Opcode `" << HloOpcodeString(opcode) + << "` does not appear in tested or untested opcodes."; + } +} } // namespace } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/fusions/triton_test.cc b/third_party/xla/xla/service/gpu/fusions/triton_test.cc index c2cfabfa8f2292..1738d6fcad353d 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton_test.cc @@ -64,7 +64,7 @@ ENTRY entry_computation { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info); + HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info); std::unique_ptr emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); @@ -100,7 +100,7 @@ ENTRY entry_computation { TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - HloFusionAnalysis analysis = AnalyzeFusion(*root, device_info); + HloFusionAnalysis analysis = HloFusionAnalysis::Create(*root, device_info); std::unique_ptr emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 15debb4020b854..be7ba5c92c3da9 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -74,7 +74,6 @@ limitations under the License. #include "xla/service/all_reduce_folder.h" #include "xla/service/all_reduce_promotion.h" #include "xla/service/all_reduce_reassociate.h" -#include "xla/service/all_reduce_splitter.h" #include "xla/service/async_collective_creator.h" #include "xla/service/batchnorm_expander.h" #include "xla/service/bitcast_dtypes_expander.h" @@ -110,43 +109,18 @@ limitations under the License. #include "xla/service/float_support.h" #include "xla/service/gather_expander.h" #include "xla/service/gather_simplifier.h" -#include "xla/service/gpu/algorithm_checker.h" -#include "xla/service/gpu/all_reduce_blueconnect.h" -#include "xla/service/gpu/autotuner_util.h" -#include "xla/service/gpu/collective_permute_cycle_decomposer.h" -#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h" -#include "xla/service/gpu/command_buffer_scheduling.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/autotuning/custom_kernel_fusion_autotuner.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/conv_layout_normalization.h" -#include "xla/service/gpu/custom_kernel_fusion_autotuner.h" -#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" -#include "xla/service/gpu/dot_dimension_sorter.h" -#include "xla/service/gpu/dot_operand_converter.h" -#include "xla/service/gpu/double_buffer_loop_unrolling.h" -#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h" #include "xla/service/gpu/execution_stream_assignment.h" #include "xla/service/gpu/fusion_pipeline.h" -#include "xla/service/gpu/fusion_wrapper.h" -#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" -#include "xla/service/gpu/gemm_fusion.h" -#include "xla/service/gpu/gemm_rewriter.h" -#include "xla/service/gpu/gemv_rewriter.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" -#include "xla/service/gpu/gpu_all_gather_optimizer.h" -#include "xla/service/gpu/gpu_async_collective_annotator.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/gpu_float_support.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/gpu_layout_assignment.h" #include "xla/service/gpu/gpu_p2p_pipeliner.h" -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" -#include "xla/service/gpu/gpu_sanitize_constant_names.h" -#include "xla/service/gpu/gpu_scatter_expander.h" #include "xla/service/gpu/gpu_spmd_pipeline.h" -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" #include "xla/service/gpu/hlo_fusion_stats.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/ir_emitter_context.h" @@ -156,26 +130,55 @@ limitations under the License. #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/model/gpu_cost_model_stats_collection.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" -#include "xla/service/gpu/move_copy_to_users.h" -#include "xla/service/gpu/pipelined_p2p_rewriter.h" #include "xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h" -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" -#include "xla/service/gpu/reduction_dimension_grouper.h" -#include "xla/service/gpu/reduction_layout_normalizer.h" -#include "xla/service/gpu/reduction_splitter.h" #include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/rename_fusions.h" #include "xla/service/gpu/runtime/thunk.h" #include "xla/service/gpu/runtime_intrinsics.h" -#include "xla/service/gpu/scatter_slice_simplifier.h" -#include "xla/service/gpu/softmax_rewriter_triton.h" -#include "xla/service/gpu/stream_attribute_annotator.h" -#include "xla/service/gpu/stream_attribute_async_wrapper.h" #include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/gpu/topk_specializer.h" -#include "xla/service/gpu/topk_splitter.h" -#include "xla/service/gpu/tree_reduction_rewriter.h" -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algorithm_checker.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" +#include "xla/service/gpu/transforms/all_reduce_blueconnect.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_wrapper.h" +#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h" +#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" +#include "xla/service/gpu/transforms/command_buffer_scheduling.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" +#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/transforms/dot_dimension_sorter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" +#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" +#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" +#include "xla/service/gpu/transforms/fusion_wrapper.h" +#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" +#include "xla/service/gpu/transforms/gemv_rewriter.h" +#include "xla/service/gpu/transforms/layout_assignment.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" +#include "xla/service/gpu/transforms/rename_fusions.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" +#include "xla/service/gpu/transforms/scatter_expander.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/topk_specializer.h" +#include "xla/service/gpu/transforms/topk_splitter.h" +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_computation_deduplicator.h" #include "xla/service/hlo_constant_folding.h" @@ -407,6 +410,16 @@ GpuThunkAotCompilationResult::LoadExecutable( platform_name, gpu_device_info, mlir_context.get(), llvm_module.get(), /*llvm_module_constants=*/nullptr, /*emit_kernels=*/false); + + absl::string_view cache_file_path = + hlo_module->config().debug_options().xla_gpu_kernel_cache_file(); + if (!cache_file_path.empty() && + hlo_module->config() + .debug_options() + .xla_gpu_enable_llvm_module_compilation_parallelism()) { + TF_RETURN_IF_ERROR(LoadCache(ir_emitter_context, cache_file_path)); + } + auto ir_emitter = IrEmitterUnnested::Create(&ir_emitter_context); TF_RETURN_IF_ERROR( ir_emitter->EmitHloComputation(hlo_module->entry_computation())); @@ -494,7 +507,7 @@ AlgebraicSimplifierOptions LayoutInsensitiveAlgebraicSimplifierOptions( AlgebraicSimplifierOptions layout_insensitive_algsimp_opts = opts_from_compiler; layout_insensitive_algsimp_opts.set_conv_is_lowerable_callback( - GpuConvRewriter::ConvIsLowerable); + ConvRewriter::ConvIsLowerable); layout_insensitive_algsimp_opts.set_enable_dot_strength_reduction( hlo_module_config.debug_options() .xla_gpu_enable_dot_strength_reduction()); @@ -526,6 +539,7 @@ absl::Status RunPreSPMDPartitionerPasses(HloModule* hlo_module) { HloPassPipeline pre_spmd_pipeline("pre-spmd-partitioner"); // Run some IR cleanup passes before running the SPMD partitioning // passes. + pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); pre_spmd_pipeline.AddPass(); @@ -627,7 +641,7 @@ absl::Status RunOptimizationPasses( HloPassPipeline pipeline("optimization"); AddHloVerifier(&pipeline); if (debug_options.xla_gpu_multi_streamed_windowed_einsum()) { - pipeline.AddPass(); + pipeline.AddPass(); } pipeline.AddPass(); pipeline.AddPass(); @@ -1121,7 +1135,7 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { return false; } }; - pipeline.AddPass(convert_to_async); + pipeline.AddPass(convert_to_async); return pipeline.Run(hlo_module).status(); } @@ -1177,6 +1191,8 @@ absl::Status RunPostFusionVerificationPasses( absl::Status GpuCompiler::OptimizeHloModule( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config) { + tsl::profiler::TraceMe traceme("GpuCompiler::OptimizeHloModule"); + CheckNotScheduled(hlo_module); LogDebugOptions(hlo_module); @@ -1255,7 +1271,7 @@ absl::Status GpuCompiler::OptimizeHloModule( // This is a "low effort, high impact" fusion that should be run first. if (hlo_module->config() .debug_options() - .xla_gpu_enable_address_computation_fusion()) { + .xla_gpu_enable_dynamic_slice_fusion()) { HloPassPipeline pipeline("dynamic-slice"); TF_ASSIGN_OR_RETURN(se::Platform * platform, se::PlatformManager::PlatformWithId(PlatformId())); @@ -1305,6 +1321,30 @@ absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { .status(); } +namespace { +void AddGemmRewriterPasses(HloPassPipeline& pipeline, + const DebugOptions& debug_options, + const se::GpuComputeCapability gpu_version, + const int32_t toolkit_version) { + // Adding bias to GEMMs is helpful for skipping kernel launches for `add` + // operations. However, the bias term can add dependencies between the GEMMs + // that could otherwise be parallelized. Because of this, we disable bias + // addition when async dot is enabled. + GemmRewriterOptions::BiasMode bias_mode = + GemmRewriterOptions::BiasMode::kBias; + if (debug_options.xla_gpu_async_dot()) { + bias_mode = GemmRewriterOptions::BiasMode::kNoBias; + } + + pipeline.AddPass( + gpu_version, toolkit_version, + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only, bias_mode}); + pipeline.AddPass( + gpu_version, toolkit_version, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only, bias_mode}); +} +} // namespace + absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( HloModule* hlo_module, se::StreamExecutor* stream_exec, const CompileOptions& options, const TargetConfig& gpu_target_config, @@ -1376,6 +1416,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // heuristic, so we can mix and match various Gemm implementations based // on projected (measured) performance. if (debug_options.xla_gpu_enable_custom_fusions()) { + pipeline.AddPass(); pipeline.AddPass( &gpu_target_config.device_description); pipeline.AddPass(autotune_config); @@ -1396,10 +1437,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(gpu_version); } - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/true); - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/false); + // Rewrite GEMMs into custom calls. + AddGemmRewriterPasses(pipeline, debug_options, gpu_version, + GetToolkitVersion()); // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); @@ -1414,6 +1454,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); // Run Softmax fusion after layout normalization. We expect a default layout @@ -1421,8 +1462,9 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( // ReductionDimensionGrouper, as that makes matching the softmax pattern // harder. if (debug_options.xla_gpu_enable_triton_softmax_fusion() && - cuda_cc != nullptr && - cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) { + ((cuda_cc != nullptr && + cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) || + rocm_cc != nullptr)) { // Triton compilation needs normalized operations on bf16 (i.e. converted // to f32). add_float_normalization(pipeline); @@ -1441,7 +1483,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( bool ignore_small_reduce_dims = !debug_options.xla_gpu_enable_priority_fusion(); pipeline.AddPass>(ignore_small_reduce_dims); - pipeline.AddPass>(gpu_version); + pipeline.AddPass>(gpu_version); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -1466,13 +1508,23 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); // TODO(tdanyluk): Apply CublasPadForGemms to the cuBLAS GEMMs generated // here for possibly better cuBLAS performance. - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/true); - pipeline.AddPass(gpu_version, GetToolkitVersion(), - /*f8_rewrite=*/false); + AddGemmRewriterPasses(pipeline, debug_options, gpu_version, + GetToolkitVersion()); + // Rewrite GEMMs with broadcasted inputs as strided GEMMs. pipeline.AddPass(); + // Wrap `dot` operations into async computations in an effort to parallelize + // matrix operations. This pass needs to run after the GEMM rewriter so that + // we still use the native GEMM implementation. + if (debug_options.xla_gpu_async_dot()) { + pipeline.AddPass([](HloInstruction* instruction) { + // TODO(b/339654953): Use a better heuristic to determine whether a + // `dot` operation should be wrapped in an async computation. + return instruction->opcode() == HloOpcode::kCustomCall; + }); + } + pipeline.AddPass( static_cast(stream_executor::MemoryType::kHost), /* after_layout= */ true); @@ -2058,6 +2110,8 @@ GpuCompiler::CompileToBackendResult( HloModule* module, llvm::LLVMContext* llvm_context, se::StreamExecutor* executor, const CompileOptions& options, const se::DeviceDescription& gpu_device_info) { + tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); + TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor)); TF_ASSIGN_OR_RETURN( ScheduleMetadata schedule_metadata, @@ -2146,8 +2200,8 @@ absl::StatusOr> GpuCompiler::RunBackend( }}; BinaryMap dnn_compiled_graphs; if (stream_exec) { - TF_RETURN_IF_ERROR(RunCudnnFusionCompilerPass(module.get(), stream_exec, - &dnn_compiled_graphs)); + TF_RETURN_IF_ERROR(RunCudnnCompilerPasses(module.get(), stream_exec, + &dnn_compiled_graphs)); } const DebugOptions& debug_opts = module->config().debug_options(); @@ -2480,7 +2534,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( pipeline.AddPass( gpu_device_info, toolkit_version, driver_version.value_or(toolkit_version)); - pipeline.AddPass(); + pipeline.AddPass(); } AddHloVerifier(&main_pipeline, diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index 27a434f5a5d035..456e6755b0d83a 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -31,7 +31,7 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/compile_module_to_llvm_ir.h" #include "xla/service/gpu/executable.pb.h" @@ -171,10 +171,10 @@ class GpuCompiler : public LLVMCompiler { return absl::OkStatus(); } - // Runs cuDNN fusion compiler pass. - virtual absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) { + // Runs cuDNN fusion and custom call compiler passes. + virtual absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) { return absl::OkStatus(); } @@ -235,7 +235,8 @@ class GpuCompiler : public LLVMCompiler { absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); virtual absl::StatusOr> LinkModules( - se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, + se::GpuComputeCapability gpu_compute_capability, + se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) { return Unimplemented("LinkModules is not implemented."); diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc index ad41f791bb6bbb..62ab8c7b7d6b56 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -33,10 +34,14 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_module_group.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/primitive_util.h" +#include "xla/service/compiler.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_hlo_schedule.h" #include "xla/service/gpu/metrics.h" #include "xla/service/hlo_module_config.h" @@ -44,11 +49,14 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/xla_debug_info_manager.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/casts.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" @@ -439,8 +447,7 @@ ENTRY main { HloModuleConfig config; DebugOptions triton_enabled_debug_options = GetDebugOptionsForTest(); - triton_enabled_debug_options.set_xla_gpu_enable_address_computation_fusion( - false); + triton_enabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); triton_enabled_debug_options .set_xla_gpu_require_complete_aot_autotune_results(true); config.set_debug_options(triton_enabled_debug_options); @@ -459,8 +466,7 @@ ENTRY main { GetOptimizedModule(std::move(module))); AutotunerUtil::ClearAutotuneResults(); DebugOptions triton_disabled_debug_options = GetDebugOptionsForTest(); - triton_disabled_debug_options.set_xla_gpu_enable_address_computation_fusion( - false); + triton_disabled_debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); triton_disabled_debug_options.set_xla_gpu_enable_triton_gemm(false); config.set_debug_options(triton_disabled_debug_options); TF_ASSERT_OK_AND_ASSIGN(module, @@ -658,12 +664,12 @@ CHECK: %[[AFTER_ALL:.*]] = after-all CHECK: %[[RESULT_RECV:.*]] = recv(%[[AFTER_ALL]]) CHECK-SAME: channel_id=[[CHANNEL_ID]] CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", -CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, +CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, CHECK-SAME: control-predecessors={%[[CUSTOM_CALL]]} CHECK: %[[RESULT_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[AFTER_ALL]]) CHECK-SAME: channel_id=1 CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", -CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, +CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, CHECK-SAME: control-predecessors={%[[RESULT_RECV]]} CHECK: ROOT // We actually expect both RESULT_RECV and RESULT_SEND to match on this line. @@ -677,11 +683,11 @@ CHECK: %[[ENTRY_AFTER_ALL:.*]] = after-all CHECK: %[[ENTRY_RECV:.*]] = recv(%[[ENTRY_AFTER_ALL]]) CHECK-SAME: channel_id=[[CHANNEL_ID]] CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", -CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"} +CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}} CHECK: %[[ENTRY_SEND:.*]] = send(%[[SOME_SEND_ARG:.*]], %[[ENTRY_AFTER_ALL]]) CHECK-SAME: channel_id=1 CHECK-SAME: frontend_attributes={_xla_send_recv_pipeline="0", -CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs="{{0,1},{1,2},{2,3},{3,4}}"}, +CHECK-SAME{LITERAL}: _xla_send_recv_source_target_pairs={{0,1},{1,2},{2,3},{3,4}}}, CHECK-SAME: control-predecessors={%[[ENTRY_RECV]]} CHECK: %[[WHILE_INIT:.*]] = tuple // Check here that the send argument is likewise passed to the while loop, as @@ -818,6 +824,78 @@ TEST_F(KernelCacheTest, AllKernelsAreCachedBecauseSplitModuleUsesRoundRobin) { EXPECT_EQ(CacheEntryCount(), 4); } +TEST_F(KernelCacheTest, CachingWorksWithLoadedExecutables) { + const std::string kHloAdd1 = R"( +add1 { + p = s32[] parameter(0) + c = s32[] constant(1) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add1 +})"; + + const std::string kHloAdd2 = R"( +add2 { + p = s32[] parameter(0) + c = s32[] constant(2) + ROOT a = s32[] add(p, c) +} + +ENTRY e { + p = s32[] parameter(0) + ROOT r = s32[] fusion(p), kind=kLoop, calls=add2 +})"; + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + se::PlatformManager::PlatformWithName("cuda")); + TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * stream_exec, + platform->ExecutorForDevice(0)); + + Compiler* compiler = backend().compiler(); + AotCompilationOptions aot_options(compiler->PlatformId()); + aot_options.set_executor(stream_exec); + + auto test = [this, &compiler, &aot_options](absl::string_view hlo, int input, + int expected_result) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto module_group = std::make_unique(std::move(module)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector> aot_results, + compiler->CompileAheadOfTime(std::move(module_group), aot_options)); + + TF_ASSERT_OK_AND_ASSIGN(std::string serialized_aot_result, + aot_results[0]->SerializeAsString()); + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr aot_result, + compiler->LoadAotCompilationResult(serialized_aot_result)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr executable, + aot_result->LoadExecutable(compiler, aot_options.executor())); + + const xla::Literal literal_input = + xla::LiteralUtil::CreateR0(input); + const xla::Literal literal_expected_result = + xla::LiteralUtil::CreateR0(expected_result); + + TF_ASSERT_OK_AND_ASSIGN(Literal result, + GetHloRunner().value()->ExecuteWithExecutable( + executable.get(), {&literal_input})); + + EXPECT_TRUE(LiteralTestUtil::Equal(result, literal_expected_result)); + }; + + test(kHloAdd1, 1, 2); + test(kHloAdd2, 1, 3); + // The test used to fail on the second execution of the second module when it + // was already cached. + test(kHloAdd2, 1, 3); +} + class KernelCacheTestSingleThreaded : public KernelCacheTest { public: DebugOptions GetDebugOptionsForTest() override { @@ -874,10 +952,10 @@ TEST_F(GpuCompilerTest, TestFlag_xla_gpu_unsafe_pipelined_loop_annotator) { })"; const char* kExpected = R"( - // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"} - // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"} - // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{3,0}}",_xla_send_recv_validation="{{[{]}}{3,9}}"} - // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs="{{[{]}}{0,1},{1,2},{2,3}}",_xla_send_recv_validation="{{[{]}}{0,6},{1,7},{2,8}}"} + // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}} + // CHECK: {{.+}} = send({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}} + // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{3,0}},_xla_send_recv_validation={{[{]}}{3,9}}} + // CHECK: {{.+}} = recv({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_source_target_pairs={{[{]}}{0,1},{1,2},{2,3}},_xla_send_recv_validation={{[{]}}{0,6},{1,7},{2,8}}} )"; DebugOptions debug_options; diff --git a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto index 3549c95f7fdccd..51caadb7bd2d06 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto +++ b/third_party/xla/xla/service/gpu/gpu_compiler_test_autotune_db.textproto @@ -13,7 +13,7 @@ results { } results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 1555 GB/s, L2 cache: 40 MB" - hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" + hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" result { run_time { nanos: 1 @@ -37,7 +37,7 @@ results { } results { device: "CUDA: 8.0, Cores: 108, GPU clock: 1.41 GHz, Memory bandwidth: 2039 GB/s, L2 cache: 40 MB" - hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" + hlo: "(bf16[128,1024,1024]{2,1,0}, s8[4194304]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" result { run_time { nanos: 1 @@ -61,7 +61,7 @@ results { } results { device: "CUDA: 9.0, Cores: 132, GPU clock: 1.98 GHz, Memory bandwidth: 3352 GB/s, L2 cache: 50 MB" - hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"2\"],\"rhs_contracting_dimensions\":[\"1\"],\"lhs_batch_dimensions\":[\"0\"],\"rhs_batch_dimensions\":[\"0\"]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"1048576\",\"rhs_stride\":\"1048576\",\"grad_x\":false,\"grad_y\":false,\"damax_output\":false},\"force_earliest_schedule\":false}" + hlo: "(bf16[128,1024,1024]{2,1,0}, s8[33554432]{0}) custom-call(bf16[128,1024,1024]{2,1,0}, bf16[128,1024,1024]{2,1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"damax_output\":false,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[\"0\"],\"lhs_contracting_dimensions\":[\"2\"],\"rhs_batch_dimensions\":[\"0\"],\"rhs_contracting_dimensions\":[\"1\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"1048576\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"1048576\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" result { gemm { algorithm: -1 diff --git a/third_party/xla/xla/service/gpu/gpu_executable.cc b/third_party/xla/xla/service/gpu/gpu_executable.cc index bf9774711fcfd6..40055de2fa0f16 100644 --- a/third_party/xla/xla/service/gpu/gpu_executable.cc +++ b/third_party/xla/xla/service/gpu/gpu_executable.cc @@ -841,9 +841,14 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( TF_ASSIGN_OR_RETURN(globals, ResolveConstantGlobals(run_options->stream())); } - auto device_ordinal = executor->device_ordinal(); + // Use the `device_ordinal` from the `run_options` if it is provided. This is + // the ordinal of the logical devices (e.g., virtual GPUs). If it is not + // provided, the ordinals of the logical and physical devices are the same. + const int device_ordinal = run_options->device_ordinal() != -1 + ? run_options->device_ordinal() + : executor->device_ordinal(); ExecutionOutput result(/*on_device_shape=*/output_shape_, memory_allocator, - device_ordinal); + device_ordinal, executor->device_ordinal()); TF_ASSIGN_OR_RETURN( BufferAllocations buffer_allocations, @@ -873,9 +878,7 @@ absl::StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( } module_allocations_[executor][i] = buffer_allocations.GetDeviceAddress(i); - VLOG(5) << "Gpu address changed for module " << module_name_ - << ", allocation info: \n" - << allocations[i].ToShortString(); + VLOG(5) << "Gpu address changed for module " << module_name_; } } } diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc deleted file mode 100644 index 566c0068f5dbba..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.cc +++ /dev/null @@ -1,719 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/gpu_fused_mha_runner.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" -#include "Eigen/Core" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -namespace { -using se::DeviceMemory; -using se::DeviceMemoryBase; -using se::dnn::DataType; -using se::dnn::MatmulTensorDescriptor; -using se::dnn::TensorDescriptor; - -template -absl::Status RunFusedMHA(GpufMHAParams params, se::Stream *stream, - RunFusedMHAOptions options, - DeviceMemory lhs_bmm1_buffer, - DeviceMemory rhs_bmm1_buffer, - DeviceMemory rhs_bmm2_buffer, - DeviceMemory output_buffer, - DeviceMemoryBase bias_buffer, - DeviceMemoryBase scratch_memory, - DeviceMemoryBase activation_output, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHARunner(); - std::optional> local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHAOp::Config config, - params.config->AsDnnFusedMHAOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - return (*runner)(stream, options.profile_result, scratch_memory, - lhs_bmm1_buffer, rhs_bmm1_buffer, rhs_bmm2_buffer, - output_buffer, bias_buffer, activation_output, seqlen_q, - seqlen_k); -} - -template -absl::Status RunGpuFMHAImpl(const GpufMHAParams ¶ms, se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHAOptions options) { - auto lhs_bmm1_buffer = se::DeviceMemory(params.lhs_bmm1_buffer); - auto rhs_bmm1_buffer = se::DeviceMemory(params.rhs_bmm1_buffer); - auto rhs_bmm2_buffer = se::DeviceMemory(params.rhs_bmm2_buffer); - auto output_buffer = se::DeviceMemory(params.output_buffer); - auto activation_buffer = - params.activation_buffer.has_value() - ? se::DeviceMemory(*params.activation_buffer) - : se::DeviceMemoryBase(); - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - run_status = RunFusedMHA( - params, stream, options, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, scratch_memory, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return absl::OkStatus(); -} - -template -absl::Status RunFusedMHABackward( - GpufMHABackwardParams params, se::Stream *stream, - RunFusedMHABackwardOptions options, - DeviceMemory bmm1_grad_gemm1_rhs_buffer, - DeviceMemory bmm1_grad_gemm2_rhs_buffer, - DeviceMemory bmm2_grad_gemm1_lhs_buffer, - DeviceMemory bmm2_grad_gemm2_rhs_buffer, - DeviceMemory d_output_buffer, - DeviceMemory d_bmm1_lhs_buffer, - DeviceMemory d_bmm1_rhs_buffer, - DeviceMemory d_bmm2_rhs_buffer, DeviceMemoryBase d_s_buffer, - DeviceMemoryBase d_bias_buffer, DeviceMemoryBase fwd_output_buffer, - DeviceMemoryBase bias_buffer, DeviceMemoryBase scratch_memory, - DeviceMemoryBase seqlen_q, DeviceMemoryBase seqlen_k) { - se::dnn::LazyOpRunner *lazy_runner = - options.runner_cache->AsFusedMHABackwardRunner(); - std::optional> - local_runner; - if (!lazy_runner) { - local_runner.emplace(params.config->algorithm); - lazy_runner = &*local_runner; - } - std::optional dropout_rate; - if (params.config->dropout_rate) { - dropout_rate = *params.config->dropout_rate; - } - - std::optional seed; - if (params.config->seed) { - seed = *params.config->seed; - } - - TF_ASSIGN_OR_RETURN(se::dnn::FusedMHABackwardOp::Config config, - params.config->AsDnnFusedMHABackwardOpConfig()); - TF_ASSIGN_OR_RETURN(auto *runner, - lazy_runner->GetOrCreateRunner(config, stream)); - // TODO: pass in real softmax_sum, dQ_accum, fwd_output - return (*runner)(stream, options.profile_result, scratch_memory, - bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q, seqlen_k); - return absl::OkStatus(); -} - -template -absl::Status RunGpuFMHABackwardImpl(const GpufMHABackwardParams ¶ms, - se::Stream *stream, - se::DeviceMemoryBase scratch_memory, - RunFusedMHABackwardOptions options) { - auto bmm1_grad_gemm1_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm1_rhs_buffer); - auto bmm1_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm1_grad_gemm2_rhs_buffer); - auto bmm2_grad_gemm1_lhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm1_lhs_buffer); - auto bmm2_grad_gemm2_rhs_buffer = - se::DeviceMemory(params.bmm2_grad_gemm2_rhs_buffer); - auto d_output_buffer = se::DeviceMemory(params.d_output_buffer); - auto d_bmm1_lhs_buffer = - se::DeviceMemory(params.d_bmm1_lhs_buffer); - auto d_bmm1_rhs_buffer = - se::DeviceMemory(params.d_bmm1_rhs_buffer); - auto d_bmm2_rhs_buffer = - se::DeviceMemory(params.d_bmm2_rhs_buffer); - - // optional buffers - auto d_s_buffer = params.d_s_buffer.has_value() - ? se::DeviceMemory(*params.d_s_buffer) - : se::DeviceMemoryBase(); - - auto d_bias_buffer = params.d_bias_buffer.has_value() - ? se::DeviceMemory(*params.d_bias_buffer) - : se::DeviceMemoryBase(); - - auto fwd_output_buffer = - params.fwd_output_buffer.has_value() - ? se::DeviceMemory(*params.fwd_output_buffer) - : se::DeviceMemoryBase(); - - auto bias_buffer = params.bias_buffer.has_value() - ? se::DeviceMemory(*params.bias_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_q_buffer = - params.seqlen_q_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_q_buffer) - : se::DeviceMemoryBase(); - - auto seqlen_k_buffer = - params.seqlen_k_buffer.has_value() - ? se::DeviceMemory(*params.seqlen_k_buffer) - : se::DeviceMemoryBase(); - - se::dnn::AlgorithmDesc algorithm = params.config->algorithm; - if (options.runner_cache) { - algorithm = options.runner_cache->ToAlgorithmDesc(); - } - - absl::Status run_status = absl::OkStatus(); - switch (params.config->kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - run_status = RunFusedMHABackward( - params, stream, options, bmm1_grad_gemm1_rhs_buffer, - bmm1_grad_gemm2_rhs_buffer, bmm2_grad_gemm1_lhs_buffer, - bmm2_grad_gemm2_rhs_buffer, d_output_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, scratch_memory, seqlen_q_buffer, - seqlen_k_buffer); - break; - default: - return Internal("Invalid cuDNN fMHA kind"); - } - - if (!run_status.ok()) { - return run_status; - } - - if (!stream->ok()) { - return Internal("Unable to launch FMHA with type %s and algorithm %s", - CudnnfMHAKindToString(params.config->kind), - algorithm.ToString()); - } - - return run_status; -} -} // namespace - -/*static*/ absl::StatusOr GpufMHAConfig::For( - const GpufMHADescriptor &desc) { - // Get shapes from desc. - const Shape &lhs_bmm1_shape = desc.lhs_bmm1_shape; - const Shape &rhs_bmm1_shape = desc.rhs_bmm1_shape; - const Shape &rhs_bmm2_shape = desc.rhs_bmm2_shape; - const Shape &intermediate_lhs_bmm2_shape = desc.intermediate_lhs_bmm2_shape; - const Shape &output_shape = desc.output_shapes[0]; - - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN( - DataType lhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(lhs_bmm1_shape.element_type())); - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm1_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm1_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType rhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType(rhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType lhs_bmm2_type, - GetDNNDataTypeFromPrimitiveType( - intermediate_lhs_bmm2_shape.element_type())); - TF_ASSIGN_OR_RETURN(DataType output_type, GetDNNDataTypeFromPrimitiveType( - output_shape.element_type())); - GpufMHAConfig config; - config.input_type = lhs_bmm1_shape.element_type(); - config.output_type = output_shape.element_type(); - - // Get MatmulTensorDescriptors for BMM1 - config.lhs_bmm1 = - MatmulTensorDescriptor::For(lhs_bmm1_type, lhs_bmm1_shape.dimensions(), - desc.lhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.lhs_batch_dimensions(), - desc.bmm1_dnums.lhs_contracting_dimensions()); - config.rhs_bmm1 = - MatmulTensorDescriptor::For(rhs_bmm1_type, rhs_bmm1_shape.dimensions(), - desc.rhs_bmm1_shape.layout().minor_to_major(), - desc.bmm1_dnums.rhs_batch_dimensions(), - desc.bmm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 - config.rhs_bmm2 = - MatmulTensorDescriptor::For(rhs_bmm2_type, rhs_bmm2_shape.dimensions(), - desc.rhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.rhs_batch_dimensions(), - desc.bmm2_dnums.rhs_contracting_dimensions()); - - config.intermediate_lhs_bmm2 = MatmulTensorDescriptor::For( - lhs_bmm2_type, intermediate_lhs_bmm2_shape.dimensions(), - desc.intermediate_lhs_bmm2_shape.layout().minor_to_major(), - desc.bmm2_dnums.lhs_batch_dimensions(), - desc.bmm2_dnums.lhs_contracting_dimensions()); - - config.output = TensorDescriptor::For(output_type, output_shape.dimensions(), - output_shape.layout().minor_to_major()); - - if (desc.output_shapes.size() > 1) { - const Shape &activation_shape = desc.output_shapes.back(); - // Generally, activation should have same type as output, but set it - // explicityly just to be safe. - TF_ASSIGN_OR_RETURN( - DataType activation_type, - GetDNNDataTypeFromPrimitiveType(activation_shape.element_type())); - config.activation = - TensorDescriptor::For(activation_type, activation_shape.dimensions(), - activation_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - config.kind = desc.kind; - config.mask_type = desc.mask_type; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHAConfig::AsDnnFusedMHAOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHAOp::Config{ - scale, lhs_bmm1, rhs_bmm1, rhs_bmm2, intermediate_lhs_bmm2, - output, bias, activation, dropout_rate, seed, - mask_type}; -} - -/*static*/ absl::StatusOr GpufMHABackwardConfig::For( - const GpufMHABackwardDescriptor &desc) { - // Get shapes from desc. - - const Shape &bmm1_grad_gemm1_rhs_shape = desc.bmm1_grad_gemm1_rhs_shape; - const Shape &bmm1_grad_gemm2_rhs_shape = desc.bmm1_grad_gemm2_rhs_shape; - const Shape &bmm2_grad_gemm1_lhs_shape = desc.bmm2_grad_gemm1_lhs_shape; - const Shape &bmm2_grad_gemm2_rhs_shape = desc.bmm2_grad_gemm2_rhs_shape; - const Shape &d_output_shape = desc.d_output_shape; - const Shape &d_bmm1_lhs_shape = desc.d_bmm1_lhs_shape; - const Shape &d_bmm1_rhs_shape = desc.d_bmm1_rhs_shape; - const Shape &d_bmm2_rhs_shape = desc.d_bmm2_rhs_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm1_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm1_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm1_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm1_lhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN(DataType bmm2_grad_gemm2_rhs_type, - GetDNNDataTypeFromPrimitiveType( - bmm2_grad_gemm2_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_output_type, - GetDNNDataTypeFromPrimitiveType(d_output_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_lhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_lhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm1_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm1_rhs_shape.element_type())); - - TF_ASSIGN_OR_RETURN( - DataType d_bmm2_rhs_type, - GetDNNDataTypeFromPrimitiveType(d_bmm2_rhs_shape.element_type())); - - GpufMHABackwardConfig config; - config.input_type = bmm1_grad_gemm1_rhs_shape.element_type(); - config.output_type = d_bmm1_lhs_shape.element_type(); - - // Get MatmulTensorDescriptors for lhs of BMM1 grad GEMM 1 - config.bmm1_grad_gemm1_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm1_rhs_type, bmm1_grad_gemm1_rhs_shape.dimensions(), - desc.bmm1_grad_gemm1_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for rhs of BMM1 grad GEMM 2 - config.bmm1_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm1_grad_gemm2_rhs_type, bmm1_grad_gemm2_rhs_shape.dimensions(), - desc.bmm1_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm1_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm1_grad_gemm2_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 1 - config.bmm2_grad_gemm1_lhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - desc.bmm2_grad_gemm1_lhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.lhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.lhs_contracting_dimensions()); - - config.d_output = MatmulTensorDescriptor::For( - d_output_type, d_output_shape.dimensions(), - desc.d_output_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm1_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm1_dnums.rhs_contracting_dimensions()); - - // Get MatmulTensorDescriptors for BMM2 grad GEMM 2 - config.bmm2_grad_gemm2_rhs = MatmulTensorDescriptor::For( - bmm2_grad_gemm2_rhs_type, bmm2_grad_gemm2_rhs_shape.dimensions(), - desc.bmm2_grad_gemm2_rhs_shape.layout().minor_to_major(), - desc.bmm2_grad_gemm2_dnums.rhs_batch_dimensions(), - desc.bmm2_grad_gemm2_dnums - .rhs_contracting_dimensions()); // FMHA TODO: transpose here? - - config.d_bmm1_lhs = - TensorDescriptor::For(d_bmm1_lhs_type, d_bmm1_lhs_shape.dimensions(), - d_bmm1_lhs_shape.layout().minor_to_major()); - config.d_bmm1_rhs = - TensorDescriptor::For(d_bmm1_rhs_type, d_bmm1_rhs_shape.dimensions(), - d_bmm1_rhs_shape.layout().minor_to_major()); - config.d_bmm2_rhs = - TensorDescriptor::For(d_bmm2_rhs_type, d_bmm2_rhs_shape.dimensions(), - d_bmm2_rhs_shape.layout().minor_to_major()); - config.d_s = TensorDescriptor::For( - bmm2_grad_gemm1_lhs_type, bmm2_grad_gemm1_lhs_shape.dimensions(), - bmm2_grad_gemm1_lhs_shape.layout().minor_to_major()); - - if (desc.d_bias_shape) { - const Shape &d_bias_shape = *desc.d_bias_shape; - // Get DNN dtype from primtive types - TF_ASSIGN_OR_RETURN(DataType d_bias_type, GetDNNDataTypeFromPrimitiveType( - d_bias_shape.element_type())); - config.d_bias = - TensorDescriptor::For(d_bias_type, d_bias_shape.dimensions(), - d_bias_shape.layout().minor_to_major()); - } - - if (desc.mask_shape) { - const Shape &mask_shape = *desc.mask_shape; - TF_ASSIGN_OR_RETURN(DataType mask_type, GetDNNDataTypeFromPrimitiveType( - mask_shape.element_type())); - config.mask = TensorDescriptor::For(mask_type, mask_shape.dimensions(), - mask_shape.layout().minor_to_major()); - } - if (desc.fwd_output_shape) { - const Shape &fwd_output_shape = *desc.fwd_output_shape; - TF_ASSIGN_OR_RETURN( - DataType fwd_output_type, - GetDNNDataTypeFromPrimitiveType(fwd_output_shape.element_type())); - config.fwd_output = - TensorDescriptor::For(fwd_output_type, fwd_output_shape.dimensions(), - fwd_output_shape.layout().minor_to_major()); - } - - if (desc.bias_shape) { - const Shape &bias_shape = *desc.bias_shape; - TF_ASSIGN_OR_RETURN(DataType bias_type, GetDNNDataTypeFromPrimitiveType( - bias_shape.element_type())); - config.bias = TensorDescriptor::For(bias_type, bias_shape.dimensions(), - bias_shape.layout().minor_to_major()); - } - - config.kind = desc.kind; - config.mask_type = desc.mask_type; - config.force_deterministic = desc.force_deterministic; - const CudnnfMHABackendConfig &backend_config = desc.backend_config; - config.algorithm = se::dnn::AlgorithmDesc(backend_config.algorithm()); - config.fmha_scale.emplace(backend_config.fmha_scale()); - config.dropout_rate.emplace(backend_config.dropout_rate()); - config.seed.emplace(backend_config.seed()); - return config; -} - -absl::StatusOr -GpufMHABackwardConfig::AsDnnFusedMHABackwardOpConfig() const { - double scale = 1.0; - if (fmha_scale.has_value()) { - scale = *fmha_scale; - } - TF_ASSIGN_OR_RETURN(se::dnn::FMHAMaskKind mask_type, - GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(mask_type)); - - return se::dnn::FusedMHABackwardOp::Config{scale, - bmm1_grad_gemm1_rhs, - bmm1_grad_gemm2_rhs, - bmm2_grad_gemm1_lhs, - bmm2_grad_gemm2_rhs, - d_output, - d_bmm1_lhs, - d_bmm1_rhs, - d_bmm2_rhs, - d_s, - d_bias, - fwd_output, - bias, - dropout_rate, - seed, - mask_type, - force_deterministic}; -} - -/*static*/ absl::StatusOr GpufMHAParams::For( - const GpufMHAConfig &config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHAParams params; - params.config = &config; - params.lhs_bmm1_buffer = lhs_bmm1_buffer; - params.rhs_bmm1_buffer = rhs_bmm1_buffer; - params.rhs_bmm2_buffer = rhs_bmm2_buffer; - params.output_buffer = output_buffer; - params.activation_buffer = activation_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -/*static*/ absl::StatusOr GpufMHABackwardParams::For( - const GpufMHABackwardConfig &config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer) { - GpufMHABackwardParams params; - params.config = &config; - params.bmm1_grad_gemm1_rhs_buffer = bmm1_grad_gemm1_rhs_buffer; - params.bmm1_grad_gemm2_rhs_buffer = bmm1_grad_gemm2_rhs_buffer; - params.bmm2_grad_gemm1_lhs_buffer = bmm2_grad_gemm1_lhs_buffer; - params.bmm2_grad_gemm2_rhs_buffer = bmm2_grad_gemm2_rhs_buffer; - params.d_output_buffer = d_output_buffer; - params.d_bmm1_lhs_buffer = d_bmm1_lhs_buffer; - params.d_bmm1_rhs_buffer = d_bmm1_rhs_buffer; - params.d_bmm2_rhs_buffer = d_bmm2_rhs_buffer; - params.d_s_buffer = d_s_buffer; - params.d_bias_buffer = d_bias_buffer; - params.fwd_output_buffer = fwd_output_buffer; - params.bias_buffer = bias_buffer; - params.seqlen_q_buffer = seqlen_q_buffer; - params.seqlen_k_buffer = seqlen_k_buffer; - return params; -} - -absl::Status RunGpuFMHA(const GpufMHAConfig &fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream *stream, RunFusedMHAOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHAParams params, - GpufMHAParams::For(fmha_config, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, bias_buffer, - activation_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHAImpl( - params, stream, scratch_buffer, options); - default: - return absl::UnimplementedError(absl::StrFormat( - "Unimplemented fused MHA with %s", ToString(fmha_config))); - } - return absl::OkStatus(); -} - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig &fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream *stream, - RunFusedMHABackwardOptions options) { - TF_ASSIGN_OR_RETURN( - GpufMHABackwardParams params, - GpufMHABackwardParams::For( - fmha_config, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, - d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, fwd_output_buffer, - bias_buffer, seqlen_q_buffer, seqlen_k_buffer)); - PrimitiveType input_primitive_type = fmha_config.input_type; - switch (input_primitive_type) { - case F16: - return RunGpuFMHABackwardImpl( - params, stream, scratch_buffer, options); - case BF16: - return RunGpuFMHABackwardImpl(params, stream, - scratch_buffer, options); - default: - return Unimplemented("Unimplemented fused MHA backward"); - } - return absl::OkStatus(); -} - -std::string ToString(const GpufMHAConfig &config) { - std::string result = "GpufMHAConfig:\n"; - absl::StrAppend(&result, - "input_type: ", PrimitiveType_Name(config.input_type), ", "); - absl::StrAppend( - &result, "output_type: ", PrimitiveType_Name(config.output_type), ", "); - absl::StrAppend(&result, "Kind: ", CudnnfMHAKindToString(config.kind), ", "); - if (config.fmha_scale) { - absl::StrAppend(&result, "fmha_scale: ", *config.fmha_scale, ", "); - } - if (config.dropout_rate) { - absl::StrAppend(&result, "dropout_rate: ", *config.dropout_rate, ", "); - } - if (config.seed) { - absl::StrAppend(&result, "seed: ", *config.seed, ", "); - } - absl::StrAppend(&result, "Algorithm Desc: ", config.algorithm.ToString(), - "\n"); - absl::StrAppend(&result, "lhs_bmm1: ", config.lhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm1: ", config.rhs_bmm1.ToString(), "\n"); - absl::StrAppend(&result, "rhs_bmm2: ", config.rhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "intermediate_lhs_bmm2: ", - config.intermediate_lhs_bmm2.ToString(), "\n"); - absl::StrAppend(&result, "output: ", config.output.ToString(), "\n"); - - if (config.mask) { - absl::StrAppend(&result, "mask: ", (*config.mask).ToString(), "\n"); - } - - if (config.bias) { - absl::StrAppend(&result, "bias: ", (*config.bias).ToString(), "\n"); - } - - return result; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h b/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h deleted file mode 100644 index d0621cbdff6d74..00000000000000 --- a/third_party/xla/xla/service/gpu/gpu_fused_mha_runner.h +++ /dev/null @@ -1,431 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ -#define XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ - -#include -#include -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/cublas_cudnn.h" -#include "xla/shape.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/dnn.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -inline absl::StatusOr AsCudnnFmhaMaskKind( - xla::gpu::CudnnfMHABackendConfig_MaskType mask_type) { - switch (mask_type) { - case xla::gpu::CudnnfMHABackendConfig::NO_MASK: - return xla::gpu::CudnnfMHAMaskKind::kNoMask; - case xla::gpu::CudnnfMHABackendConfig::PADDING: - return xla::gpu::CudnnfMHAMaskKind::kPadding; - case xla::gpu::CudnnfMHABackendConfig::CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kCausal; - case xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL: - return xla::gpu::CudnnfMHAMaskKind::kPaddingCausal; - case xla::gpu::CudnnfMHABackendConfig::ALIBI: - return xla::gpu::CudnnfMHAMaskKind::kAlibi; - default: - return xla::Internal("Unknown fmha mask kind."); - } -} - -// This is an interim structure to hold the parameters to construct a -// GpufMHAConfig. -// Struct to describe properties of a FMHA without being tied to specific -// IR. Will be used to help build FMHA thunks from either XLA HLO or -// LHLO GPU dialect in MLIR. -struct GpufMHADescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape lhs_bmm1_shape; - Shape rhs_bmm1_shape; - Shape rhs_bmm2_shape; - Shape intermediate_lhs_bmm2_shape; - // This will contain both output shape and activation shape - absl::InlinedVector output_shapes; - DotDimensionNumbers bmm1_dnums; - DotDimensionNumbers bmm2_dnums; - - std::optional mask_shape; - std::optional bias_shape; -}; - -struct GpufMHABackwardDescriptor { - CudnnfMHAKind kind; - CudnnfMHABackendConfig backend_config; - CudnnfMHAMaskKind mask_type; - Shape bmm1_grad_gemm1_rhs_shape; - Shape bmm1_grad_gemm2_rhs_shape; - Shape bmm2_grad_gemm1_lhs_shape; - Shape bmm2_grad_gemm2_rhs_shape; - Shape d_output_shape; - Shape d_bmm1_lhs_shape; - Shape d_bmm1_rhs_shape; - Shape d_bmm2_rhs_shape; - DotDimensionNumbers bmm1_grad_gemm1_dnums; - DotDimensionNumbers bmm1_grad_gemm2_dnums; - DotDimensionNumbers bmm2_grad_gemm1_dnums; - DotDimensionNumbers bmm2_grad_gemm2_dnums; - - std::optional d_s_shape; - std::optional fwd_output_shape; - std::optional mask_shape; - std::optional d_bias_shape; - std::optional bias_shape; - bool force_deterministic; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention. -struct GpufMHAConfig { - static absl::StatusOr For(const GpufMHADescriptor& fmha_desc); - - absl::StatusOr AsDnnFusedMHAOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // bias -> [1, num_attn_heads, q_seq_len, kv_seq_len] - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor lhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm1; - se::dnn::MatmulTensorDescriptor rhs_bmm2; - se::dnn::MatmulTensorDescriptor intermediate_lhs_bmm2; - se::dnn::TensorDescriptor output; - - std::optional activation; - std::optional mask; - std::optional bias; -}; - -// Structure to describe static properties of a GPU fused Multi-Headed -// Attention backward. -struct GpufMHABackwardConfig { - static absl::StatusOr For( - const GpufMHABackwardDescriptor& fmha_desc); - - absl::StatusOr - AsDnnFusedMHABackwardOpConfig() const; - - PrimitiveType - input_type; // Capture the primitive type of one of the inputs of BMM1 - PrimitiveType output_type; - CudnnfMHAKind kind; - std::optional fmha_scale; - std::optional dropout_rate; - std::optional seed; - - se::dnn::AlgorithmDesc algorithm; - CudnnfMHAMaskKind mask_type; - // mask -> [batch_size, 1, q_seq_len, kv_seq_len] - // d_bias -> [1, num_heads, q_seq_len, kv_seq_len] - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm1_rhs; - se::dnn::MatmulTensorDescriptor bmm1_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm1_lhs; - se::dnn::MatmulTensorDescriptor bmm2_grad_gemm2_rhs; - se::dnn::MatmulTensorDescriptor d_output; - se::dnn::TensorDescriptor d_bmm1_lhs; - se::dnn::TensorDescriptor d_bmm1_rhs; - se::dnn::TensorDescriptor d_bmm2_rhs; - std::optional d_s; - std::optional mask; - std::optional d_bias; - std::optional fwd_output; - std::optional bias; - bool force_deterministic; -}; - -// Implementation struct exposed for debugging and log analysis. -struct GpufMHAParams { - static absl::StatusOr For( - const GpufMHAConfig& config, se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, se::DeviceMemoryBase output_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHAConfig* config; // Not owned - se::DeviceMemoryBase lhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm1_buffer; - se::DeviceMemoryBase rhs_bmm2_buffer; - se::DeviceMemoryBase output_buffer; - std::optional activation_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -struct GpufMHABackwardParams { - static absl::StatusOr For( - const GpufMHABackwardConfig& config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer); - - const GpufMHABackwardConfig* config; // Not owned - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer; - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer; - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer; - se::DeviceMemoryBase d_output_buffer; - se::DeviceMemoryBase d_bmm1_lhs_buffer; - se::DeviceMemoryBase d_bmm1_rhs_buffer; - se::DeviceMemoryBase d_bmm2_rhs_buffer; - std::optional d_s_buffer; - std::optional d_bias_buffer; - std::optional fwd_output_buffer; - std::optional bias_buffer; - std::optional seqlen_q_buffer; - std::optional seqlen_k_buffer; -}; - -class FusedMultiHeadedAttentionRunner { - public: - using Repr = - std::variant>>; - - FusedMultiHeadedAttentionRunner() = default; - - explicit FusedMultiHeadedAttentionRunner( - std::unique_ptr> runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionRunner(const GpufMHAConfig& config) - : FusedMultiHeadedAttentionRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) << "Cannot construct FusedMultiHeadedAttentionRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* AsFusedMHARunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>(repr_)); - return std::get< - std::unique_ptr>>( - repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionRunner class. Defining it static makes it easy to - // use and makes it clear that it is a utility function that doesn't rely on - // the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHAConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kSoftmaxDropout: - case CudnnfMHAKind::kSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmax: - case CudnnfMHAKind::kScaleBiasSoftmaxDropout: - return std::make_unique>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -class FusedMultiHeadedAttentionBackwardRunner { - public: - using Repr = std::variant< - std::monostate, // To allow XXX default ctor - std::unique_ptr>>; - - FusedMultiHeadedAttentionBackwardRunner() = default; - - explicit FusedMultiHeadedAttentionBackwardRunner( - std::unique_ptr> - runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner(Repr runner) - : repr_(std::move(runner)) {} - - explicit FusedMultiHeadedAttentionBackwardRunner( - const GpufMHABackwardConfig& config) - : FusedMultiHeadedAttentionBackwardRunner(CreateRunner(config)) { - if (std::holds_alternative(repr_)) { - CHECK(false) - << "Cannot construct FusedMultiHeadedAttentionBackwardRunner with " - "std::monostate"; - } - } - - se::dnn::AlgorithmDesc ToAlgorithmDesc() const { - return std::visit(ToAlgorithmDescVisitor{}, repr_); - } - - se::dnn::LazyOpRunner* - AsFusedMHABackwardRunner() { - CHECK(std::holds_alternative< - std::unique_ptr>>( - repr_)); - return std::get>>(repr_) - .get(); - } - - private: - // The CreateRunner function is defined as static because it - // doesn't need access to any non-static member variables of the - // FusedMultiHeadedAttentionBackwardRunner class. Defining it static makes it - // easy to use and makes it clear that it is a utility function that doesn't - // rely on the state of any specific instance of the class. - static Repr CreateRunner(const GpufMHABackwardConfig& config) { - switch (config.kind) { - case CudnnfMHAKind::kBackwardSoftmaxDropout: - case CudnnfMHAKind::kBackwardSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmax: - case CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout: - return std::make_unique< - se::dnn::LazyOpRunner>( - config.algorithm); - default: - LOG(FATAL) << "Internal error: unsupported CUDNN MHA kind in " - "FusedMultiHeadedAttentionBackwardRunner"; - } - } - - struct ToAlgorithmDescVisitor { - template - se::dnn::AlgorithmDesc operator()(const RunnerPtr& runner) { - return runner->ToAlgorithmDesc(); - } - - se::dnn::AlgorithmDesc operator()(const std::monostate&) { - CHECK(false) << "Internal error: uninitialized runner in ToAlgorithmDesc"; - } - }; - - Repr repr_; -}; - -struct RunFusedMHAOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionRunner* runner_cache; -}; - -struct RunFusedMHABackwardOptions { - // Nullable output-parameter pointer for profiling results. - // Profile results remain unused for now since cuDNN FMHA has only one - // algorithm for now. - se::dnn::ProfileResult* profile_result = nullptr; - - // Use this runner cache (and its configured algorithm), instead of the one - // from the instruction. - FusedMultiHeadedAttentionBackwardRunner* runner_cache; -}; - -absl::Status RunGpuFMHA(const GpufMHAConfig& fmha_config, - se::DeviceMemoryBase lhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm1_buffer, - se::DeviceMemoryBase rhs_bmm2_buffer, - se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase scratch_buffer, - std::optional bias_buffer, - std::optional activation_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, - se::Stream* stream, RunFusedMHAOptions = {}); - -absl::Status RunGpuFMHABackward( - const GpufMHABackwardConfig& fmha_config, - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer, - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer, - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer, - se::DeviceMemoryBase d_output_buffer, se::DeviceMemoryBase scratch_buffer, - se::DeviceMemoryBase d_bmm1_lhs_buffer, - se::DeviceMemoryBase d_bmm1_rhs_buffer, - se::DeviceMemoryBase d_bmm2_rhs_buffer, - std::optional d_s_buffer, - std::optional d_bias_buffer, - std::optional fwd_output_buffer, - std::optional bias_buffer, - std::optional seqlen_q_buffer, - std::optional seqlen_k_buffer, se::Stream* stream, - RunFusedMHABackwardOptions = {}); - -std::string ToString(const GpufMHAConfig& config); - -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_GPU_FUSED_MHA_RUNNER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index c40be168a7d44a..f637e67e562113 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -29,11 +29,13 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/synchronization/mutex.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/permutation_util.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/reduction_utils.h" @@ -57,6 +59,25 @@ bool HasAnyTiledTransposeRoot(const HloComputation& computation) { }); } +const Shape& GetElementShape(const HloFusionAnalysis& analysis) { + const Shape* shape = &analysis.fusion_root(0).shape(); + while (shape->IsTuple()) { + shape = &shape->tuple_shapes(0); + } + return *shape; +} + +// Computes the maximum valid unroll factor for a given instruction. +int ComputeMaxUnrollFactor(int64_t num_elements) { + constexpr int kMaxUnrollFactor = 4; + for (int i = kMaxUnrollFactor; i > 1; i /= 2) { + if (num_elements % i == 0) { + return i; + } + } + return 1; +} + } // namespace bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { @@ -612,11 +633,16 @@ static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { // from potential x-tiling). return 4 * 32 * 33 * primitive_size * num_variadic; } - } else if (GetDescriptionForTiledTransposeEmitter(instr, instr).has_value()) { + } else if (auto tr = GetDescriptionForTiledTransposeEmitter(instr, instr)) { // Tile size for transposition. int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type()); - return 32 * 33 * primitive_size; + int64_t bytes_required = 32 * 33 * primitive_size; + // If the last dimension is not changed, it becomes part of the tile. + if (tr->permutation.back() == tr->permutation.size() - 1) { + bytes_required *= tr->dimensions.back(); + } + return bytes_required; } // Other fused expressions for now don't need the shared memory budget. return 0; @@ -942,9 +968,6 @@ static void GetFusionRootsRec(const HloInstruction* root, GetFusionRootsRec(root->operand(i), out); } } else { - CHECK(!absl::c_linear_search(out, root)) - << "Fusion root contains instruction " << root->ToString() - << " multiple times"; out.push_back(root); } } @@ -1022,5 +1045,46 @@ std::vector GetFusibleComputations( return result; } +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis) { + return ComputeLoopFusionConfig(analysis, GetElementShape(analysis)); +} + +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis, const Shape& element_shape) { + int unroll_factor = 1; + // Unrolling is good to read large inputs with small elements + // due to vector loads, but increases the register pressure when one + // thread has to produce multiple output elements. + // Therefore for fusions with small outputs prefer to use one thread + // per output element = no unroll. + // Call 'small' fusions that use less threads than the GPU has. + int64_t num_elements = ShapeUtil::ElementsIn(element_shape); + int64_t n_threads_max = analysis.device_info().threads_per_core_limit() * + analysis.device_info().core_count(); + if (num_elements >= n_threads_max && + !MayPreventVectorization(analysis.fusion())) { + unroll_factor = ComputeMaxUnrollFactor(num_elements); + } + // CHECK that unroll_factor is a power-of-2, as needed by the logic below. + CHECK(absl::has_single_bit(static_cast(unroll_factor))); + // Ensure a single thread writes to a byte containing multiple values by + // setting unroll_factor to an appropriate number. Setting unroll_factor is + // safe even if the new unroll_factor doesn't divide the number of elements, + // as the parallel loop emitter will insert a bounds check in this case to + // ensure the out-of-bounds element is not computed and written. Setting + // unroll_factor is safe even if MayPreventVectorization returns false, as + // the MayPreventVectorization check is an optimization, not a correctness + // requirement. + unroll_factor = std::max( + unroll_factor, + CeilOfRatio(8, analysis.input_output_info().smallest_output_dtype_bits)); + CHECK(absl::has_single_bit(static_cast(unroll_factor))); + VLOG(2) << "Unroll factor: " << unroll_factor; + + LaunchDimensionsConfig launch_config{unroll_factor}; + return launch_config; +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 185c440603a6b2..0dadbfa36f5476 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -27,12 +27,14 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/launch_dimensions.h" #include "xla/service/instruction_fusion.h" #include "xla/stream_executor/device_description.h" // TODO(b/112957171): Extract logic to determine fusibility of HLO ops from -// GpuInstructionFusion, FusionMerger, and GpuMultiOutputFusion. +// GpuInstructionFusion, FusionMerger, and MultiOutputFusion. namespace xla { namespace gpu { @@ -226,6 +228,12 @@ bool IsGenericTritonFusion(const HloInstruction& instr); // instructions it contains. bool MayPreventVectorization(const HloFusionAdaptor& fusion); +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis); + +LaunchDimensionsConfig ComputeLoopFusionConfig( + const HloFusionAnalysis& analysis, const Shape& shape); + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 874b9da3a0a8c0..735709cbd346f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -544,7 +544,9 @@ TEST_F(GpuFusibleTest, FusionHeroesAreCompatible_TransposeFusionNotCompatible) { fused_computation_1 { p0.1 = f32[64,32]{1,0} parameter(0) neg = f32[64,32]{1,0} negate(p0.1) - ROOT transpose = f32[32,64]{1,0} transpose(neg), dimensions={1,0} + bc = f32[1,64,32]{2,1,0} bitcast(neg) + transpose = f32[1,32,64]{2,1,0} transpose(bc), dimensions={0,2,1} + ROOT bc2 = f32[32,64]{1,0} bitcast(transpose) } fused_computation_2 { @@ -562,10 +564,12 @@ TEST_F(GpuFusibleTest, FusionHeroesAreCompatible_TransposeFusionNotCompatible) { const HloInstruction* fusion_1 = module->entry_computation()->root_instruction(); const HloInstruction* fusion_2 = fusion_1->operand(0); - EXPECT_FALSE(FusionHeroesAreCompatible(fusion_1->fused_expression_root(), - fusion_2->fused_expression_root())); - EXPECT_FALSE(FusionHeroesAreCompatible(fusion_2->fused_expression_root(), - fusion_1->fused_expression_root())); + EXPECT_FALSE( + FusionHeroesAreCompatible(fusion_1->fused_expression_root(), + fusion_2->fused_expression_root()->operand(0))); + EXPECT_FALSE( + FusionHeroesAreCompatible(fusion_2->fused_expression_root()->operand(0), + fusion_1->fused_expression_root())); } TEST_F(GpuFusibleTest, ShapesCompatibleForMultiOutputFusion_LoopFusions) { @@ -1520,9 +1524,9 @@ TEST_F(GpuFusibleTest, ChooseFusionKind) { HloModule module ENTRY computation { - p = f32[5000,6000]{1,0} parameter(0) - c = f32[6000,5000] transpose(p), dimensions={1,0} - ROOT r = f32[300,20,5000] reshape(c) + p = f32[1,5000,6000]{2,1,0} parameter(0) + c = f32[1,6000,5000]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT r = f32[300,20,5000]{2,1,0} reshape(c) } )") .value(); @@ -1700,6 +1704,33 @@ TEST_F(GpuFusibleTest, GetFusionRootsWithMakeTupleGTESequence) { EXPECT_EQ(roots, expected_result); } +TEST_F(GpuFusibleTest, GetFusionRootsWithTupleMultipleSameOperands) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test_module + + fusion { + p1 = s32[32] parameter(0) + add0 = s32[32] add(p1, p1) + ROOT _ = (s32[32], s32[32]) tuple(add0, add0) + } + + ENTRY entry { + p0 = s32[32] parameter(0) + ROOT fusion = (s32[32], s32[32]) fusion(p0), kind=kCustom, calls=fusion + } + )") + .value(); + + auto called_computations = + module->entry_computation()->root_instruction()->called_computations(); + ASSERT_EQ(called_computations.size(), 1); + + auto fusion = called_computations.front(); + auto roots = GetFusionRoots(*fusion); + auto add0 = fusion->root_instruction()->operand(0); + EXPECT_THAT(GetFusionRoots(*fusion), ElementsAre(add0, add0)); +} + TEST_F(GpuFusibleTest, GetFusibleComputations) { auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( fused_reduce { @@ -1731,5 +1762,23 @@ TEST_F(GpuFusibleTest, GetFusibleComputations) { module->entry_computation())); } +TEST_F(GpuFusibleTest, GetSharedMemoryUsage) { + auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"( + wrapped_transpose { + p0 = f32[128,1024,2]{2,1,0} parameter(0) + ROOT transpose = f32[1024,128,2]{2,1,0} transpose(p0), dimensions={1,0,2} + } + ENTRY main { + p = f32[128,1024,2] parameter(0) + ROOT res = f32[1024,128,2]{2,1,0} fusion(p), kind=kInput, calls=wrapped_transpose + })")) + .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + FusionInfoCache cache; + auto fusion = module->entry_computation()->root_instruction(); + EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index 2504b431741f82..588abc6297fadd 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -46,9 +46,10 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_latency_hiding_scheduler.h" -#include "xla/service/gpu/gpu_schedule_postprocessing.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/latency_hiding_scheduler.h" @@ -63,6 +64,7 @@ limitations under the License. #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" +#include "tsl/profiler/lib/traceme.h" namespace xla { namespace gpu { @@ -74,6 +76,9 @@ bool ShouldScheduleAsEarlyAsPossible(const HloInstruction& instr) { case HloOpcode::kAllReduceStart: case HloOpcode::kCollectivePermuteStart: return !IsSyncCollective(&instr); + case HloOpcode::kAsyncStart: + // Start async ops as early as possible to allow more concurrency. + return true; case HloOpcode::kCustomCall: return static_cast(instr) .custom_call_schedule() == @@ -95,6 +100,10 @@ bool ShouldScheduleAsLateAsPossible(const HloInstruction& instr) { case HloOpcode::kAllReduceDone: case HloOpcode::kCollectivePermuteDone: return ShouldScheduleAsEarlyAsPossible(*instr.operand(0)); + case HloOpcode::kAsyncDone: + // Schedule as many other ops as possible before blocking on the + // completion of async ops. + return true; case HloOpcode::kCustomCall: return static_cast(instr) .custom_call_schedule() == CustomCallSchedule::SCHEDULE_LATEST; @@ -423,6 +432,7 @@ static int64_t GetSchedulerMemoryLimit( absl::StatusOr ScheduleGpuModule( HloModule* module, int64_t pointer_size, const se::DeviceDescription& gpu_device_info) { + tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); int64_t memory_limit = GetSchedulerMemoryLimit(module, gpu_device_info, pointer_size); if (module->has_schedule()) { @@ -469,6 +479,7 @@ absl::StatusOr ScheduleGpuModule( module->config() .debug_options() .xla_gpu_enable_analytical_latency_estimator(); + HloPassPipeline pipeline("latency-hiding-scheduler"); if (profile.has_value()) { auto aggregator = std::make_unique(); auto pg_latency_estimator = std::make_unique( @@ -477,7 +488,7 @@ absl::StatusOr ScheduleGpuModule( LOG(INFO) << "Found profile, using profile guided latency estimator. Profile:\n" << profile->DebugString(); - TF_RETURN_IF_ERROR(pg_latency_estimator->CheckAccuracy(*module)); + pipeline.AddPass(*pg_latency_estimator); latency_estimator = std::move(pg_latency_estimator); } else if (enable_analytical_latency_estimator) { latency_estimator = std::make_unique( @@ -502,7 +513,6 @@ absl::StatusOr ScheduleGpuModule( auto shape_size_in_bytes = [pointer_size](const Shape& shape) { return GetSizeOfShape(shape, pointer_size); }; - HloPassPipeline pipeline("latency-hiding-scheduler"); auto scheduler_core = std::make_unique( shape_size_in_bytes, async_tracker.get(), latency_estimator.get(), config); @@ -513,8 +523,8 @@ absl::StatusOr ScheduleGpuModule( TF_RETURN_IF_ERROR(pipeline.Run(module).status()); - HloPassPipeline postprocessing_pipeline("gpu-schedule-postprocessing"); - postprocessing_pipeline.AddPass(); + HloPassPipeline postprocessing_pipeline("schedule-postprocessing"); + postprocessing_pipeline.AddPass(); TF_RETURN_IF_ERROR(postprocessing_pipeline.Run(module).status()); return ScheduleMetadata{memory_limit}; diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 0304f358d4b132..60b80d656aa7fc 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -46,6 +46,8 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -54,7 +56,7 @@ limitations under the License. namespace xla { namespace gpu { -using ::testing::HasSubstr; +using ::testing::ElementsAre; using ::tsl::testing::StatusIs; class GpuHloScheduleTest : public HloTestBase { @@ -492,6 +494,112 @@ TEST_F(GpuHloScheduleTest, ProfileGuidedCostModel) { } } +TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelFailsWithIncompleteProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT t = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const absl::string_view kProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1000.0 } + )pb"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloString, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true, + /*fdo_profile=*/kProfile))); + + // `dot1` and `ar-start1` are missing from the profile. + EXPECT_THAT(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status(), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST_F( + GpuHloScheduleTest, + ProfileGuidedCostModelDoesNotFailWithIncompleteProfileIfAccuracyCheckerIsDisabled) { // NOLINT(whitespace/line_length) + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT t = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const absl::string_view kProfile = R"pb( + costs { name: "dot0" cost_us: 100.0 } + costs { name: "add0" cost_us: 10.0 } + costs { name: "ar-start" cost_us: 1000.0 } + )pb"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule( + kHloString, GetModuleConfig(/*enable_latency_hiding_scheduler=*/true, + /*enable_gpu_async_tracker=*/true, + /*fdo_profile=*/kProfile))); + + // `dot1` and `ar-start1` are missing from the profile but we disable the + // pass. + module->mutable_config().mutable_debug_options().add_xla_disable_hlo_passes( + "pgle-accuracy-checker"); + TF_EXPECT_OK(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + backend().default_stream_executor()->GetDeviceDescription()) + .status()); +} + TEST_F(GpuHloScheduleTest, ProfileGuidedCostModelWithRematData) { const char* hlo_text = R"( HloModule AsyncAR @@ -1480,5 +1588,48 @@ TEST_F(GpuHloSchedulePostProcessTest, PostProcessAsyncCollectives) { } } +TEST_F(GpuHloScheduleTest, AsyncOps) { + const char* hlo_text = R"( + HloModule m + + op1 { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } + + op2 { + p0 = f32[2,2] parameter(0) + ROOT add = f32[2,2] add(p0, p0) + } + + ENTRY main { + p0 = f32[2,2] parameter(0) + // The `async-start` blocks should be moved up, and the `async-done` blocks + // should be moved down. + acc1_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), + kind=kLoop, calls=op1 + acc1_done = f32[2,2] fusion-done(acc1_start) + acc2_start = ((f32[2,2]), f32[2,2], s32[]) fusion-start(p0), + kind=kLoop, calls=op2 + acc2_done = f32[2,2] fusion-done(acc2_start) + ROOT done = f32[2,2] add(acc1_done, acc2_done) + })"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_text, HloModuleConfig{})); + SequentialHloOrdering order = BuildHloOrdering(module.get()); + + std::vector opcodes; + for (HloInstruction* instruction : + order.SequentialOrder(*module->entry_computation())->instructions()) { + opcodes.push_back(instruction->opcode()); + } + EXPECT_THAT(opcodes, + ElementsAre(HloOpcode::kParameter, HloOpcode::kAsyncStart, + HloOpcode::kAsyncStart, HloOpcode::kAsyncDone, + HloOpcode::kAsyncDone, HloOpcode::kAdd)); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc index 590adffffbd077..24622c8d685265 100644 --- a/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_latency_hiding_scheduler_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/hlo_module_config.h" #include "xla/service/profile_guided_latency_estimator.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index 3099b957575e6f..215609c7e288cb 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -31,15 +31,15 @@ limitations under the License. #include "xla/layout.h" #include "xla/service/buffer_value.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_memory_scheduler.h" #include "xla/service/hlo_rematerialization.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc index 1dcab0c47b98c9..de3adb3f4d4885 100644 --- a/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_p2p_pipeliner_test.cc @@ -146,7 +146,7 @@ TEST_F(GpuP2PPipelinerTest, EXPECT_EQ(send1->channel_id(), send2->channel_id()); const char* kPeeledAttr = "_xla_send_recv_validation=\"invalid\""; - const char* kRotatedAttr = "_xla_send_recv_validation=\"{{0,6}}\""; + const char* kRotatedAttr = "_xla_send_recv_validation={{0,6}}"; EXPECT_THAT(send1->ToString(), ::testing::HasSubstr(kPeeledAttr)); EXPECT_THAT(recv1->ToString(), ::testing::HasSubstr(kPeeledAttr)); EXPECT_THAT(send2->ToString(), ::testing::HasSubstr(kRotatedAttr)); diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc index 4f7635813e28be..06e1e6fa1594b0 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/service/algebraic_simplifier.h" #include "xla/service/conditional_simplifier.h" #include "xla/service/gather_expander.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" @@ -89,7 +89,7 @@ void AddSPMDPasses( const HloModuleConfig& config = hlo_module->config(); - if (config.debug_options().xla_use_shardy()) { + if (config.use_shardy_partitioner()) { spmd_pipeline.AddPass(); } else { spmd_pipeline.AddPass(); diff --git a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc index 765b73b9590de6..42a9e7dcad49f9 100644 --- a/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_spmd_pipeline_test.cc @@ -48,6 +48,7 @@ class GpuSpmdPartitioningTest : public HloTestBase, HloModuleConfig config = GetModuleConfigForTest( /*replica_count=*/1, /*num_partitions=*/num_devices); config.set_num_partitions(num_devices); + config.set_use_shardy_partitioner(UseShardy()); TF_ASSIGN_OR_RETURN(auto module, ParseAndReturnVerifiedModule(hlo_module, config)); @@ -67,7 +68,6 @@ class GpuSpmdPartitioningTest : public HloTestBase, DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_use_shardy(UseShardy()); return debug_options; } }; diff --git a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc index 8b1efb2199501e..dcca8f318f4c85 100644 --- a/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc +++ b/third_party/xla/xla/service/gpu/hlo_algorithm_denylist.cc @@ -25,8 +25,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "xla/debug_options_flags.h" #include "xla/hlo/ir/backend_config.h" +#include "xla/service/gpu/autotuning/gpu_autotuning.pb.h" #include "xla/service/gpu/backend_configs.pb.h" -#include "xla/service/gpu/gpu_autotuning.pb.h" #include "xla/stream_executor/dnn.h" #include "tsl/platform/env.h" #include "tsl/platform/protobuf.h" diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 345fd8c709ee49..e527a097497b7e 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -171,15 +171,40 @@ HloFusionAnalysis HloFusionAnalysis::Create( // static HloFusionAnalysis HloFusionAnalysis::Create( - const HloFusionInstruction* fusion, - const se::DeviceDescription* device_info) { - CHECK(device_info != nullptr); - FusionBackendConfig backend_config = - fusion->has_backend_config() - ? fusion->backend_config()->fusion_backend_config() - : FusionBackendConfig::default_instance(); - return Create(std::move(backend_config), - HloFusionAdaptor::ForInstruction(fusion), device_info); + const HloInstruction& instruction, + const se::DeviceDescription& device_info) { + absl::StatusOr gpu_backend_config = + instruction.backend_config(); + + FusionBackendConfig fusion_backend_config = + gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config() + : FusionBackendConfig::default_instance(); + return Create(std::move(fusion_backend_config), + HloFusionAdaptor::ForInstruction(&instruction), &device_info); +} + +// static +HloFusionAnalysis HloFusionAnalysis::Create( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info) { + absl::StatusOr gpu_backend_config; + + if (consumer.has_backend_config()) { + gpu_backend_config = consumer.backend_config(); + } + + if (!gpu_backend_config.ok() && producer.has_backend_config()) { + gpu_backend_config = producer.backend_config(); + } + + FusionBackendConfig fusion_backend_config = + gpu_backend_config.ok() ? gpu_backend_config->fusion_backend_config() + : FusionBackendConfig::default_instance(); + + return HloFusionAnalysis::Create( + std::move(fusion_backend_config), + HloFusionAdaptor::ForProducerConsumer(&producer, &consumer), + &device_info); } // Returns true if the fusion has consistent transpose heros. @@ -264,7 +289,7 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() } // We expect that the last dimension is swapped with a different dimension. - if (HasConsistentTransposeHeros() && tiled_transpose_->permutation[2] != 2) { + if (HasConsistentTransposeHeros()) { return EmitterFusionKind::kTranspose; } @@ -305,24 +330,5 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { LOG(FATAL) << "Did not find a hero reduction"; } -HloFusionAnalysis AnalyzeProducerConsumerFusion( - const HloInstruction& producer, const HloInstruction& consumer, - const se::DeviceDescription& device_info) { - return HloFusionAnalysis::Create( - consumer.has_backend_config() - ? consumer.backend_config()->fusion_backend_config() - : producer.backend_config() - ->fusion_backend_config(), - HloFusionAdaptor::ForProducerConsumer(&producer, &consumer), - &device_info); -} - -HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer, - const se::DeviceDescription& device_info) { - return HloFusionAnalysis::Create( - consumer.backend_config()->fusion_backend_config(), - HloFusionAdaptor::ForInstruction(&consumer), &device_info); -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h index 146224b394579f..c1b7e5b986c1b0 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.h @@ -58,8 +58,17 @@ class HloFusionAnalysis { static HloFusionAnalysis Create(FusionBackendConfig backend_config, std::unique_ptr fusion, const se::DeviceDescription* device_info); - static HloFusionAnalysis Create(const HloFusionInstruction* fusion, - const se::DeviceDescription* device_info); + + // Creates a HloFusionAnalysis that analyzes just instruction as a standalone + // fusion. + static HloFusionAnalysis Create(const HloInstruction& instruction, + const se::DeviceDescription& device_info); + + // Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer + // into consumer. + static HloFusionAnalysis Create(const HloInstruction& producer, + const HloInstruction& consumer, + const se::DeviceDescription& device_info); const HloFusionAdaptor& fusion() const { return *fusion_; } @@ -131,17 +140,6 @@ class HloFusionAnalysis { InputOutputInfo input_output_info_; }; -// Creates a HloFusionAnalysis that analyzes a hypothetical fusion of producer -// into consumer. -HloFusionAnalysis AnalyzeProducerConsumerFusion( - const HloInstruction& producer, const HloInstruction& consumer, - const se::DeviceDescription& device_info); - -// Creates a HloFusionAnalysis that analyzes just consumer as a standalone -// fusion. -HloFusionAnalysis AnalyzeFusion(const HloInstruction& consumer, - const se::DeviceDescription& device_info); - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc index 04c58194c75846..7328bc6dad0ec9 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -15,9 +15,11 @@ limitations under the License. #include "xla/service/gpu/hlo_fusion_analysis.h" #include +#include "xla/protobuf_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" @@ -48,12 +50,12 @@ TEST_F(HloFusionAnalysisTest, DoesNotPeekOutsideBoundary) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info); + auto analysis = HloFusionAnalysis::Create(*root, device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kLoop); auto analysis_fused = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info); EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -155,7 +157,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFused) { auto* root = module->entry_computation()->root_instruction(); auto analysis = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -186,7 +188,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInConsumer) { auto* root = module->entry_computation()->root_instruction(); auto analysis = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -223,7 +225,7 @@ TEST_F(HloFusionAnalysisTest, ReductionEpilogueFusionPartiallyFusedInBoth) { auto* root = module->entry_computation()->root_instruction(); auto analysis = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -255,7 +257,7 @@ TEST_F(HloFusionAnalysisTest, ReduceMultiOutputFusionWithTransposeBitcast) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info); + auto analysis = HloFusionAnalysis::Create(*root, device_info); EXPECT_EQ(analysis.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -287,7 +289,7 @@ TEST_F(HloFusionAnalysisTest, InvalidReduceMultiOutputFusion) { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info); + auto analysis = HloFusionAnalysis::Create(*root, device_info); // We expect to fallback to the loop emitter, because the two reductions are // not compatible as they reduce over different dimensions. EXPECT_EQ(analysis.GetEmitterFusionKind(), @@ -319,7 +321,7 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { auto* root = module->entry_computation()->root_instruction(); auto analysis_fused = - AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info); + HloFusionAnalysis::Create(*root->operand(0), *root, device_info); EXPECT_EQ(analysis_fused.GetEmitterFusionKind(), HloFusionAnalysis::EmitterFusionKind::kReduction); } @@ -352,5 +354,90 @@ TEST_F(HloFusionAnalysisTest, ConcatFusion) { HloFusionAnalysis::EmitterFusionKind::kConcatenate); } +TEST_F(HloFusionAnalysisTest, ExtractValidGpuBackendConfig) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation.1 { + %x = s32[64] parameter(0) + %y = s32[64] parameter(1) + ROOT %root = s32[64] add(%x, %y) + } + + fused_computation.2 { + %x = s32[64] parameter(0) + %y = s32[64] parameter(1) + ROOT %root = s32[64] add(%x, %y) + } + + ENTRY entry { + %x = s32[64] parameter(0) + %y = s32[64] parameter(1) + %fusion.1 = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation.1, backend_config={"fusion_backend_config": {kind: "__triton"}} + ROOT %fusion.2 = s32[64] fusion(%fusion.1, %y), kind=kLoop, calls=fused_computation.2 + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + auto* consumer = module->entry_computation()->root_instruction(); + auto* producer = consumer->operand(0); + + auto producer_analysis = HloFusionAnalysis::Create(*producer, device_info); + EXPECT_EQ(producer_analysis.fusion_backend_config().kind(), + kTritonFusionKind); + + auto producer_consumer_analysis = + HloFusionAnalysis::Create(*producer, *consumer, device_info); + EXPECT_EQ(producer_consumer_analysis.fusion_backend_config().kind(), + kTritonFusionKind); +} + +TEST_F(HloFusionAnalysisTest, + InvalidGpuBackendConfig_SingleInstruction_Ignored) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + ENTRY entry { + %x = s32[64,64,64] parameter(0) + %y = s32[64,64,64] parameter(1) + ROOT %root = s32[64,128,64] concatenate(x, y), dimensions={1}, backend_config={"outer_dimension_partitions": ["1"]} + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + auto* root = module->entry_computation()->root_instruction(); + auto analysis = HloFusionAnalysis::Create(*root, device_info); + + EXPECT_TRUE( + protobuf_util::ProtobufEquals(analysis.fusion_backend_config(), + FusionBackendConfig::default_instance())); +} + +TEST_F(HloFusionAnalysisTest, + InvalidGpuBackendConfig_ProducerConsumer_Ignored) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fused_computation { + %x = s32[64] parameter(0) + %y = s32[64] parameter(1) + ROOT %root = s32[64] add(%x, %y) + } + + ENTRY entry { + %x = s32[64] parameter(0) + %y = s32[64] parameter(1) + %fusion = s32[64] fusion(%x, %y), kind=kLoop, calls=fused_computation, backend_config={"invalid_field": "some_value"} + ROOT %root = s32[128] concatenate(fusion, y), dimensions={0}, backend_config={"invalid_field": "some_value"} + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + auto* consumer = module->entry_computation()->root_instruction(); + auto* producer = consumer->operand(0); + auto analysis = HloFusionAnalysis::Create(*producer, *consumer, device_info); + + EXPECT_TRUE( + protobuf_util::ProtobufEquals(analysis.fusion_backend_config(), + FusionBackendConfig::default_instance())); +} + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc index 0a19b213922b42..c2d33da7fcf408 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_stats_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/strings/match.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.cc b/third_party/xla/xla/service/gpu/hlo_traversal.cc index 4394226dfadc0b..c2318ab10f1584 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal.cc @@ -500,13 +500,17 @@ bool operator==(const HloInstructionAdaptor& lhs, lhs.instruction_->unique_id() == rhs.instruction_->unique_id(); } +bool operator!=(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs) { + return !(lhs == rhs); +} + namespace { void HloBfsTraversal( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& visit_node, - const std::function& visit_arg, bool visit_operands) { absl::flat_hash_set visited; std::queue q; @@ -514,12 +518,8 @@ void HloBfsTraversal( const auto& adjacent_nodes = visit_operands ? node.GetOperands() : node.GetUsers(); for (const auto& node : adjacent_nodes) { - if (visited.insert(node).second) { - if (fusion.ContainsInstruction(node)) { - q.push(node); - } else { - visit_arg(node); - } + if (fusion.ContainsInstruction(node) && visited.insert(node).second) { + q.push(node); } } }; @@ -548,9 +548,8 @@ void HloBfsConsumersFirstTraversal( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& - visit_node, - const std::function& visit_arg) { - HloBfsTraversal(roots, fusion, visit_node, visit_arg, + visit_node) { + HloBfsTraversal(roots, fusion, visit_node, /*visit_operands=*/true); } @@ -559,9 +558,8 @@ void HloBfsProducersFirstTraversal( const HloFusionAdaptor& fusion, const std::function& visit_node) { - HloBfsTraversal( - producers, fusion, visit_node, [](HloInstructionAdaptor) {}, - /*visit_operands=*/false); + HloBfsTraversal(producers, fusion, visit_node, + /*visit_operands=*/false); } bool HloBfsAnyOf(absl::Span roots, @@ -592,7 +590,7 @@ std::optional HloBfsFindIf( } return TraversalResult::kAdvance; }, - [](HloInstructionAdaptor) {}, visit_operands); + visit_operands); return result; } diff --git a/third_party/xla/xla/service/gpu/hlo_traversal.h b/third_party/xla/xla/service/gpu/hlo_traversal.h index b49d4efc9377ce..b4a5859875ba24 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal.h +++ b/third_party/xla/xla/service/gpu/hlo_traversal.h @@ -53,6 +53,8 @@ class HloInstructionAdaptor { friend bool operator==(const HloInstructionAdaptor& lhs, const HloInstructionAdaptor& rhs); + friend bool operator!=(const HloInstructionAdaptor& lhs, + const HloInstructionAdaptor& rhs); template friend H AbslHashValue(H h, const HloInstructionAdaptor& m); @@ -147,9 +149,7 @@ void HloBfsConsumersFirstTraversal( absl::Span roots, const HloFusionAdaptor& fusion, const std::function& - visit_node, - const std::function& visit_arg = - [](HloInstructionAdaptor) {}); + visit_node); // Visit the HLO nodes starting from `producers` in BFS order following the // `user` edges. Each node will be visited exactly once. diff --git a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc index ee3a4b7ad1239f..fcab8f47a3100f 100644 --- a/third_party/xla/xla/service/gpu/hlo_traversal_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_traversal_test.cc @@ -150,43 +150,31 @@ TEST_F(HloTraversalTest, AdaptorUsers) { TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; - std::vector visited_args; auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { - visited_args.emplace_back(arg.name()); - }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); - EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } TEST_F(HloTraversalTest, TraverseFusionConsumerFirstFromFusionRootAndInnerNode) { auto module = ParseAndReturnVerifiedModule(kTestModule).value(); std::vector visited_nodes; - std::vector visited_args; auto fusion = HloFusionAdaptor::ForInstruction( module->entry_computation()->GetInstructionWithName("fusion")); auto root = fusion->GetRoots()[0]; - HloBfsConsumersFirstTraversal( - {root, root.GetOperand(0)}, *fusion, - [&](HloInstructionAdaptor node) { - visited_nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { - visited_args.emplace_back(arg.name()); - }); + HloBfsConsumersFirstTraversal({root, root.GetOperand(0)}, *fusion, + [&](HloInstructionAdaptor node) { + visited_nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(visited_nodes, ElementsAre("reduce.1", "mul")); - EXPECT_THAT(visited_args, ElementsAre("p0", "negate")); } TEST_F(HloTraversalTest, TraverseFusionProducerFirst) { @@ -379,17 +367,13 @@ TEST_F(HloTraversalTest, FuseFusionConsumer) { EXPECT_TRUE(reduce_1.GetUsers().empty()); std::vector nodes; - std::vector params; - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor param) { params.emplace_back(param.name()); }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(nodes, ElementsAre("reduce.1", "mul", "negate")); - EXPECT_THAT(params, ElementsAre("p0", "sum")); } TEST_F(HloTraversalTest, FuseFusionProducer) { @@ -411,17 +395,13 @@ TEST_F(HloTraversalTest, FuseFusionProducer) { InstructionAdaptorName("fusion.1"))); std::vector nodes; - std::vector params; - HloBfsConsumersFirstTraversal( - fusion->GetRoots(), *fusion, - [&](HloInstructionAdaptor node) { - nodes.emplace_back(node.name()); - return TraversalResult::kAdvance; - }, - [&](HloInstructionAdaptor arg) { params.emplace_back(arg.name()); }); + HloBfsConsumersFirstTraversal(fusion->GetRoots(), *fusion, + [&](HloInstructionAdaptor node) { + nodes.emplace_back(node.name()); + return TraversalResult::kAdvance; + }); EXPECT_THAT(nodes, ElementsAre("difference", "reduce.2")); - EXPECT_THAT(params, ElementsAre("p0", "negate", "fusion.1")); } TEST_F(HloTraversalTest, FuseFusionConsumerAndProducer) { diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 81d05f4d1347fa..d7e7129b2f26ad 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" +#include #include #include #include @@ -28,7 +29,6 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/status/status.h" #include "absl/strings/escaping.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -42,6 +42,7 @@ limitations under the License. #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TargetParser/Triple.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -54,18 +55,16 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/buffer_assignment_util.h" #include "xla/service/llvm_ir/llvm_type_conversion_util.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" -#include "xla/translate/mhlo_to_hlo/location_exporter.h" -#include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/stream_executor/device_description.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/strings/proto_serialization.h" -#include "tsl/platform/errors.h" +#include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" namespace xla { @@ -83,6 +82,13 @@ bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 1; } +bool IsMlirTransposeEmitterEnabled(const HloInstruction& hlo) { + return hlo.GetModule() + ->config() + .debug_options() + .xla_gpu_mlir_emitter_level() >= 3; +} + } // namespace bool IsMatrixMultiplication(const HloInstruction& dot) { @@ -369,16 +375,16 @@ absl::StatusOr GetAllocationSlice( return buffer_assignment.GetUniqueSlice(instr, index); } -std::vector GetOutputDefiningDynamicUpdateSlices( +std::vector GetOutputDefiningDynamicUpdateSlices( absl::Span roots) { - std::vector dus_ops; + std::vector dus_ops; for (HloInstructionAdaptor root : roots) { while (root.opcode() == HloOpcode::kBitcast) { root = root.GetOperand(0); } if (root.opcode() == HloOpcode::kDynamicUpdateSlice) { - dus_ops.push_back(&root.instruction()); + dus_ops.push_back(root); } } return dus_ops; @@ -396,109 +402,86 @@ absl::InlinedVector GetStartIndices(T instr) { } absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - const HloFusionInstruction* fusion, + const HloFusionAdaptor& fusion_adaptor, std::function( const HloInstruction* instr, const ShapeIndex& index)> get_allocation_slice, - absl::Span roots) { - std::vector dus_instrs = - GetOutputDefiningDynamicUpdateSlices(roots); - - // Get output buffers for fusion. - std::vector output_buffers; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion->shape(), [&](const Shape& shape, const ShapeIndex index) { - if (shape.IsArray()) { - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer, - get_allocation_slice(fusion, index)); - output_buffers.push_back(buffer); - } - return absl::OkStatus(); - })); + const HloInstruction* fusion) { + std::vector dus_instrs = + GetOutputDefiningDynamicUpdateSlices(fusion_adaptor.GetRoots()); // This check could probably be relaxed: if code generation is made to use a // separate parallel loop for each dynamic slice update, then it shouldn't be // necessary for every output to be a dynamic slice update, nor to have the // same shape. - if (dus_instrs.size() != output_buffers.size()) { + if (dus_instrs.size() != fusion_adaptor.GetRoots().size()) { return false; } - if (output_buffers.empty()) { - return Internal("Output buffers should not be empty"); - } - - Shape update_shape = dus_instrs[0]->operand(1)->shape(); + Shape update_shape = dus_instrs[0].GetOperand(1).shape(); for (int i = 0; i < dus_instrs.size(); ++i) { - auto* dus = Cast(dus_instrs[i]); + const auto& dus = dus_instrs[i]; - // Dynamic slice updates should have a single path to the root to avoid + // DynamicUpdateSlice ops should have a single path to the root to avoid // allowing a dynamic slice update to depend on another, as this would not // be guaranteed to work with the current codegen. - if (!dus->IsRoot() && dus->user_count() != 1) return false; - - // We follow DUS users until we find a root instruction. We support only - // few patterns: + // We follow DUS users until we find an instruction without users. We + // support only few patterns: // // (1) ROOT dynamic-update-slice // (2) ROOT tuple(dynamic-update-slice) // (3) ROOT bitcast(dynamic-update-slice) // (4) ROOT tuple(bitcast(dynamic-update-slice)) - HloInstruction* dus_user = dus->IsRoot() ? nullptr : dus->users().front(); - - // Since the direct consumer of an output dynamic slice update may be a - // bitcast, we also check that this bitcast is used a single time. - // This property is also important because reads and writes on the parameter - // to be updated are done using the shape and layout of the dynamic slice - // update. This is a valid approach only if a subsequent bitcast is not read - // by any other op within the fusion as this may result in codegen - // accessing elements using the wrong physical layout. - if (dus_user && dus_user->opcode() == HloOpcode::kBitcast) { - if (!dus_user->IsRoot() && dus_user->user_count() != 1) return false; - - // Stop following DUS users if we found a root. - dus_user = dus_user->IsRoot() ? nullptr : dus_user->users().front(); - } - - // Check that last DUS user is a tuple operation at ROOT position. - if (dus_user && dus_user->opcode() == HloOpcode::kTuple) { - if (!dus_user->IsRoot()) return false; - - // Stop following DUS users if we found a root. - dus_user = nullptr; + // + // In case there is a root tuple, the search will stop at the tuple operand, + // as the root tuple is not considered a real user by HloInstructionAdaptor. + // Note that due to AlgebraicSimplifier we will never have a chain of + // bitcasts. + HloInstructionAdaptor real_root = dus; + auto users = real_root.GetUsers(); + while (!users.empty()) { + if (users.size() > 1) { + return false; + } + real_root = users.front(); + if (real_root.opcode() != HloOpcode::kBitcast) { + return false; + } + users = real_root.GetUsers(); } - // We can't emit DUS fusion if we have unsupported DUS users. - if (dus_user != nullptr) return false; - // Find "real" DUS operand by skipping bitcasted operands. - const HloInstruction* operand = dus->operand(0); - if (operand->opcode() == HloOpcode::kBitcast) { - operand = operand->operand(0); + HloInstructionAdaptor operand = dus.GetOperand(0); + if (fusion_adaptor.ContainsInstruction(operand) && + operand.opcode() == HloOpcode::kBitcast) { + operand = operand.GetOperand(0); } // Operand to a DUS (or Bitcast) must be a fusion parameter. - auto* parameter = DynCast(operand); - if (!parameter) return false; + // HloInstructionAdaptor skips parameters, so we need to check whether + // 'operand' is outside of the fusion. + if (fusion_adaptor.ContainsInstruction(operand)) { + return false; + } // We require that the parameter being updated is only read at the same // index positions by all users, since we otherwise risk a race condition // when updating the parameter inplace. - std::queue q; + std::queue q; absl::flat_hash_set visited; - q.push(parameter); - visited.insert(parameter); + q.push(operand); + visited.insert(&operand.instruction()); // We have already checked above that the DUS only has one user. So we don't // need to visit it during the breadth-first search. - visited.insert(dus); + visited.insert(&dus.instruction()); while (!q.empty()) { - const HloInstruction* instr = q.front(); + HloInstructionAdaptor instr = q.front(); q.pop(); - for (const HloInstruction* user : instr->users()) { - if (user->opcode() == HloOpcode::kDynamicSlice && - dus->operand(0) == user->operand(0) && - update_shape == user->shape()) { + for (const HloInstructionAdaptor& user : instr.GetUsers()) { + if (user.opcode() == HloOpcode::kDynamicSlice && + dus.GetOperand(0) == user.GetOperand(0) && + update_shape == user.shape()) { // We can still emit in-place in this case if the same slice is // accessed by the DUS and the DS. If they don't access the same // slice, the two slices might partially overlap and read/write the @@ -506,19 +489,21 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( // read before it is overwritten. However if both access only a single // element, there also can be no race condition. absl::InlinedVector user_start_indices = - GetStartIndices(Cast(user)); + GetStartIndices( + Cast(&user.instruction())); absl::InlinedVector dus_start_indices = - GetStartIndices(dus); + GetStartIndices( + Cast(&dus.instruction())); if (ShapeUtil::ElementsIn(update_shape) != 1 && user_start_indices != dus_start_indices) { return false; } - } else if (user != dus && !user->IsElementwise() && - user->opcode() != HloOpcode::kBitcast && - user->opcode() != HloOpcode::kTuple) { + } else if (user != dus && !user.instruction().IsElementwise() && + user.opcode() != HloOpcode::kBitcast && + user.opcode() != HloOpcode::kTuple) { return false; } - if (visited.insert(user).second) { + if (visited.insert(&user.instruction()).second) { q.push(user); } } @@ -529,16 +514,26 @@ absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( // be necessary for the shape to be the same for all the dynamic slice // updates. Note that this equality check purposefully ignores the element // type. - if (dus->update()->shape() != update_shape) { + if (Cast(&dus.instruction()) + ->update() + ->shape() != update_shape) { return false; } - const HloInstruction* lhs = fusion->operand(parameter->parameter_number()); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, - get_allocation_slice(lhs, {})); - BufferAllocation::Slice rhs_buffer = output_buffers[i]; - if (lhs_buffer != rhs_buffer) { - return false; + if (fusion != nullptr) { + ShapeIndex root_index = {}; + if (fusion->IsMultiOutputFusion()) { + root_index = {i}; + } + // Get output buffer for the fusion root. + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer, + get_allocation_slice(fusion, root_index)); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_buffer, + get_allocation_slice(&operand.instruction(), {})); + if (lhs_buffer != output_buffer) { + return false; + } } } @@ -551,61 +546,95 @@ static std::optional FindTiledTranspose( return std::nullopt; } - if (std::optional tr = ShapeUtil::GetNormalizedTransposeShape( - instr.operand(0)->shape(), instr.shape(), Vector3{0, 2, 1})) { + absl::InlinedVector permutation; + auto tr = ShapeUtil::GetNormalizedTransposeShape(instr.operand(0)->shape(), + instr.shape(), permutation); + if (!tr.has_value()) { + return std::nullopt; + } + if (permutation == absl::InlinedVector{0, 2, 1}) { if ((tr->at(1) >= kMinDimensionToTransposeTiled && tr->at(2) >= kMinDimensionToTransposeTiled) || (tr->at(1) >= kMinDimensionToTransposeTiled2 && tr->at(2) >= kMinDimensionToTransposeTiled2 && tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{0, 2, 1}}; + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{0, 2, 1}}; } - } - if (std::optional tr = ShapeUtil::GetNormalizedTransposeShape( - instr.operand(0)->shape(), instr.shape(), Vector3{2, 1, 0})) { + } else if (permutation == absl::InlinedVector{2, 1, 0}) { if ((tr->at(0) >= kMinDimensionToTransposeTiled && tr->at(2) >= kMinDimensionToTransposeTiled) || (tr->at(0) >= kMinDimensionToTransposeTiled2 && tr->at(2) >= kMinDimensionToTransposeTiled2 && tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{2, 1, 0}}; + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{2, 1, 0}}; + } + } else if (IsMlirTransposeEmitterEnabled(instr)) { + if (permutation == absl::InlinedVector{1, 0, 2}) { + auto byte_width = primitive_util::ByteWidth(instr.shape().element_type()); + if (byte_width * tr->at(2) <= kMaxBytesInMostMinorDimension && + byte_width * tr->at(2) * std::min(tr->at(0), tr->at(1)) >= + kMinDimensionToTransposeTiled) { + return TransposeDescription{ + &instr, *tr, + /*permutation=*/absl::InlinedVector{1, 0, 2}}; + } } } return std::nullopt; } -// Find 021 or 210 transpose in logical + physical transposition. +// Find 021, 210 or 102 transpose in logical + physical transposition. static std::optional FindTiledLogicalTranspose( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kTranspose) { return std::nullopt; } - // TODO(cheshire): avoid code duplication. - if (std::optional tr = ShapeUtil::GetNormalizedLogicalTransposeShape( - instr.operand(0)->shape(), instr.shape(), instr.dimensions(), - Vector3{0, 2, 1})) { - if ((tr->at(1) >= kMinDimensionToTransposeTiled && - tr->at(2) >= kMinDimensionToTransposeTiled) || - (tr->at(1) >= kMinDimensionToTransposeTiled2 && - tr->at(2) >= kMinDimensionToTransposeTiled2 && - tr->at(1) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{0, 2, 1}}; - } + // We can assume that TransposeDimensionGrouper pass has run, so no need to + // call GetNormalizedLogicalTransposeShape here. + absl::InlinedVector permutation(instr.dimensions().begin(), + instr.dimensions().end()); + // A real transpose needs at least 2 transpose dimensions. + if (permutation.size() < 2) { + return std::nullopt; } - if (std::optional tr = ShapeUtil::GetNormalizedLogicalTransposeShape( - instr.operand(0)->shape(), instr.shape(), instr.dimensions(), - Vector3{2, 1, 0})) { - if ((tr->at(0) >= kMinDimensionToTransposeTiled && - tr->at(2) >= kMinDimensionToTransposeTiled) || - (tr->at(0) >= kMinDimensionToTransposeTiled2 && - tr->at(2) >= kMinDimensionToTransposeTiled2 && - tr->at(0) * tr->at(2) >= kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&instr, *tr, - /*permutation=*/Vector3{2, 1, 0}}; + absl::InlinedVector dimensions(instr.shape().dimensions().begin(), + instr.shape().dimensions().end()); + int64_t operand_most_minor_dim = + instr.operand(0)->shape().dimensions().back(); + if (permutation == absl::InlinedVector{0, 2, 1} || + permutation == absl::InlinedVector{2, 1, 0}) { + if ((dimensions.back() >= kMinDimensionToTransposeTiled && + operand_most_minor_dim >= kMinDimensionToTransposeTiled) || + (dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() * operand_most_minor_dim >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&instr, dimensions, permutation}; + } + } else if (IsMlirTransposeEmitterEnabled(instr)) { + if (permutation.back() == dimensions.size() - 1) { + operand_most_minor_dim = + instr.operand(0)->shape().dimensions(dimensions.size() - 2); + auto byte_width = primitive_util::ByteWidth(instr.shape().element_type()); + if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension && + byte_width * dimensions.back() * + std::min(operand_most_minor_dim, + dimensions[dimensions.size() - 2]) >= + kMinDimensionToTransposeTiled) { + return TransposeDescription{&instr, dimensions, permutation}; + } + } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && + dimensions.back() >= kMinDimensionToTransposeTiled) || + (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim * dimensions.back() >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&instr, dimensions, permutation}; } } return std::nullopt; @@ -613,12 +642,6 @@ static std::optional FindTiledLogicalTranspose( std::optional GetDescriptionForTiledTransposeEmitter( const HloInstruction& root, const HloInstruction& hero) { - // TODO(b/284431534): Figure out how to make the shared memory transpose - // emitter faster for this case. - if (hero.shape().element_type() == F32 && root.shape().element_type() == S8) { - return std::nullopt; - } - if (auto d1 = FindTiledTranspose(hero)) { return d1; } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index 044a3537d90282..3dcf0bce20ae23 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -54,6 +55,10 @@ inline constexpr int64_t kMinDimensionToTransposeTiled = 16; // efficient. inline constexpr int64_t kMinDimensionToTransposeTiled2 = 8; inline constexpr int64_t kMinTotalDimensionsToTransposeTiled = 64 * 128; +// As the amount of shared memory is limited, we need to make sure that we don't +// detect 102 transposes that would require too much bytes for the most minor +// dimension. +inline constexpr int64_t kMaxBytesInMostMinorDimension = 8; // Matrix multiplication before the rewrite. bool IsMatrixMultiplication(const HloInstruction& dot); @@ -125,21 +130,25 @@ absl::StatusOr GetAllocationSlice( const BufferAssignment& buffer_assignment, const HloInstruction* instr, const ShapeIndex& index); -// Returns whether 'fusion' can be emitted with the dynamic update slice -// in-place emitter. +// Returns whether the fusion represented by 'fusion_adaptor' can be emitted +// with the dynamic update slice in-place emitter. If 'fusion_adaptor' +// represents a single fusion computation, 'fusion' should provide the fusion +// instruction corresponding to that fusion computation. 'get_allocation_slice' +// is a callback for getting the allocated buffer slice, given an instruction +// and a shape index. This is ignored in case 'fusion' is a nullptr. absl::StatusOr CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - const HloFusionInstruction* fusion, + const HloFusionAdaptor& fusion_adaptor, std::function( const HloInstruction* instr, const ShapeIndex& index)> get_allocation_slice, - absl::Span roots); + const HloInstruction* fusion = nullptr); // Returns the dynamic-update-slice instructions defining the results of a // fusion node. A dynamic slice update is said to be "defining" of a result if // that result is the output of a dynamic slice update, or if that result is the // output of a bitcast of a dynamic slice update---since such bitcast may be // handled as a no-op. -std::vector GetOutputDefiningDynamicUpdateSlices( +std::vector GetOutputDefiningDynamicUpdateSlices( absl::Span roots); // Returns the first hero instruction reachable from `instr` as root. Hero @@ -156,16 +165,18 @@ struct TransposeDescription { const HloInstruction* instr; // Normalized transpose dimensions. - Vector3 dimensions; + absl::InlinedVector dimensions; // Permutations of normalized transpose dimensions. - Vector3 permutation; + absl::InlinedVector permutation; - TransposeDescription(Vector3 dimensions, Vector3 permutation) + TransposeDescription(absl::InlinedVector dimensions, + absl::InlinedVector permutation) : TransposeDescription(/*instr=*/nullptr, dimensions, permutation) {} - TransposeDescription(const HloInstruction* instr, Vector3 dimensions, - Vector3 permutation) + TransposeDescription(const HloInstruction* instr, + absl::InlinedVector dimensions, + absl::InlinedVector permutation) : instr(instr), dimensions(dimensions), permutation(permutation) {} // Transpose instruction input shape. diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index 67ffe2c723fd4f..80407b0835d9eb 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -20,17 +20,19 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/strings/str_cat.h" #include "xla/hlo/ir/backend_config.h" -#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" -#include "xla/util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -41,26 +43,161 @@ namespace gpu { using ::tsl::testing::IsOkAndHolds; using IrEmissionUtilsTest = HloTestBase; +using InlinedVector = absl::InlinedVector; TEST_F(IrEmissionUtilsTest, FindTiledLogicalTranspose) { const char* hlo = R"( HloModule module ENTRY entry { - p = f32[32,48,64]{2,1,0} parameter(0) - ROOT t = f32[64,32,48]{2,1,0} transpose(p), dimensions={2,0,1} + p = f32[1536,64]{1,0} parameter(0) + ROOT t = f32[64,1536]{1,0} transpose(p), dimensions={1,0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({64, 1536})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0})); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical102Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[32,48,2]{2,1,0} parameter(0) + ROOT t = f32[48,32,2]{2,1,0} transpose(p), dimensions={1,0,2} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({1, 64, 1536})); - EXPECT_EQ(result->permutation, Vector3({0, 2, 1})); + EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 2})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2})); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical102TransposeTooMuchMemoryRequired) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = s8[32,48,9]{2,1,0} parameter(0) + ROOT t = s8[48,32,9]{2,1,0} transpose(p), dimensions={1,0,2} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_FALSE(result.has_value()); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical2103Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,2]{3,2,1,0} parameter(0) + ROOT t = f32[32,48,33,2]{3,2,1,0} transpose(p), dimensions={2,1,0,3} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({32, 48, 33, 2})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0, 3})); +} + +TEST_F(IrEmissionUtilsTest, FindTiledLogical1320Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = f32[33,48,32,34]{3,2,1,0} parameter(0) + ROOT t = f32[48,34,32,33]{3,2,1,0} transpose(p), dimensions={1,3,2,0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({48, 34, 32, 33})); + EXPECT_EQ(result->permutation, InlinedVector({1, 3, 2, 0})); +} + +TEST_F(IrEmissionUtilsTest, FindTiled102Transpose) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = s16[32,48,4]{2,1,0} parameter(0) + ROOT t = s16[32,48,4]{2,0,1} copy(p) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, tr); + EXPECT_EQ(result->dimensions, InlinedVector({48, 32, 4})); + EXPECT_EQ(result->permutation, InlinedVector({1, 0, 2})); +} + +TEST_F(IrEmissionUtilsTest, FindTiled102TransposeTooMuchMemoryRequired) { + const char* hlo = R"( +HloModule module + +ENTRY entry { + p = s8[32,48,9]{2,1,0} parameter(0) + ROOT t = s8[32,48,9]{2,0,1} copy(p) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo)); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + + HloInstruction* tr = module->entry_computation()->root_instruction(); + + auto result = GetDescriptionForTiledTransposeEmitter(*tr, *tr); + EXPECT_FALSE(result.has_value()); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTranspose) { @@ -79,8 +216,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOp) { @@ -100,8 +237,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithIntermediateUnaryOpS8) { @@ -124,11 +261,11 @@ ENTRY main { HloInstruction* r = module->entry_computation()->root_instruction()->fused_expression_root(); - // TODO(b/284431534): Update this test when the shared memory transpose - // emitter is fast for S8 output. - EXPECT_FALSE( - GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)).has_value()); - EXPECT_EQ(FindNonTrivialHero(*r).name(), "t"); + auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result->instr, r->operand(0)); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindReduceHeroEpilogueFusion) { @@ -258,8 +395,8 @@ ENTRY entry { auto result = GetDescriptionForTiledTransposeEmitter(*r, *r->operand(0)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindAnyTiledTransposeWithTwoIntermediateBinaryOps) { @@ -289,8 +426,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*r, FindNonTrivialHero(*r)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, r->operand(0)->operand(0)); - EXPECT_EQ(result->dimensions, Vector3({64, 48, 32})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({64, 48, 32})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, @@ -302,8 +439,10 @@ fusion { p = f32[32,48,64]{2,1,0} parameter(0) p2 = f32[48,32,64]{2,1,0} parameter(1) t = f32[64,48,32]{2,1,0} transpose(p), dimensions={2,1,0} - t2 = f32[64,48,32]{2,1,0} transpose(p2), dimensions={2,0,1} - ROOT add = f32[64,48,32]{2,1,0} add(t, t2) + bc = f32[1,1536,64]{2,1,0} bitcast(p2) + t2 = f32[1,64,1536]{2,1,0} transpose(bc), dimensions={0,2,1} + bc2 = f32[64,48,32]{2,1,0} bitcast(t2) + ROOT add = f32[64,48,32]{2,1,0} add(t, bc2) } ENTRY main { @@ -475,8 +614,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, copy); - EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) { @@ -484,13 +623,13 @@ TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOneSwapDimIsSmall) { HloModule module fusion { - p = f32[100,11,12,8]{3,2,1,0} parameter(0) - ROOT t = f32[8,12,100,11]{3,2,1,0} transpose(p), dimensions={3,2,0,1} + p = f32[1100,12,8]{2,1,0} parameter(0) + ROOT t = f32[8,12,1100]{2,1,0} transpose(p), dimensions={2,1,0} } ENTRY main { - param = f32[100,11,12,8]{3,2,1,0} parameter(0) - ROOT fusion = f32[8,12,100,11]{3,2,1,0} fusion(param), kind=kInput, calls=fusion + param = f32[1100,12,8]{2,1,0} parameter(0) + ROOT fusion = f32[8,12,1100]{2,1,0} fusion(param), kind=kInput, calls=fusion } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -502,8 +641,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({8, 12, 1100})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({8, 12, 1100})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledTransposeOtherSwapDimIsSmall) { @@ -529,8 +668,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*copy, FindNonTrivialHero(*copy)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, copy); - EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) { @@ -538,13 +677,13 @@ TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeOtherSwapDimIsSmall) { HloModule module fusion { - p = f32[8,12,100,11]{3,2,1,0} parameter(0) - ROOT t = f32[100,11,12,8]{3,2,1,0} transpose(p), dimensions={2,3,1,0} + p = f32[8,12,1100]{2,1,0} parameter(0) + ROOT t = f32[1100,12,8]{2,1,0} transpose(p), dimensions={2,1,0} } ENTRY main { - param = f32[8,12,100,11]{3,2,1,0} parameter(0) - ROOT fusion = f32[100,11,12,8]{3,2,1,0} fusion(param), kind=kInput, calls=fusion + param = f32[8,12,1100]{2,1,0} parameter(0) + ROOT fusion = f32[1100,12,8]{2,1,0} fusion(param), kind=kInput, calls=fusion } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -556,8 +695,8 @@ ENTRY main { GetDescriptionForTiledTransposeEmitter(*tr, FindNonTrivialHero(*tr)); EXPECT_TRUE(result.has_value()); EXPECT_EQ(result->instr, tr); - EXPECT_EQ(result->dimensions, Vector3({1100, 12, 8})); - EXPECT_EQ(result->permutation, Vector3({2, 1, 0})); + EXPECT_EQ(result->dimensions, InlinedVector({1100, 12, 8})); + EXPECT_EQ(result->permutation, InlinedVector({2, 1, 0})); } TEST_F(IrEmissionUtilsTest, IsContiguousSlice) { @@ -703,12 +842,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -742,12 +882,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -782,8 +923,9 @@ ENTRY main { BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); BufferAllocation::Slice slice1(&alloc, 10, 20); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [fusion, &slice0, &slice1](const HloInstruction* instr, const ShapeIndex&) { if (instr == fusion) { @@ -791,7 +933,7 @@ ENTRY main { } return slice1; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -825,12 +967,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(false)); } @@ -868,12 +1011,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -913,12 +1057,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -954,12 +1099,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -995,12 +1141,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } @@ -1038,12 +1185,13 @@ ENTRY main { auto fusion = module->entry_computation()->root_instruction(); BufferAllocation alloc(/*index=*/0, /*size=*/1024, /*color=*/0); BufferAllocation::Slice slice0(&alloc, 0, 10); + auto adaptor = HloFusionAdaptor::ForInstruction(fusion); EXPECT_THAT(CanEmitFusedDynamicUpdateSliceInPlaceForGpu( - Cast(fusion), + *adaptor, [&slice0](const HloInstruction*, const ShapeIndex&) { return slice0; }, - HloFusionAdaptor::ForInstruction(fusion)->GetRoots()), + fusion), IsOkAndHolds(true)); } diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index a964f6bbd9d72c..b73225a1bd3c56 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -26,7 +26,6 @@ limitations under the License. #include #include #include -#include #include #include "absl/container/flat_hash_map.h" @@ -99,7 +98,6 @@ limitations under the License. #include "xla/service/gpu/fusions/triton/triton_fusion_emitter.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_conv_runner.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/gpu_norm_runner.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" @@ -122,7 +120,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" #include "xla/service/gpu/runtime/fft_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/infeed_thunk.h" @@ -173,6 +170,7 @@ limitations under the License. #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #include "xla/service/gpu/runtime/cholesky_thunk.h" #include "xla/service/gpu/runtime/cub_sort_thunk.h" +#include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/triangular_solve_thunk.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM @@ -955,221 +953,17 @@ absl::Status IrEmitterUnnested::EmitNormThunk( return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitFusedMHAThunk( +absl::Status IrEmitterUnnested::EmitCuDnnThunk( const HloCustomCallInstruction* instr) { - const HloInstruction* lhs_bmm1 = instr->operand(0); - const HloInstruction* rhs_bmm1 = instr->operand(1); - const HloInstruction* rhs_bmm2 = instr->operand(2); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice lhs_bmm1_slice, - GetAllocationSliceForHlo(lhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm1_slice, - GetAllocationSliceForHlo(rhs_bmm1)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice rhs_bmm2_slice, - GetAllocationSliceForHlo(rhs_bmm2)); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, - GetAllocationSliceForHlo(instr, {0})); - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo( - instr, {instr->shape().tuple_shapes_size() - 1})); - BufferAllocation::Slice activation_slice; - bool has_activation = xla::ShapeUtil::TupleElementCount(instr->shape()) == 3; - if (has_activation) { - TF_ASSIGN_OR_RETURN(activation_slice, GetAllocationSliceForHlo(instr, {1})); - } - - TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, - xla::gpu::GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice, bias_slice; - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - std::optional mask_shape, bias_shape; - { - bool has_bias = kind == CudnnfMHAKind::kScaleBiasSoftmax || - kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout; - - if (has_bias) { - const HloInstruction* bias = instr->operand(3); - TF_ASSIGN_OR_RETURN(bias_slice, GetAllocationSliceForHlo(bias)); - bias_shape = bias->shape(); - } - int64_t seqlen_qk_operand_index = 3 + has_bias; - bool has_seqlen_qk = seqlen_qk_operand_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(seqlen_qk_operand_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = - instr->operand(seqlen_qk_operand_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - } - } - - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - absl::InlinedVector output_shapes = { - ShapeUtil::GetSubshape(instr->shape(), {0})}; - if (has_activation) { - output_shapes.push_back(ShapeUtil::GetSubshape(instr->shape(), {1})); - } - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - GpufMHADescriptor descriptor = {kind, - config, - mask_type, - lhs_bmm1->shape(), - rhs_bmm1->shape(), - rhs_bmm2->shape(), - intermediate_tensor_shape, - output_shapes, - config.bmm1_dot_dimension_numbers(), - config.bmm2_dot_dimension_numbers(), - mask_shape, - bias_shape}; - - TF_ASSIGN_OR_RETURN(GpufMHAConfig fmha_config, - GpufMHAConfig::For(descriptor)); - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(fmha_config), - lhs_bmm1_slice, rhs_bmm1_slice, rhs_bmm2_slice, output_slice, - scratch_slice, mask_slice, bias_slice, activation_slice, seqlen_q_slice, - seqlen_k_slice)); - return absl::OkStatus(); -} - -absl::Status IrEmitterUnnested::EmitFusedMHABackwardThunk( - const HloCustomCallInstruction* instr) { - TF_ASSIGN_OR_RETURN(const auto gpu_config, - instr->backend_config()); - const xla::gpu::CudnnfMHABackendConfig& config = - gpu_config.cudnn_fmha_backend_config(); - - int input_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm1_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm1_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm2_rhs_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape bmm2_grad_gemm1_lhs_shape; - - Shape intermediate_tensor_shape(config.intermediate_tensor_shape()); - bmm2_grad_gemm1_lhs_shape = intermediate_tensor_shape; - input_index++; - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - Shape d_output_shape = instr->operand(input_index++)->shape(); - - TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, GetCudnnfMHAKind(instr)); - BufferAllocation::Slice mask_slice; - std::optional mask_shape; - - bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || - kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); - BufferAllocation::Slice bias_slice; - std::optional bias_shape; - if (has_bias) { - TF_ASSIGN_OR_RETURN(bias_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - bias_shape = instr->operand(input_index++)->shape(); - } - - BufferAllocation::Slice fwd_output_slice; - std::optional fwd_output_shape; - - TF_ASSIGN_OR_RETURN(fwd_output_slice, - GetAllocationSliceForHlo(instr->operand(input_index))); - fwd_output_shape = instr->operand(input_index++)->shape(); - - BufferAllocation::Slice seqlen_q_slice, seqlen_k_slice; - bool has_seqlen_qk = input_index == instr->operand_count() - 2; - if (has_seqlen_qk) { - const HloInstruction* seqlen_q = instr->operand(input_index); - TF_ASSIGN_OR_RETURN(seqlen_q_slice, GetAllocationSliceForHlo(seqlen_q)); - const HloInstruction* seqlen_k = instr->operand(input_index + 1); - TF_ASSIGN_OR_RETURN(seqlen_k_slice, GetAllocationSliceForHlo(seqlen_k)); - input_index += 2; - } - TF_RET_CHECK(input_index == instr->operand_count()); - - int output_index = 0; - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_lhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_lhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm1_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm1_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice d_bmm2_rhs_slice, - GetAllocationSliceForHlo(instr, {output_index})); - Shape d_bmm2_rhs_shape = - ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - - BufferAllocation::Slice d_s_slice; - std::optional d_s_shape; - - bool has_dbias = instr->shape().tuple_shapes().size() == 5; - BufferAllocation::Slice d_bias_slice; - std::optional d_bias_shape; - if (has_dbias) { - TF_ASSIGN_OR_RETURN(d_bias_slice, - GetAllocationSliceForHlo(instr, {output_index})); - d_bias_shape = ShapeUtil::GetSubshape(instr->shape(), {output_index++}); - } - TF_ASSIGN_OR_RETURN(BufferAllocation::Slice scratch_slice, - GetAllocationSliceForHlo(instr, {output_index++})); - TF_RET_CHECK(output_index == instr->shape().tuple_shapes().size()); - TF_ASSIGN_OR_RETURN(const auto mask_type, - AsCudnnFmhaMaskKind(config.mask_type())); - bool force_deterministic = config.force_deterministic(); - GpufMHABackwardDescriptor descriptor = { - kind, - config, - mask_type, - bmm1_grad_gemm1_rhs_shape, - bmm1_grad_gemm2_rhs_shape, - bmm2_grad_gemm1_lhs_shape, - bmm2_grad_gemm2_rhs_shape, - d_output_shape, - d_bmm1_lhs_shape, - d_bmm1_rhs_shape, - d_bmm2_rhs_shape, - config.bmm1_grad_gemm1_dot_dimension_numbers(), - config.bmm1_grad_gemm2_dot_dimension_numbers(), - config.bmm2_grad_gemm1_dot_dimension_numbers(), - config.bmm2_grad_gemm2_dot_dimension_numbers(), - d_s_shape, - fwd_output_shape, - mask_shape, - d_bias_shape, - bias_shape, - force_deterministic}; - - TF_ASSIGN_OR_RETURN(GpufMHABackwardConfig fmha_backward_config, - GpufMHABackwardConfig::For(descriptor)); - - AddThunkToThunkSequence(std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), - std::move(fmha_backward_config), bmm1_grad_gemm1_rhs_slice, - bmm1_grad_gemm2_rhs_slice, bmm2_grad_gemm1_lhs_slice, - bmm2_grad_gemm2_rhs_slice, d_output_slice, scratch_slice, - d_bmm1_lhs_slice, d_bmm1_rhs_slice, d_bmm2_rhs_slice, d_s_slice, - mask_slice, d_bias_slice, fwd_output_slice, bias_slice, seqlen_q_slice, - seqlen_k_slice)); - + TF_ASSIGN_OR_RETURN( + auto kernel_arguments, + KernelArguments::Create(ir_emitter_context_->buffer_assignment(), instr, + instr->operands())); + TF_ASSIGN_OR_RETURN(const std::string fingerprint, + FingerprintWithBackendConfig(*instr)); + AddThunkToThunkSequence(std::make_unique( + fingerprint, Thunk::ThunkInfo::WithProfileAnnotation(instr), + kernel_arguments.args())); return absl::OkStatus(); } @@ -1698,7 +1492,7 @@ absl::Status IrEmitterUnnested::EmitFusion(const HloFusionInstruction* instr) { const se::DeviceDescription& device_info = ir_emitter_context_->gpu_device_info(); const HloFusionAnalysis fusion_analysis = - HloFusionAnalysis::Create(instr, &device_info); + HloFusionAnalysis::Create(*instr, device_info); std::unique_ptr emitter = GetFusionEmitter(HloFusionInfo( fusion_analysis, instr, &ir_emitter_context_->buffer_assignment())); @@ -2921,11 +2715,8 @@ absl::Status IrEmitterUnnested::EmitHloInstruction( if (IsCustomCallToDnnNorm(*instr)) { return EmitNormThunk(custom_call); } - if (IsFwdCustomCallTofMHA(*instr)) { - return EmitFusedMHAThunk(custom_call); - } - if (IsBwdCustomCallTofMHA(*instr)) { - return EmitFusedMHABackwardThunk(custom_call); + if (IsCustomCallTofMHA(*instr)) { + return EmitCuDnnThunk(custom_call); } #endif // GOOGLE_CUDA if (IsCustomCallToTopK(*instr)) { diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h index f97f106ddfc0df..d19dd5d9c4172c 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.h +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.h @@ -147,8 +147,7 @@ class IrEmitterUnnested : public IrEmitter { absl::Status EmitConvolutionReorderThunk( const HloCustomCallInstruction* instr); absl::Status EmitNormThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHAThunk(const HloCustomCallInstruction* instr); - absl::Status EmitFusedMHABackwardThunk(const HloCustomCallInstruction* instr); + absl::Status EmitCuDnnThunk(const HloCustomCallInstruction* instr); #endif // GOOGLE_CUDA #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM absl::Status EmitCubDeviceRadixSort(const HloCustomCallInstruction* instr); diff --git a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc index 1d4ea628f5b832..3f32225a72759c 100644 --- a/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc +++ b/third_party/xla/xla/service/gpu/kernel_reuse_cache_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include #include "absl/log/check.h" #include "xla/service/gpu/executable.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/kernels/BUILD b/third_party/xla/xla/service/gpu/kernels/BUILD index 0479da5359dea8..737b47db113eba 100644 --- a/third_party/xla/xla/service/gpu/kernels/BUILD +++ b/third_party/xla/xla/service/gpu/kernels/BUILD @@ -8,7 +8,8 @@ load( load("//xla:xla.bzl", "xla_cc_binary") load("//xla/service/gpu:build_defs.bzl", "gpu_kernel_library") load("//xla/stream_executor:build_defs.bzl", "if_gpu_is_configured") -load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tests:build_defs.bzl", "DEFAULT_DISABLED_BACKENDS", "xla_test") +load("//xla/tsl:tsl.bzl", "if_windows") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -72,14 +73,22 @@ cc_library( # a single dependency. cc_library( name = "custom_fusion_library", + tags = [ + "gpu", + "no_rocm", + ], visibility = [":friends"], - deps = [":cutlass_gemm_fusion"], + deps = if_cuda_is_configured([":cutlass_gemm_fusion"]), ) cc_library( name = "cutlass_gemm_fusion", srcs = ["cutlass_gemm_fusion.cc"], hdrs = ["cutlass_gemm_fusion.h"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":custom_kernel", ":custom_kernel_fusion", @@ -94,6 +103,7 @@ cc_library( "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", @@ -106,9 +116,10 @@ xla_test( srcs = ["cutlass_gemm_fusion_test.cc"], backends = ["gpu"], # TODO(b/332820384): Enable when it passes on H100. - disabled_backends = ["gpu_h100"], + disabled_backends = DEFAULT_DISABLED_BACKENDS + ["gpu_h100"], tags = ["no_rocm"], deps = [ + ":custom_kernel", ":custom_kernel_fusion_pattern", ":cutlass_gemm_custom_kernel", ":cutlass_gemm_fusion", @@ -118,9 +129,11 @@ xla_test( "//xla:error_spec", "//xla:literal_util", "//xla:types", - "//xla/service/gpu:custom_kernel_fusion_rewriter", + "//xla:xla_data_proto_cc", "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/transforms:custom_kernel_fusion_rewriter", "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], @@ -128,13 +141,12 @@ xla_test( cc_library( name = "topk_kernel", - srcs = if_gpu_is_configured(["topk_kernel.cc"]), - hdrs = if_gpu_is_configured(["topk_kernel.h"]), + srcs = ["topk_kernel.cc"], + hdrs = ["topk_kernel.h"], compatible_with = [], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), + tags = ["gpu"], deps = [ + ":topk_kernel_gpu", "//xla:shape_util", "//xla:types", "//xla:util", @@ -151,19 +163,17 @@ cc_library( "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:statusor", - ] + if_gpu_is_configured([ - ":topk_kernel_gpu", - ]), + ], ) gpu_kernel_library( name = "topk_kernel_gpu", - srcs = if_gpu_is_configured([ + srcs = [ + "topk_kernel.cu.h", "topk_kernel_bfloat16.cu.cc", "topk_kernel_float.cu.cc", - "topk_kernel.cu.h", - ]), - hdrs = if_gpu_is_configured(["topk_kernel_common.h"]), + ], + hdrs = ["topk_kernel_common.h"], compatible_with = [], deps = [ "//xla:types", @@ -174,7 +184,7 @@ gpu_kernel_library( xla_test( name = "topk_kernel_test", - srcs = if_gpu_is_configured(["topk_kernel_test.cc"]), + srcs = ["topk_kernel_test.cc"], backends = ["gpu"], deps = [ ":topk_kernel", @@ -223,7 +233,7 @@ cc_library( xla_test( name = "topk_custom_kernel_test", - srcs = if_gpu_is_configured(["topk_custom_kernel_test.cc"]), + srcs = ["topk_custom_kernel_test.cc"], backends = ["gpu"], deps = [ ":topk_custom_kernel", @@ -231,14 +241,13 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/service:platform_util", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/random", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -252,11 +261,12 @@ xla_test( cc_library( name = "cutlass_gemm_custom_kernel", - srcs = if_cuda_is_configured( - ["cutlass_gemm_custom_kernel.cc"], - ["cutlass_gemm_custom_kernel_stub.cc"], - ), + srcs = ["cutlass_gemm_custom_kernel.cc"], hdrs = ["cutlass_gemm_custom_kernel.h"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":custom_kernel", ":cutlass_gemm", @@ -274,18 +284,18 @@ cc_library( xla_test( name = "cutlass_gemm_custom_kernel_test", - srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_test.cc"]), + srcs = ["cutlass_gemm_custom_kernel_test.cc"], backends = ["gpu"], data = [":cutlass_gemm_kernel_f32xf32_to_f32.so"], + tags = ["no_rocm"], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -296,13 +306,16 @@ xla_test( xla_cc_binary( name = "cutlass_gemm_custom_kernel_benchmarks", testonly = 1, - srcs = if_cuda_is_configured(["cutlass_gemm_custom_kernel_benchmarks.cc"]), + srcs = ["cutlass_gemm_custom_kernel_benchmarks.cc"], + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":cutlass_gemm_custom_kernel", "//xla:xla_data_proto_cc", "//xla/service:gpu_plugin", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor/cuda:cuda_platform", @@ -327,19 +340,24 @@ cc_library( cuda_library( name = "cutlass_gemm_adaptor", - hdrs = if_cuda_is_configured(["cutlass_gemm_adaptor.cu.h"]), - copts = ["-Wno-unknown-attributes"], # __grid_constant__ is not supported by clang - deps = if_cuda_is_configured([ + hdrs = ["cutlass_gemm_adaptor.cu.h"], + copts = if_windows( + [], + ["-Wno-unknown-attributes"], + ), # __grid_constant__ is not supported by clang + tags = ["no_rocm"], + deps = [ ":cutlass_gemm", "@cutlass_archive//:cutlass", - ]), + ], ) cuda_library( name = "cutlass_gemm_epilogue", + tags = ["no_rocm"], # TODO(ezhulenev): Update to regular hdrs after fixing CUTLASS headers. - textual_hdrs = if_cuda_is_configured(["cutlass_gemm_epilogue.cu.h"]), - deps = if_cuda_is_configured(["@cutlass_archive//:cutlass"]), + textual_hdrs = ["cutlass_gemm_epilogue.cu.h"], + deps = ["@cutlass_archive//:cutlass"], ) #===--------------------------------------------------------------------------------------------===# @@ -351,9 +369,17 @@ cuda_library( cc_library( name = "cutlass_gemm_kernels", + tags = [ + "gpu", + "no_rocm", + ], deps = [ ":cutlass_gemm_kernel_bf16xbf16_to_bf16", ":cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", + ":cutlass_gemm_kernel_bf16xbf16_to_f32", + ":cutlass_gemm_kernel_bf16xf32_to_f32", + ":cutlass_gemm_kernel_bf16xs8_to_f32", + ":cutlass_gemm_kernel_f32xbf16_to_f32", ":cutlass_gemm_kernel_f32xf32_to_f32", ] + if_cuda_newer_than( "12_0", @@ -369,47 +395,144 @@ cc_library( cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"]), - copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - deps = if_cuda_is_configured([ + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"]), - copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - deps = if_cuda_is_configured([ + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm80.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"]), - copts = ["-Wno-ctad-maybe-unsupported -Wno-unknown-attributes -mllvm -unroll-threshold=100000"], - deps = if_cuda_is_configured([ + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_bf16_sm90.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + [ + "-Wno-ctad-maybe-unsupported", + "-Wno-unknown-attributes", + ], + ), + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", ":cutlass_gemm_epilogue", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], ) cuda_library( name = "cutlass_gemm_kernel_f32xf32_to_f32", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"]), - copts = ["-Wno-unknown-attributes"], - deps = if_cuda_is_configured([ + srcs = ["cutlass_gemm_kernel_f32xf32_to_f32.cu.cc"], + copts = if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ ":cutlass_gemm_adaptor", "@cutlass_archive//:cutlass", "@local_config_cuda//cuda:cuda_headers", - ]), + ], +) + +cuda_library( + name = "cutlass_gemm_kernel_bf16xbf16_to_f32", + srcs = ["cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cuda_library( + name = "cutlass_gemm_kernel_bf16xf32_to_f32", + srcs = ["cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cuda_library( + name = "cutlass_gemm_kernel_f32xbf16_to_f32", + srcs = ["cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc"], + copts = [ + "-mllvm", + "-unroll-threshold=100000", + ] + if_windows( + [], + ["-Wno-unknown-attributes"], + ), + tags = ["no_rocm"], + deps = [ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ], +) + +cuda_library( + name = "cutlass_gemm_kernel_bf16xs8_to_f32", + srcs = ["cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc"], + copts = ["-Wno-unknown-attributes -mllvm -unroll-threshold=100000"], + tags = [ + "gpu", + "no_rocm", + ], + deps = [ + ":cutlass_gemm_adaptor", + "@cutlass_archive//:cutlass", + "@local_config_cuda//cuda:cuda_headers", + ], ) #===--------------------------------------------------------------------------------------------===# @@ -418,8 +541,12 @@ cuda_library( cc_binary( name = "cutlass_gemm_kernel_f32xf32_to_f32.so", - srcs = if_cuda_is_configured(["cutlass_gemm_kernel_f32xf32_to_f32.cc"]), + srcs = ["cutlass_gemm_kernel_f32xf32_to_f32.cc"], linkshared = True, linkstatic = False, + tags = [ + "gpu", + "no_rocm", + ], deps = [":cutlass_gemm"], ) diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h index 37fb0ad8486aee..963b80c406dfa4 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm.h @@ -46,12 +46,27 @@ namespace xla::gpu::kernel::gemm_universal { enum class Arch { kDefault, kSm80, kSm90 }; +// Keep in sync with cutlass::gemm::GemmUniversalMode. +enum class GemmMode { kGemm, kGemmSplitKParallel, kBatched, kArray, kInvalid }; + template struct Bf16xBf16ToBf16 {}; template struct F32xF32ToF32 {}; +template +struct Bf16xBf16ToF32 {}; + +template +struct Bf16xF32ToF32 {}; + +template +struct F32xBf16ToF32 {}; + +template +struct Bf16xS8ToF32 {}; + // A tag to specialize CUTLASS kernel adaptors for loading kernels from shared // libraries using dlopen. struct DlOpenedKernel {}; @@ -132,6 +147,12 @@ struct DynamicSliceArguments { // Type-erased CUTLASS gemm arguments structure that has all of the details // required for packing CUTLASS kernel parameters. struct Arguments { + GemmMode mode; + + // Number of batches when mode is `kBatched`. + // Number of k-slices when mode is `kGemmSplitKParallel`. + int32_t batch_count; + int32_t m; int32_t n; int32_t k; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h index b8171d615dcfeb..1478dc8f312206 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h @@ -19,13 +19,17 @@ limitations under the License. #include #include #include +#include +#include "third_party/gpus/cuda/include/vector_types.h" #include "cute/layout.hpp" #include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/gemm_coord.h" +#include "cutlass/kernel_hardware_info.h" #include "cutlass/layout/matrix.h" #include "cutlass/util/packed_stride.hpp" #include "xla/service/gpu/kernels/cutlass_gemm.h" @@ -137,6 +141,21 @@ static bool CanImplement(const Arguments &args) { cutlass::Status::kSuccess; } +inline cutlass::gemm::GemmUniversalMode ToGemmUniversalMode(GemmMode mode) { + switch (mode) { + case GemmMode::kGemm: + return cutlass::gemm::GemmUniversalMode::kGemm; + case GemmMode::kGemmSplitKParallel: + return cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel; + case GemmMode::kBatched: + return cutlass::gemm::GemmUniversalMode::kBatched; + case GemmMode::kArray: + return cutlass::gemm::GemmUniversalMode::kArray; + case GemmMode::kInvalid: + return cutlass::gemm::GemmUniversalMode::kInvalid; + } +} + // Converts type-erased gemm arguments to the underlying CUTLASS operation // arguments. template @@ -148,7 +167,7 @@ static typename Traits::Arguments OpArguments(const Arguments &args) { auto ldb = LdB::Operation>(problem_size); auto ldc = LdC::Operation>(problem_size); - auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmUniversalMode mode = ToGemmUniversalMode(args.mode); // TODO(ezhulenev): We hardcode parameters for `LinearCombination` // epilogue, however `Gemm` template can be compiled with arbitrary @@ -160,7 +179,7 @@ static typename Traits::Arguments OpArguments(const Arguments &args) { return typename Traits::Arguments( // CUTLASS Operation arguments mode, problem_size, // - 1, // batch + args.batch_count, // batch or k-split slices {alpha, beta}, // epilogue args.lhs, args.rhs, args.out, args.out, // pointers 0, 0, 0, 0, // batch strides @@ -199,8 +218,9 @@ namespace adaptor_3x { template static std::optional ClusterDim() { typename Traits::Kernel::DispatchPolicy::ClusterShape cluster; - return Dim3{cute::get<0>(cluster), cute::get<1>(cluster), - cute::get<2>(cluster)}; + return Dim3{static_cast(cute::get<0>(cluster)), + static_cast(cute::get<1>(cluster)), + static_cast(cute::get<2>(cluster))}; } template @@ -236,7 +256,9 @@ static typename Traits::Arguments OpArguments(const Arguments &args) { // TODO(ezhulenev): Pass device id and sm_count in arguments. cutlass::KernelHardwareInfo hw_info{/*device_id=*/0, /*sm_count=*/128}; - auto mode = cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmUniversalMode mode = + static_cast( + static_cast(args.mode)); typename Kernel::ProblemShape problem_shape = {args.m, args.n, args.k, /*batch=*/1}; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc index ae39cfbe293d1d..a97fe047345c35 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.cc @@ -90,8 +90,8 @@ static int32_t* SlicePtr(const se::KernelArgsDeviceMemoryArray* args, } template -KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, - const ArgsIndices& indices, +KernelArgsPacking ArgsPacking(GemmMode mode, int32_t batch_count, int32_t m, + int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, int32_t device_sms, Adaptor adaptor) { using Packed = absl::StatusOr>; @@ -101,13 +101,17 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, // object constructed in the storage. For now we ignore it, and it's textbook // definition of UB, but for CUTLASS kernels we use today it's perfectly safe. struct Params { +#if defined(_MSC_VER) + alignas(64) std::byte storage[1024]; +#else alignas(128) std::byte storage[1024]; +#endif }; return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { auto* mem_args = se::Cast(&args); - Arguments arguments = {m, n, k}; + Arguments arguments = {mode, batch_count, m, n, k}; arguments.lhs = const_cast(mem_args->device_memory_ptr(indices.lhs)); arguments.rhs = const_cast(mem_args->device_memory_ptr(indices.rhs)); arguments.out = const_cast(mem_args->device_memory_ptr(indices.out)); @@ -172,7 +176,8 @@ KernelArgsPacking ArgsPacking(int32_t m, int32_t n, int32_t k, //===----------------------------------------------------------------------===// template -static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k, +static CustomKernel Load(std::string name, GemmMode mode, int32_t batch_count, + int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device, @@ -184,8 +189,8 @@ static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k, auto thread_dim = As(adaptor.ThreadDim()); auto shared_memory_bytes = adaptor.SharedMemoryBytes(); - auto packing = - ArgsPacking(m, n, k, indices, slices, device.core_count(), adaptor); + auto packing = ArgsPacking(mode, batch_count, m, n, k, indices, slices, + device.core_count(), adaptor); se::MultiKernelLoaderSpec spec(/*arity=*/2, std::move(packing)); spec.AddInProcessSymbol(kernel.symbol(), name); @@ -200,33 +205,83 @@ static CustomKernel Load(std::string name, int32_t m, int32_t n, int32_t k, } absl::StatusOr> GetCutlassGemmKernels( - std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + std::string name, PrimitiveType dot_type, PrimitiveType lhs_type, + PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device) { auto& cuda_cc = std::get(device.gpu_compute_capability()); - switch (dtype) { - case PrimitiveType::F32: - return {{Load>(std::move(name), m, n, k, indices, - slices, device)}}; - case PrimitiveType::BF16: + if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 && + rhs_type == PrimitiveType::F32) { + return {{Load>(std::move(name), GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, + slices, device)}}; + } + + if (dot_type == PrimitiveType::BF16 && lhs_type == PrimitiveType::BF16 && + rhs_type == PrimitiveType::BF16) { #if CUDA_VERSION >= 12000 if (cuda_cc.IsAtLeastHopper()) { - return {{Load>(std::move(name), m, n, k, indices, - slices, device)}}; + return {{Load>(std::move(name), GemmMode::kGemm, + /*batch_count=*/1, m, n, k, + indices, slices, device)}}; } #endif if (cuda_cc.IsAtLeastAmpere()) { - return {{Load>(std::move(name), m, n, k, indices, - slices, device)}}; + return {{Load>( + std::move(name), GemmMode::kGemm, /*batch_count=*/1, m, n, k, + indices, slices, device)}}; } - return {{Load>(std::move(name), m, n, k, indices, - slices, device)}}; + return {{Load>(std::move(name), GemmMode::kGemm, + /*batch_count=*/1, m, n, k, + indices, slices, device)}}; + } - default: - return absl::InvalidArgumentError("Unsupported CUTLASS gemm data type"); + if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 && + rhs_type == PrimitiveType::BF16) { + return {{Load>(std::move(name), GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, + slices, device)}}; } + + if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 && + rhs_type == PrimitiveType::F32) { + return {{Load>(name, GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, + slices, device), + Load>(name, GemmMode::kGemmSplitKParallel, + /*batch_count=*/16, m, n, k, indices, + slices, device)}}; + } + + if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::F32 && + rhs_type == PrimitiveType::BF16) { + return {{Load>(name, GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, + slices, device), + Load>(name, GemmMode::kGemmSplitKParallel, + /*batch_count=*/16, m, n, k, indices, + slices, device)}}; + } + + if (dot_type == PrimitiveType::F32 && lhs_type == PrimitiveType::BF16 && + rhs_type == PrimitiveType::S8) { + return {{ + Load>(name, GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, slices, + device), + Load>(name, GemmMode::kGemmSplitKParallel, + /*batch_count=*/16, m, n, k, indices, + slices, device), + }}; + } + + std::string kernel_name = PrimitiveType_Name(lhs_type) + "x" + + PrimitiveType_Name(rhs_type) + "To" + + PrimitiveType_Name(dot_type); + return absl::InvalidArgumentError(absl::StrCat( + "Unsupported CUTLASS gemm data type for kernel: ", kernel_name)); } absl::StatusOr LoadCutlassGemmKernel( @@ -246,8 +301,9 @@ absl::StatusOr LoadCutlassGemmKernel( "Failed to load CUTLASS kernel from a shared library: ", library_path)); } - return Load(std::move(name), m, n, k, indices, slices, device, - *adaptor, *kernel); + return Load(std::move(name), GemmMode::kGemm, + /*batch_count=*/1, m, n, k, indices, slices, + device, *adaptor, *kernel); } } // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h index 37531ef0038f31..04b09251bbae5b 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h @@ -30,7 +30,8 @@ namespace xla::gpu::kernel::gemm_universal { // Returns pre-compiled custom kernels for a given data type and problem size. absl::StatusOr> GetCutlassGemmKernels( - std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, + std::string name, PrimitiveType dot_type, PrimitiveType lhs_type, + PrimitiveType rhs_type, int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, const DynamicSliceIndices& slices, const se::DeviceDescription& device); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc index 8d44bb024294e3..124569ea5461bc 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_benchmarks.cc @@ -21,7 +21,6 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" @@ -55,13 +54,13 @@ static void BM_RowMajorGemm(benchmark::State& state) { TF_ASSERT_OK_AND_ASSIGN( auto custom_kernels, - GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16, m, n, k, + GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::BF16, + PrimitiveType::BF16, PrimitiveType::BF16, m, n, k, /*indices=*/{0, 1, 2}, /*slices=*/{}, device)); const auto& custom_kernel = custom_kernels[0]; - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel.kernel_spec())); // Prepare arguments: a=1.1, b=1.2, c=0.0 se::DeviceMemory a = executor->AllocateArray(m * k, 0); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc deleted file mode 100644 index 8e231ee3b8e6e9..00000000000000 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_stub.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "absl/status/statusor.h" -#include "xla/service/gpu/kernels/custom_kernel.h" -#include "xla/service/gpu/kernels/cutlass_gemm.h" -#include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" -#include "xla/stream_executor/device_description.h" -#include "xla/xla_data.pb.h" - -namespace xla::gpu::kernel::gemm_universal { - -absl::StatusOr> GetCutlassGemmKernels( - std::string name, PrimitiveType dtype, int32_t m, int32_t n, int32_t k, - const ArgsIndices& indices, const DynamicSliceIndices& slices, - const se::DeviceDescription& device) { - return absl::InternalError("XLA compiled without CUDA support"); -} - -absl::StatusOr LoadCutlassGemmKernel( - std::string name, const std::string& library_path, PrimitiveType dtype, - int32_t m, int32_t n, int32_t k, const ArgsIndices& indices, - const DynamicSliceIndices& slices, const se::DeviceDescription& device) { - return absl::InternalError("XLA compiled without CUDA support"); -} - -} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc index 458f31ae88a836..7cdc9507e3e7f0 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_custom_kernel_test.cc @@ -22,13 +22,12 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -45,14 +44,14 @@ TEST(CutlassGemmKernelTest, SimpleGemm) { // Load [4, 4] x [4, 4] gemm kernel written in CUDA C++ with CUTLASS. TF_ASSERT_OK_AND_ASSIGN( auto custom_kernels, - GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32, 4, 4, 4, + GetCutlassGemmKernels("cutlass_gemm", PrimitiveType::F32, + PrimitiveType::F32, PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription())); auto custom_kernel = custom_kernels[0]; - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel.kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel.kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; @@ -101,9 +100,8 @@ TEST(CutlassGemmKernelTest, LoadFromSharedLibrary) { "cutlass_gemm", kernel_lib_path, PrimitiveType::F32, 4, 4, 4, /*indices=*/{0, 1, 2}, /*slices=*/{}, executor->GetDeviceDescription()); - TF_ASSERT_OK_AND_ASSIGN( - auto gemm, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto gemm, + executor->LoadKernel(custom_kernel->kernel_spec())); int64_t length = 4 * 4; int64_t byte_length = sizeof(float) * length; diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc index a392801e25c578..946e1f85fa74a2 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/kernels/cutlass_gemm_fusion.h" +#include #include #include #include @@ -23,6 +24,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" @@ -136,14 +138,22 @@ static absl::Status MatchSimpleGemm( return absl::InternalError("unsupported operands type"); } -// Returns matched GEMM with one of the operands upcasted to the accumulator -// data type with an HLO convert instruction. +// Returns matched GEMM with one or both the operands upcasted to the +// accumulator data type with an HLO convert instruction. static absl::StatusOr MatchGemmWithUpcast( HloDotInstruction* dot) { TF_RETURN_IF_ERROR(MatchRowMajorGemm(dot)); GemmWithUpcast match(dot); + // C <- convert(A) * convert(B) + if (Match(const_cast(dot->operand(0)), + m::Convert(&match.lhs_upcast, m::Op())) && + Match(const_cast(dot->operand(1)), + m::Convert(&match.rhs_upcast, m::Op()))) { + return match; + } + // C <- convert(A) * B if (Match(const_cast(dot->operand(0)), m::Convert(&match.lhs_upcast, m::Op()))) { @@ -254,16 +264,19 @@ CutlassGemmWithUpcastPattern::TryMatch(const se::DeviceDescription& device, if (!dot) return std::nullopt; auto matched = MatchGemmWithUpcast(dot); - if (!matched.ok()) return std::nullopt; - // Only one operand can be upcasted. - DCHECK(matched->lhs_upcast == nullptr || matched->rhs_upcast == nullptr); + if (!matched.ok()) return std::nullopt; CustomFusionConfig config; config.set_name("cutlass_gemm_with_upcast"); - return matched->lhs_upcast ? Match{config, {matched->lhs_upcast, instr}} - : Match{config, {matched->rhs_upcast, instr}}; + if (matched->lhs_upcast != nullptr && matched->rhs_upcast == nullptr) { + return Match{config, {matched->lhs_upcast, instr}}; + } else if (matched->rhs_upcast != nullptr && matched->lhs_upcast == nullptr) { + return Match{config, {matched->rhs_upcast, instr}}; + } else { + return Match{config, {matched->lhs_upcast, matched->rhs_upcast, instr}}; + } } //===----------------------------------------------------------------------===// @@ -283,7 +296,7 @@ class CutlassGemmFusion : public CustomKernelFusion { TF_RETURN_IF_ERROR(MatchSimpleGemm(dot, {PrimitiveType::F32})); - auto dtype = dot->shape().element_type(); + PrimitiveType dot_type = dot->shape().element_type(); auto* lhs = Cast(dot->operand(0)); auto* rhs = Cast(dot->operand(1)); @@ -293,15 +306,19 @@ class CutlassGemmFusion : public CustomKernelFusion { lhs->parameter_number(), rhs->parameter_number(), computation->num_parameters()}; - auto& lhs_shape = lhs->shape(); - auto& rhs_shape = rhs->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); size_t m = lhs_shape.dimensions(0); size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - return kernel::gemm_universal::GetCutlassGemmKernels( - "cutlass_gemm", dtype, m, n, k, indices, /*slices=*/{}, device); + PrimitiveType lhs_type = lhs->shape().element_type(); + PrimitiveType rhs_type = rhs->shape().element_type(); + + return GetCutlassGemmKernels("cutlass_gemm", dot_type, lhs_type, rhs_type, + m, n, k, indices, + /*slices=*/{}, device); } }; @@ -313,23 +330,44 @@ class CutlassGemmWithUpcastFusion : public CustomKernelFusion { auto* dot = DynCast(computation->root_instruction()); if (dot == nullptr) { return absl::InternalError( - "cutlass_gemm requires ROOT operation to be a dot"); + "cutlass_gemm_with_upcast requires ROOT operation to be a dot"); + } + + TF_ASSIGN_OR_RETURN(GemmWithUpcast matched, MatchGemmWithUpcast(dot)); + + const HloParameterInstruction* lhs; + const HloParameterInstruction* rhs; + + if (matched.lhs_upcast == nullptr && matched.rhs_upcast != nullptr) { + lhs = Cast(matched.dot->operand(0)); + rhs = Cast(matched.rhs_upcast->operand(0)); + } else if (matched.lhs_upcast != nullptr && matched.rhs_upcast == nullptr) { + lhs = Cast(matched.lhs_upcast->operand(0)); + rhs = Cast(matched.dot->operand(1)); + } else { + lhs = Cast(matched.lhs_upcast->operand(0)); + rhs = Cast(matched.rhs_upcast->operand(0)); } - TF_ASSIGN_OR_RETURN(auto matched, MatchGemmWithUpcast(dot)); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); - // We only support upcasting of rhs operand. - if (matched.lhs_upcast != nullptr) - return absl::InternalError("only rhs upcasting is implemented"); + size_t m = lhs_shape.dimensions(0); + size_t k = lhs_shape.dimensions(1); + size_t n = rhs_shape.dimensions(1); - auto dot_dtype = dot->shape().element_type(); - auto upcast_dtype = matched.rhs_upcast->shape().element_type(); + PrimitiveType dot_type = dot->shape().element_type(); + PrimitiveType lhs_type = lhs_shape.element_type(); + PrimitiveType rhs_type = rhs_shape.element_type(); - // We only support BF16 <- BF16 x S8 upcasted gemm. - if (dot_dtype != PrimitiveType::BF16 || upcast_dtype != PrimitiveType::S8) - return absl::InternalError("unsupported upcasting pattern"); + // Mapping from fusion arguments to gemm kernel arguments. + kernel::gemm_universal::ArgsIndices args_indices = { + lhs->parameter_number(), rhs->parameter_number(), + computation->num_parameters()}; - return absl::UnimplementedError("requires CUTLASS 3.3.0"); + return GetCutlassGemmKernels("cutlass_gemm_with_upcast", dot_type, lhs_type, + rhs_type, m, n, k, args_indices, /*slices=*/{}, + device); } }; @@ -353,7 +391,7 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion { MatchSimpleGemm(Cast(matched.dot), {PrimitiveType::F32, PrimitiveType::BF16})); - auto dtype = matched.dot->shape().element_type(); + auto dot_type = matched.dot->shape().element_type(); auto* lhs = Cast(matched.dot->operand(0)); auto* rhs = Cast(matched.dot->operand(1)); @@ -370,21 +408,25 @@ class CutlassGemmWithDynamicUpdateSliceFusion : public CustomKernelFusion { kernel::gemm_universal::DynamicSliceIndices slices; slices.out = offset->parameter_number(); - auto& lhs_shape = lhs->shape(); - auto& rhs_shape = rhs->shape(); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); size_t m = lhs_shape.dimensions(0); size_t k = lhs_shape.dimensions(1); size_t n = rhs_shape.dimensions(1); - return kernel::gemm_universal::GetCutlassGemmKernels( - "cutlass_gemm_with_dynamic_update_slice", dtype, m, n, k, args_indices, - slices, device); + PrimitiveType lhs_type = lhs->shape().element_type(); + PrimitiveType rhs_type = rhs->shape().element_type(); + + return GetCutlassGemmKernels("cutlass_gemm_with_dynamic_update_slice", + dot_type, lhs_type, rhs_type, m, n, k, + args_indices, slices, device); } }; } // namespace xla::gpu +XLA_REGISTER_CUSTOM_FUSION_PATTERN(::xla::gpu::CutlassGemmWithUpcastPattern); XLA_REGISTER_CUSTOM_FUSION_PATTERN( ::xla::gpu::CutlassGemmWithDynamicUpdateSlicePattern); diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc index a6488e9602045b..768feafb80b790 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_fusion_test.cc @@ -17,18 +17,21 @@ limitations under the License. #include #include +#include +#include #include "xla/array.h" #include "xla/array2d.h" #include "xla/array3d.h" #include "xla/error_spec.h" #include "xla/literal_util.h" -#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/kernels/custom_kernel_fusion_pattern.h" #include "xla/service/gpu/kernels/cutlass_gemm_custom_kernel.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include "xla/tests/hlo_test_base.h" #include "xla/types.h" +#include "xla/xla_data.pb.h" #include "tsl/platform/test.h" namespace xla::gpu { @@ -41,10 +44,12 @@ class CutlassFusionTest : public HloTestBase { ->GetDeviceDescription() .shared_memory_per_block_optin(); } - int CutlassGemmKernelSharedMemorySize(PrimitiveType dtype, int m, int n, + int CutlassGemmKernelSharedMemorySize(PrimitiveType dot_type, + PrimitiveType lhs_type, + PrimitiveType rhs_type, int m, int n, int k) { return kernel::gemm_universal::GetCutlassGemmKernels( - "cutlass_gemm", dtype, m, n, k, + "cutlass_gemm", dot_type, lhs_type, rhs_type, m, n, k, /*indices=*/{0, 1, 2}, /*slices=*/{}, backend().default_stream_executor()->GetDeviceDescription()) ->at(0) @@ -134,6 +139,48 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcast) { RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); } +TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastOfBothOperands) { + const char* hlo = R"( + HloModule test + + ENTRY %main (p0: bf16[15,19], p1: bf16[19,17]) -> f32[15,17] { + %p0 = bf16[15,19]{1,0} parameter(0) + %c1 = f32[15,19]{1,0} convert(%p0) + %p1 = bf16[19,17]{1,0} parameter(1) + %c2 = f32[19,17]{1,0} convert(%p1) + ROOT %r = f32[15,17]{1,0} dot(%c1, %c2), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + + const char* expected = R"( + ; CHECK: %cutlass_gemm_with_upcast {{.*}} { + ; CHECK-DAG: [[P0:%[^ ]+]] = bf16[15,19]{1,0} parameter + ; CHECK: [[C1:%[^ ]+]] = f32[15,19]{1,0} convert([[P0]]) + ; CHECK-DAG: [[P1:%[^ ]+]] = bf16[19,17]{1,0} parameter + ; CHECK: [[C2:%[^ ]+]] = f32[19,17]{1,0} convert([[P1]]) + ; CHECK: ROOT [[DOT:%[^ ]+]] = f32[15,17]{1,0} dot([[C1]], [[C2]]), + ; CHECK: lhs_contracting_dims={1}, rhs_contracting_dims={0} + ; CHECK: } + + ; CHECK: ENTRY %main {{.*}} { + ; CHECK: ROOT [[FUSION:%[^ ]+]] = f32[15,17]{1,0} fusion + ; CHECK: kind=kCustom, calls=%cutlass_gemm_with_upcast, + ; CHECK: backend_config={ + ; CHECK: "kind":"__custom_fusion", + ; CHECK: "custom_fusion_config":{"name":"cutlass_gemm_with_upcast","kernel_index":0} + ; CHECK: } + ; CHECK: } + )"; + + CustomKernelFusionPatternRegistry patterns; + patterns.Emplace(); + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + CustomKernelFusionRewriter pass(&device, &patterns); + RunAndFilecheckHloRewrite(hlo, std::move(pass), expected); +} + TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSlice) { const char* hlo = R"( HloModule test @@ -329,9 +376,86 @@ TEST_F(CutlassFusionTest, RowMajorGemmKernel) { error_spec, /*run_hlo_passes=*/false)); } -TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { - GTEST_SKIP() << "Requires CUTLASS 3.3.0+"; +TEST_F(CutlassFusionTest, GemmWithLeftHandSideUpcastKernel) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = bf16[16,32]{1,0} parameter(0) + c0 = f32[16,32]{1,0} convert(p0) + p1 = f32[32,8]{1,0} parameter(1) + gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, p1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0 + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm_with_upcast { + p0 = bf16[16,32]{1,0} parameter(0) + c0 = f32[16,32]{1,0} convert(p0) + p1 = f32[32,8]{1,0} parameter(1) + ROOT dot = f32[16,8]{1,0} dot(c0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY e { + p0 = bf16[16,32]{1,0} parameter(0) + p1 = f32[32,8]{1,0} parameter(1) + ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, + backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + error_spec, /*run_hlo_passes=*/false)); +} + +TEST_F(CutlassFusionTest, GemmWithRightHandSideUpcastKernel) { + ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; + + const char* hlo_text_cublas = R"( + HloModule cublas + + ENTRY e { + p0 = f32[16,32]{1,0} parameter(0) + p1 = bf16[32,8]{1,0} parameter(1) + c1 = f32[32,8]{1,0} convert(p1) + gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} + ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0 + })"; + + const char* hlo_text_custom_fusion = R"( + HloModule cutlass + + cutlass_gemm_with_upcast { + p0 = f32[16,32]{1,0} parameter(0) + p1 = bf16[32,8]{1,0} parameter(1) + c1 = f32[32,8]{1,0} convert(p1) + ROOT dot = f32[16,8]{1,0} dot(p0, c1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + + ENTRY e { + p0 = f32[16,32]{1,0} parameter(0) + p1 = bf16[32,8]{1,0} parameter(1) + ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, + calls=cutlass_gemm_with_upcast, + backend_config={"fusion_backend_config":{kind: "__custom_fusion", + custom_fusion_config: {"name":"cutlass_gemm_with_upcast", + "kernel_index":0}}} + })"; + + EXPECT_TRUE(RunAndCompareTwoModules(hlo_text_cublas, hlo_text_custom_fusion, + error_spec, /*run_hlo_passes=*/false)); +} +TEST_F(CutlassFusionTest, GemmWithLeftHandAndRightHandSideUpcastKernel) { ErrorSpec error_spec{/*aabs=*/1e-3, /*arel=*/1e-3}; const char* hlo_text_cublas = R"( @@ -339,12 +463,13 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { ENTRY e { p0 = bf16[16,32]{1,0} parameter(0) + c0 = f32[16,32]{1,0} convert(p0) p1 = s8[32,8]{1,0} parameter(1) - c1 = bf16[32,8]{1,0} convert(p1) - gemm = (bf16[16,8]{1,0}, s8[0]{0}) custom-call(p0, c1), + c1 = f32[32,8]{1,0} convert(p1) + gemm = (f32[16,8]{1,0}, s8[0]{0}) custom-call(c0, c1), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":[1],"rhs_contracting_dimensions":[0],"lhs_batch_dimensions":[],"rhs_batch_dimensions":[]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"]},"epilogue":"DEFAULT"}} - ROOT get-tuple-element = bf16[16,8]{1,0} get-tuple-element(gemm), index=0 + ROOT get-tuple-element = f32[16,8]{1,0} get-tuple-element(gemm), index=0 })"; const char* hlo_text_custom_fusion = R"( @@ -352,16 +477,17 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { cutlass_gemm_with_upcast { p0 = bf16[16,32]{1,0} parameter(0) + c0 = f32[16,32]{1,0} convert(p0) p1 = s8[32,8]{1,0} parameter(1) - c1 = bf16[32,8]{1,0} convert(p1) - ROOT dot = bf16[16,8]{1,0} dot(p0, c1), + c1 = f32[32,8]{1,0} convert(p1) + ROOT dot = f32[16,8]{1,0} dot(c0, c1), lhs_contracting_dims={1}, rhs_contracting_dims={0} } ENTRY e { p0 = bf16[16,32]{1,0} parameter(0) p1 = s8[32,8]{1,0} parameter(1) - ROOT _ = bf16[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, + ROOT _ = f32[16,8]{1,0} fusion(p0, p1), kind=kCustom, calls=cutlass_gemm_with_upcast, backend_config={"fusion_backend_config":{kind: "__custom_fusion", custom_fusion_config: {"name":"cutlass_gemm_with_upcast", "kernel_index":0}}} })"; @@ -371,7 +497,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithUpcastKernel) { TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { if (GpuSharedMemorySize() < - CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) { + CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) { GTEST_SKIP_("The GPU does not have sufficient shared memory"); } @@ -445,7 +571,7 @@ TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernel) { TEST_F(CutlassFusionTest, RowMajorGemmWithDynamicUpdateSliceKernelWithoutBitcast) { if (GpuSharedMemorySize() < - CutlassGemmKernelSharedMemorySize(BF16, 8, 8, 8)) { + CutlassGemmKernelSharedMemorySize(BF16, BF16, BF16, 8, 8, 8)) { GTEST_SKIP_("The GPU does not have sufficient shared memory"); } diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc new file mode 100644 index 00000000000000..ec08008ce73317 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xbf16_to_f32.cu.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +namespace { + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementOutput = float; +using ElementAccumulator = float; + +} // namespace + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, // stages + 1, // A alignment + 1, // B alignment + cutlass::arch::OpMultiplyAdd>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xBf16ToF32, + GemmOperation); +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc new file mode 100644 index 00000000000000..e117b1a410e4e9 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xf32_to_f32.cu.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +namespace { + +using ElementA = cutlass::bfloat16_t; +using ElementB = float; +using ElementOutput = float; +using ElementAccumulator = float; + +} // namespace + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, // stages + 1, // A alignment + 1, // B alignment + cutlass::arch::OpMultiplyAdd>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xF32ToF32, + GemmOperation); +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc new file mode 100644 index 00000000000000..527d3692548355 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_bf16xs8_to_f32.cu.cc @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +namespace { + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::int8_t; +using ElementOutput = float; +using ElementAccumulator = float; + +} // namespace + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + cutlass::gemm::GemmShape<64, 128, 8>, cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, // stages + 1, // A alignment + 1, // B alignment + cutlass::arch::OpMultiplyAdd>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(Bf16xS8ToF32, GemmOperation); +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc new file mode 100644 index 00000000000000..6ec6963a24f1b3 --- /dev/null +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xbf16_to_f32.cu.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "xla/service/gpu/kernels/cutlass_gemm_adaptor.cu.h" + +namespace xla::gpu::kernel::gemm_universal { + +namespace { + +using ElementA = float; +using ElementB = cutlass::bfloat16_t; +using ElementOutput = float; +using ElementAccumulator = float; + +} // namespace + +using GemmOperation = cutlass::gemm::device::GemmUniversal< + ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor, + ElementOutput, cutlass::layout::RowMajor, ElementAccumulator, + cutlass::arch::OpClassSimt, cutlass::arch::Sm70, + cutlass::gemm::GemmShape<128, 32, 8>, cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, // stages + 1, // A alignment + 1, // B alignment + cutlass::arch::OpMultiplyAdd>; + +XLA_GPU_DEFINE_CUTLASS_GEMM_TRAITS(F32xBf16ToF32, + GemmOperation); +template struct Adaptor>; +template struct DeviceKernel>; + +} // namespace xla::gpu::kernel::gemm_universal diff --git a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc index 5aff534c351aea..119d724f4b486d 100644 --- a/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc +++ b/third_party/xla/xla/service/gpu/kernels/cutlass_gemm_kernel_f32xf32_to_f32.cc @@ -51,14 +51,14 @@ extern "C" int32_t xla_cutlass_kernel_shared_memory_bytes() { extern "C" bool xla_cutlass_kernel_can_implement(int32_t m, int32_t n, int32_t k) { Adaptor adaptor; - Arguments arguments = {m, n, k}; + Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k}; return adaptor.CanImplement(arguments); } extern "C" int64_t xla_cutlass_kernel_workspace_size(int32_t m, int32_t n, int32_t k) { Adaptor adaptor; - Arguments arguments = {m, n, k}; + Arguments arguments = {GemmMode::kGemm, /*batch_count=*/1, m, n, k}; return adaptor.WorkspaceSize(arguments); } @@ -67,7 +67,9 @@ extern "C" void xla_cutlass_kernel_initialize( void* out, void* workspace, int32_t* out_offset, int32_t device_sms, int32_t sm_occupancy) { Adaptor adaptor; - Arguments arguments = {m, n, k, lhs, rhs, out, workspace, {out_offset}}; + Arguments arguments = { + GemmMode::kGemm, /*batch_count=*/1, m, n, k, lhs, rhs, out, + workspace, {out_offset}}; adaptor.Initialize(params, arguments, device_sms, sm_occupancy); } diff --git a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc index 0a8a4d9342b81d..4f6f62605996a6 100644 --- a/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc +++ b/third_party/xla/xla/service/gpu/kernels/topk_custom_kernel_test.cc @@ -28,14 +28,13 @@ limitations under the License. #include "absl/strings/substitute.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -111,9 +110,8 @@ TEST_P(TopKKernelTest, TopKFloat) { auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK_AND_ASSIGN( - auto kernel, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto kernel, + executor->LoadKernel(custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( @@ -166,9 +164,8 @@ TEST_P(TopKKernelTest, TopKPackedNegative) { auto custom_kernel = GetTopKKernel("topk", PrimitiveType::F32, n, k, batch_size); - TF_ASSERT_OK_AND_ASSIGN( - auto kernel, - se::KernelFactory::Create(executor, custom_kernel->kernel_spec())); + TF_ASSERT_OK_AND_ASSIGN(auto kernel, + executor->LoadKernel(custom_kernel->kernel_spec())); // Launch topk kernel with device memory arguments. se::KernelArgsDeviceMemoryArray arr( diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.cc b/third_party/xla/xla/service/gpu/launch_dimensions.cc index 89b322f6708556..f9e28995d09960 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.cc +++ b/third_party/xla/xla/service/gpu/launch_dimensions.cc @@ -16,13 +16,8 @@ limitations under the License. #include "xla/service/gpu/launch_dimensions.h" #include -#include #include -#include -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/strings/str_format.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" @@ -32,139 +27,6 @@ limitations under the License. namespace xla { namespace gpu { -static int64_t ThreadsPerBlockLimit( - const se::DeviceDescription& gpu_device_info) { - int64_t threads_per_block = gpu_device_info.threads_per_block_limit(); - if (threads_per_block <= 0) { - static std::atomic log_count{0}; - if (log_count.fetch_add(1) < 8) { - LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " - "without full information about its capabilities. " - "StreamExecutor's PopulateDeviceDescription should be " - "updated for this device."; - } - threads_per_block = gpu_device_info.threads_per_warp(); - if (threads_per_block == 0) { - // Fall back to *something* if we can't even get num threads per warp. - threads_per_block = 32; - } - } - return threads_per_block; -} - -int64_t ThreadsPerBlockRowVectorized( - const Shape& shape, const se::DeviceDescription& gpu_device_info, - LaunchDimensionsConfig dim_config) { - if (shape.dimensions().empty()) { - return -1; - } - int64_t threads_per_block_row_vectorized = - shape.dimensions().back() / dim_config.unroll_factor; - if (dim_config.row_vectorized && - shape.dimensions().back() % dim_config.unroll_factor == 0 && - // If the row size is a multiple of 256, then use the old code - // path that use a block size of 256. This give small speed up on V100. - // Vectorization of the row load was already happening. - (shape.dimensions().back() % 256) != 0 && - // We do not support row that do not fit in one block. - threads_per_block_row_vectorized <= - gpu_device_info.threads_per_block_limit()) { - return threads_per_block_row_vectorized; - } - return -1; -} - -namespace { - -struct BlockSizes { - int64_t threads_per_block_x; - int64_t threads_per_block_y; - int64_t block_count; -}; - -BlockSizes GetBlockSizes(LaunchDimensionsConfig dim_config, - const se::DeviceDescription& gpu_device_info, - const Shape& shape, int64_t num_elements) { - if (!dim_config.row_vectorized && !dim_config.few_waves) { - BlockSizes result; - const int kWarpSchedulers = 4; - result.threads_per_block_x = std::min( - gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - result.threads_per_block_y = 1; - result.block_count = CeilOfRatio( - num_elements, result.threads_per_block_x * result.threads_per_block_y); - return result; - } - - int64_t threads_per_block_row_vectorized = - ThreadsPerBlockRowVectorized(shape, gpu_device_info, dim_config); - // If row vectorized, threads_per_block_x is the vectorized size. - // Otherwise, we unroll kernels to make use of vectorized - // loads/stores. This means we need more registers to hold - // intermediate values. Reduce the number of threads per block to - // increase the number of registers available to ptxas. Make sure - // we still have a multiple of 32. - BlockSizes result; - int64_t max_threads_per_block_x = - threads_per_block_row_vectorized > 0 - ? threads_per_block_row_vectorized - : RoundUpTo(ThreadsPerBlockLimit(gpu_device_info) / - dim_config.unroll_factor, - int64_t{32}); - result.threads_per_block_x = std::min(num_elements, max_threads_per_block_x); - // threads_per_block_y > 1 when we row vectorize and have small row size. - result.threads_per_block_y = - threads_per_block_row_vectorized > 0 && - threads_per_block_row_vectorized < 128 && num_elements > 128 - ? CeilOfRatio(static_cast(128), - threads_per_block_row_vectorized) - : 1; - VLOG(2) << "Set # of threads per block to (.x=" << result.threads_per_block_x - << ", .y=" << result.threads_per_block_y << ")"; - - result.block_count = CeilOfRatio( - num_elements, result.threads_per_block_x * result.threads_per_block_y); - if (dim_config.few_waves) { - if (dim_config.row_vectorized) { - // This multiple of 32 was tuned to not cause regression on multiple - // benchmarks. It isn't a value that is optimal for all kernels. Maybe - // looking at the arithmetic intensity of the kernels can specialize the - // multiple per kernel. - int64_t max_block_count = - 32 * gpu_device_info.core_count() * - (gpu_device_info.threads_per_core_limit() / - (result.threads_per_block_x * result.threads_per_block_y)); - int64_t capped_block_count = result.block_count; - while (capped_block_count > max_block_count) { - capped_block_count /= 2; - } - if (capped_block_count < result.block_count) { - result.block_count = capped_block_count; - VLOG(2) << "Update # of blocks to " << result.block_count - << " as few_waves is enabled."; - } - } else { - int64_t capped_threads_per_block_x = - std::min(result.threads_per_block_x, 128); - int64_t capped_block_count = - gpu_device_info.core_count() * - (gpu_device_info.threads_per_core_limit() / - (capped_threads_per_block_x * result.threads_per_block_y)); - if (capped_block_count < result.block_count) { - result.threads_per_block_x = capped_threads_per_block_x; - result.block_count = capped_block_count; - VLOG(2) << "Update the # of blocks to " << result.block_count - << " and the # of threads per blocks to " - << result.threads_per_block_x - << " as the few_waves mode is enabled."; - } - } - } - return result; -} - -} // namespace - LaunchDimensions CalculateLaunchDimensions( const Shape& shape, const se::DeviceDescription& gpu_device_info, LaunchDimensionsConfig dim_config) { @@ -173,12 +35,13 @@ LaunchDimensions CalculateLaunchDimensions( return LaunchDimensions(); } num_elements = CeilOfRatio(num_elements, int64_t{dim_config.unroll_factor}); - BlockSizes sizes = - GetBlockSizes(dim_config, gpu_device_info, shape, num_elements); - return LaunchDimensions( - se::BlockDim(sizes.block_count, 1, 1), - se::ThreadDim(sizes.threads_per_block_x, sizes.threads_per_block_y, 1)); + const int kWarpSchedulers = 4; + int64_t threads_per_block = std::min( + gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); + int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block); + return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), + se::ThreadDim(threads_per_block, 1, 1)); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.h b/third_party/xla/xla/service/gpu/launch_dimensions.h index e0c53f9b266f4c..7295048fcdd45c 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.h +++ b/third_party/xla/xla/service/gpu/launch_dimensions.h @@ -85,17 +85,6 @@ struct LaunchDimensionsConfig { // The kernel implementation will be unrolled if `unroll_factor` is // greater than one. int unroll_factor = 1; - // A wave is a group of blocks that execute at the same time on the - // GPU. If there are more blocks then the number that can run - // concurrently, there are multiple waves of blocks running - // sequentially. If `few_waves` is true, each thread will loop over - // a block of unroll_factor elements. Otherwise each thread will - // handle only unroll_factor. - bool few_waves = false; - // If `row_vectorized` is true, then the block size will equal to - // `hlo.shape().dimensions().back()/unroll_factor`. - // Currently few_waves and row_vectorized do not work together. - bool row_vectorized = false; }; // Returns -1 if the shape doesn't allow the row vectorization code path. diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD index 8951c2719cb290..8fc3db56945a8e 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/BUILD @@ -2,6 +2,10 @@ load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured", ) +load( + "@local_config_sycl//sycl:build_defs.bzl", + "if_sycl_is_configured", +) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", @@ -88,6 +92,8 @@ cc_library( "@local_config_rocm//rocm:rocm_headers", "@llvm-project//llvm:AMDGPUCodeGen", "@llvm-project//llvm:AMDGPUAsmParser", + ]) + if_sycl_is_configured([ + "@spirv_llvm_translator//:spirv_llvm_translator", ]), ) @@ -106,3 +112,16 @@ xla_cc_test( "@local_tsl//tsl/platform:test", ], ) + +xla_cc_test( + name = "gpu_backend_lib_test", + size = "small", + srcs = ["gpu_backend_lib_test.cc"], + deps = [ + ":llvm_gpu_backend", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + "@llvm-project//llvm:Core", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index c64c549a730b77..fb0d9c65f5bf5f 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -105,6 +106,11 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #endif +#if TENSORFLOW_USE_SYCL +#include "LLVMSPIRVLib.h" +#include "LLVMSPIRVOpts.h" +#endif // TENSORFLOW_USE_SYCL + namespace xla { namespace gpu { namespace { @@ -117,41 +123,6 @@ const int kAMDGPUInlineThreshold = 0x100000; // Default inline threshold value to use in llvm. const int kDefaultInlineThreshold = 1100; -// Gets the GPU name as it's known to LLVM for a given compute -// capability. If we see an unrecognized compute capability, we -// return the highest one that is known and below the selected device. -static std::string GetSmName(se::CudaComputeCapability compute_capability) { - int compute_capability_version = - compute_capability.major * 10 + compute_capability.minor; - int sm_version = 30; - // If the current compute capability isn't known, fallback to the - // most recent version before it. - int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, - 61, 60, 53, 52, 50, 37, 35, 32, 30}; - for (int v : supported_versions) { - if (v <= compute_capability_version) { - sm_version = v; - break; - } - } - - // If the current CC isn't supported by LLVM and it is newer then - // the max supported LLVM version, do not warn about it. The end - // user can't do anything about this. E.g., PTX compiled for SM75 will - // run on SM80 too. - if (sm_version != compute_capability_version && - compute_capability_version < supported_versions[0]) { - LOG(WARNING) << "Unknown compute capability " - << compute_capability.ToString() - << ". Defaulting to telling LLVM that we're compiling for sm_" - << sm_version; - } - // If the target is sm_90, hard code it to sm_90a so that all instructions - // can be used. We don't need the portability that sm_90 gives. - std::string_view extension = sm_version == 90 ? "a" : ""; - return absl::StrCat("sm_", sm_version, extension); -} - // NOLINTBEGIN: clang-diagnostic-unused-function // Convenience function for producing a name of a temporary compilation product // from the input filename. @@ -378,7 +349,7 @@ std::unique_ptr NVPTXGetTargetMachine( #else std::string feature_str; #endif // GOOGLE_CUDA - return GetTargetMachine(target_triple, GetSmName(compute_capability), + return GetTargetMachine(target_triple, nvptx::GetSmName(compute_capability), debug_options, feature_str); } @@ -452,7 +423,9 @@ absl::Status LinkAndOptimizeModule( llvm::CGSCCAnalysisManager cgam; llvm::ModuleAnalysisManager mam; - fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); }); + if (target_machine) { + fam.registerPass([&] { return target_machine->getTargetIRAnalysis(); }); + } llvm::PipelineTuningOptions pto; pto.SLPVectorization = true; @@ -569,6 +542,40 @@ void NVPTXBackendInit(const DebugOptions& debug_options) { namespace nvptx { +std::string GetSmName(se::CudaComputeCapability compute_capability) { + int compute_capability_version = + compute_capability.major * 10 + compute_capability.minor; + int sm_version = 30; + // If the current compute capability isn't known, fallback to the + // most recent version before it. + int supported_versions[] = {90, 89, 87, 86, 80, 75, 72, 70, 62, + 61, 60, 53, 52, 50, 37, 35, 32, 30}; + for (int v : supported_versions) { + if (v <= compute_capability_version) { + sm_version = v; + break; + } + } + + // If the current CC isn't supported by LLVM and it is newer then + // the max supported LLVM version, do not warn about it. The end + // user can't do anything about this. E.g., PTX compiled for SM75 will + // run on SM80 too. + if (sm_version != compute_capability_version && + compute_capability_version < supported_versions[0]) { + LOG(WARNING) << "Unknown compute capability " + << compute_capability.ToString() + << ". Defaulting to telling LLVM that we're compiling for sm_" + << sm_version; + } + // On Hopper, default to sm_90a so that all instructions can be used. But + // only sm_90 is forward compatible, so don't use sm_90a with newer hardware: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility + std::string_view extension = + (compute_capability.major == 9 && sm_version == 90) ? "a" : ""; + return absl::StrCat("sm_", sm_version, extension); +} + std::string CantFindCudaMessage(absl::string_view msg, absl::string_view xla_gpu_cuda_data_dir) { return absl::StrCat( @@ -857,15 +864,10 @@ absl::StatusOr> EmitModuleToHsaco( // Locate lld. std::string lld_path; if (std::getenv("LLVM_PATH")) { - lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); - } - else if (std::getenv("ROCM_PATH")) { - lld_path = tsl::io::JoinPath(std::getenv("ROCM_PATH"), "llvm/bin"); - } - else { - lld_path = tsl::io::JoinPath("/opt/rocm", "llvm/bin"); + lld_path = tsl::io::JoinPath(std::getenv("LLVM_PATH"), "bin"); + } else { + lld_path = tsl::io::JoinPath(tsl::RocmRoot(), "llvm/bin"); } - auto lld_program = llvm::sys::findProgramByName("ld.lld", {lld_path}); if (!lld_program) { return xla::Internal("unable to find ld.lld in PATH: %s", @@ -1151,5 +1153,95 @@ absl::StatusOr> CompileToHsaco( } // namespace amdgpu +namespace { + +std::unique_ptr SPIRGetTargetMachine( + llvm::Triple target_triple, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { + return nullptr; +} + +absl::Status SPIRTargetModuleLinker( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options, + const std::string& device_bitcode_dir_path) { + return absl::OkStatus(); +} + +absl::StatusOr EmitModuleToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { +#if TENSORFLOW_USE_SYCL + SPIRV::TranslatorOpts::ExtensionsStatusMap ExtensionsStatus; + SPIRV::TranslatorOpts opts(SPIRV::VersionNumber::MaximumVersion, + ExtensionsStatus); + opts.enableAllExtensions(); // enable all SPIR-V extension first + + std::ostringstream oss; + std::string err; + bool success = llvm::writeSpirv(module, opts, oss, err); + if (!success) { + return xla::Internal("Fails to convert LLVM as SPIR-V: %s", err); + } + return oss.str(); +#else + return absl::UnimplementedError("Not implemented for SYCL"); +#endif +} + +void SPIRBackendInit(const DebugOptions& debug_options) { + FeedLLVMWithFlags({ + "-slp-vectorize-hor=false", + "-slp-min-reg-size=64", + "-slp-max-reg-size=64", + }); + + llvm_ir::InitializeLLVMCommandLineOptions( + debug_options.xla_backend_extra_options()); + + llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry(); + InitializePasses(registry); +} + +} // namespace + +namespace spir { + +absl::StatusOr> CompileToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options) { + std::string libdevice_dir_path; + static absl::once_flag backend_init_flag; + absl::call_once(backend_init_flag, SPIRBackendInit, debug_options); + + std::string spir; + { + XLA_SCOPED_LOGGING_TIMER("Compile module " + module->getName().str()); + + // If the module has no functions or globals, there's nothing to compile. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << module->getName().str() + << "' is empty. Skipping compilation."; + return std::vector(); + } + + llvm::Triple default_target_triple("spir64-unknown-unknown"); + std::unique_ptr target_machine = + SPIRGetTargetMachine(default_target_triple, gpu_version, debug_options); + + TF_RETURN_IF_ERROR(LinkAndOptimizeModule( + module, gpu_version, debug_options, libdevice_dir_path, + SPIRTargetModuleLinker, default_target_triple, target_machine.get(), + kDefaultInlineThreshold)); + + // Lower optimized LLVM module to SPIR. + TF_ASSIGN_OR_RETURN(spir, + EmitModuleToSpir(module, gpu_version, debug_options)); + } + return std::vector(spir.begin(), spir.end()); +} + +} // namespace spir + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 3ab5d6d84db1b3..1814291beae184 100644 --- a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -37,6 +37,11 @@ namespace gpu { namespace nvptx { +// Gets the GPU name as it's known to LLVM for a given compute +// capability. If we see an unrecognized compute capability, we +// return the highest one that is known and below the selected device. +std::string GetSmName(se::CudaComputeCapability compute_capability); + std::string CantFindCudaMessage(absl::string_view msg, absl::string_view xla_gpu_cuda_data_dir); @@ -73,6 +78,13 @@ absl::StatusOr> CompileToHsaco( const std::string& module_config_cache_key); } // namespace amdgpu +namespace spir { +// Compiles the argument module and returns it. +absl::StatusOr> CompileToSpir( + llvm::Module* module, se::GpuComputeCapability gpu_version, + const DebugOptions& debug_options); +} // namespace spir + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc new file mode 100644 index 00000000000000..9e65f34a296cb6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib_test.cc @@ -0,0 +1,38 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" + +#include "xla/stream_executor/device_description.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { +namespace se = ::stream_executor; + +TEST(UtilsTest, TestGetSmName) { + se::CudaComputeCapability cc_hopper(9, 0); + ASSERT_EQ(nvptx::GetSmName(cc_hopper), "sm_90a"); + // Do not default to sm90_a after Hopper, because it is not forward + // compatible. + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#ptx-compatibility + se::CudaComputeCapability cc_next(10, 0); + ASSERT_EQ(nvptx::GetSmName(cc_next), "sm_90"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/matmul_utils.cc b/third_party/xla/xla/service/gpu/matmul_utils.cc index fe4982e9a223b9..49270de65ecd3f 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.cc +++ b/third_party/xla/xla/service/gpu/matmul_utils.cc @@ -456,7 +456,11 @@ bool IsTf32Allowed(PrecisionConfig::Algorithm algorithm, const HloInstruction* gemm) { TF_ASSIGN_OR_RETURN(GpuBackendConfig gpu_config, gemm->backend_config()); - const GemmBackendConfig& config = gpu_config.gemm_backend_config(); + return For(gemm, gpu_config.gemm_backend_config()); +} + +/*static*/ absl::StatusOr GemmConfig::For( + const HloInstruction* gemm, const GemmBackendConfig& config) { std::optional algorithm; if (config.algorithm_case() != GemmBackendConfig::ALGORITHM_NOT_SET) { algorithm = config.selected_algorithm(); diff --git a/third_party/xla/xla/service/gpu/matmul_utils.h b/third_party/xla/xla/service/gpu/matmul_utils.h index 22d7f178133835..5f128e418af58c 100644 --- a/third_party/xla/xla/service/gpu/matmul_utils.h +++ b/third_party/xla/xla/service/gpu/matmul_utils.h @@ -108,6 +108,11 @@ struct GemmConfig : public se::gpu::GemmConfig { static absl::StatusOr For(const HloInstruction* gemm); + // Gets the GemmConfig of the `gemm` instruction with overridden + // GemmBackendConfig. + static absl::StatusOr For(const HloInstruction* gemm, + const GemmBackendConfig& config); + static absl::StatusOr For( const Shape& lhs_shape, absl::Span lhs_batch_dims, absl::Span lhs_contracting_dims, const Shape& rhs_shape, diff --git a/third_party/xla/xla/service/gpu/model/BUILD b/third_party/xla/xla/service/gpu/model/BUILD index 3c9a4d23d023df..48e3b1ccafb363 100644 --- a/third_party/xla/xla/service/gpu/model/BUILD +++ b/third_party/xla/xla/service/gpu/model/BUILD @@ -350,6 +350,7 @@ cc_library( ":indexing_analysis", ":symbolic_tile_analysis", ":tiled_hlo_computation", + ":triton_emitter_constraints", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", @@ -475,7 +476,6 @@ cc_library( "//xla/service:gather_simplifier", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:matmul_utils", - "//xla/service/gpu/fusions:tiling_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", @@ -549,7 +549,6 @@ xla_cc_test( ":indexing_test_utils", "//xla/hlo/ir:hlo", "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu/fusions:tiling_util", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings:string_view", @@ -734,6 +733,7 @@ xla_cc_test( ":indexing_test_utils", ":symbolic_tile", ":symbolic_tile_analysis", + ":symbolic_tiled_hlo_instruction", ":tiled_hlo_computation", ":tiled_hlo_instruction", "//xla:util", @@ -742,6 +742,7 @@ xla_cc_test( "//xla/service/gpu:hlo_traversal", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -749,11 +750,44 @@ xla_cc_test( "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:IR", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", ], ) +cc_library( + name = "triton_emitter_constraints", + srcs = ["triton_emitter_constraints.cc"], + hdrs = ["triton_emitter_constraints.h"], + deps = [ + ":affine_map_evaluator", + ":symbolic_tile_analysis", + ":symbolic_tiled_hlo_instruction", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +xla_cc_test( + name = "triton_emitter_constraints_test", + srcs = ["triton_emitter_constraints_test.cc"], + deps = [ + ":symbolic_tile_analysis", + ":triton_emitter_constraints", + "//xla/hlo/ir:hlo", + "//xla/service:instruction_fusion", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + cc_library( name = "coalescing_analysis", srcs = ["coalescing_analysis.cc"], diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 11ebb82a94476a..aefe84294472a2 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -52,7 +52,7 @@ class CoalescingTest : public HloTestBase { std::vector IsReadCoalescedPerOperand(const HloInstruction* root) { auto fusion_adaptor = HloFusionAdaptor::ForInstruction(root); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); auto fusion = dynamic_cast(emitter.get()); EXPECT_NE(fusion, nullptr); @@ -71,7 +71,7 @@ class CoalescingTest : public HloTestBase { bool IsReadCoalescedHeuristic(absl::string_view hlo_string) { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* root = module->entry_computation()->root_instruction(); - auto analysis = AnalyzeFusion(*root, device_info_); + auto analysis = HloFusionAnalysis::Create(*root, device_info_); return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(), root->operand(0), root); } @@ -167,13 +167,13 @@ TEST_F(CoalescingTest, Transpose) { HloModule module fusion { - %input = f32[100, 64, 32] parameter(0) - ROOT transpose = f32[32, 100, 64] transpose(%input), dimensions={2, 0, 1} + %input = f32[1, 6400, 32] parameter(0) + ROOT transpose = f32[1, 32, 6400] transpose(%input), dimensions={0, 2, 1} } ENTRY entry { - %input = f32[100, 64, 32] parameter(0) - ROOT %fusion = f32[32, 100, 64] fusion(%input), kind=kLoop, calls=fusion + %input = f32[1, 6400, 32] parameter(0) + ROOT %fusion = f32[1, 32, 6400] fusion(%input), kind=kLoop, calls=fusion })"; // thread_x to linearized input mapping for thread_x in [0, 31]: // Operand 1: (thread_x)[s0] -> (thread_x + s0 * 128) for s0 in [0, 7] @@ -185,15 +185,15 @@ TEST_F(CoalescingTest, TransposeOfBroadcastHeuristic) { HloModule module fusion { - input = f32[32, 100, 64] parameter(0) - ROOT slice = f32[32, 100, 1] slice(input), slice={[0:32:1], [0:100:1], [0:1:1]} + input = f32[1, 32, 6400] parameter(0) + ROOT slice = f32[1, 32, 100] slice(input), slice={[0:1:1], [0:32:1], [0:6400:64]} } ENTRY entry { p0 = f32[32] parameter(0) - broadcast = f32[100, 64, 32] broadcast(p0), dimensions={2} - transpose = f32[32, 100, 64] transpose(broadcast), dimensions={2, 0, 1} - ROOT %fusion = f32[32, 100, 1] fusion(transpose), kind=kLoop, calls=fusion + broadcast = f32[1, 6400, 32] broadcast(p0), dimensions={2} + transpose = f32[1, 32, 6400] transpose(broadcast), dimensions={0, 2, 1} + ROOT %fusion = f32[1, 32, 100] fusion(transpose), kind=kLoop, calls=fusion })"; EXPECT_TRUE(IsReadCoalescedHeuristic(ir)); } diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc index ba033fb74f9d89..5e80fe7bca8b7e 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.cc @@ -33,7 +33,8 @@ const HloFusionAnalysis& HloFusionAnalysisCache::Get( } } - HloFusionAnalysis analysis = AnalyzeFusion(instruction, device_info_); + HloFusionAnalysis analysis = + HloFusionAnalysis::Create(instruction, device_info_); absl::MutexLock lock(&mutex_); // If some other thread created an entry for this key concurrently, return @@ -59,7 +60,7 @@ const HloFusionAnalysis& HloFusionAnalysisCache::Get( } HloFusionAnalysis analysis = - AnalyzeProducerConsumerFusion(producer, consumer, device_info_); + HloFusionAnalysis::Create(producer, consumer, device_info_); absl::MutexLock lock(&mutex_); // If some other thread created an entry for this key concurrently, return @@ -78,7 +79,6 @@ const HloFusionAnalysis& HloFusionAnalysisCache::Get( } void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); analyses_.erase(instruction.unique_id()); if (auto consumers = @@ -96,8 +96,6 @@ void HloFusionAnalysisCache::Invalidate(const HloInstruction& instruction) { } void HloFusionAnalysisCache::Clear() { - absl::MutexLock lock(&mutex_); - analyses_.clear(); producer_consumer_analyses_.clear(); consumers_for_producers_.clear(); diff --git a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h index 4cf6053e03fed6..9eacee0a933aad 100644 --- a/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h +++ b/third_party/xla/xla/service/gpu/model/fusion_analysis_cache.h @@ -28,9 +28,9 @@ limitations under the License. namespace xla::gpu { -// Caches HloFusionAnalyses. Thread-compatible, if no threads concurrently `Get` -// and `Invalidate` the same key. Analyses are cached based on unique_ids, no -// checking or tracking of changes is done. +// Caches HloFusionAnalyses. `Get` can be called concurrently, but `Invalidate` +// and `Clear` shouldn't. Analyses are cached based on unique_ids, no checking +// or tracking of changes is done. class HloFusionAnalysisCache { public: explicit HloFusionAnalysisCache( diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc index f04771f789691a..aad3343260c945 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.cc @@ -111,8 +111,11 @@ float GetMaxSysBwFromGpu(const se::CudaComputeCapability cc, return bandwidths_table[1]; case se::CudaComputeCapability::HOPPER: return bandwidths_table[2]; + case se::CudaComputeCapability::BLACKWELL: + return bandwidths_table[3]; + default: + return bandwidths_table[4]; } - return -1; } } // namespace @@ -133,7 +136,7 @@ float GpuPerformanceWithCollectiveModel::GetNvlinkBw( } /*static*/ bool GpuPerformanceWithCollectiveModel::InitNvml() { -#if GOOGLE_CUDA +#if GOOGLE_CUDA && (defined(PLATFORM_POSIX) || defined(PLATFORM_GOOGLE)) void* libhandle = dlopen("libnvidia-ml.so.1", RTLD_NOW); CHECK(libhandle != nullptr) << "Failed to open libnvidia-ml.so.1"; @@ -189,7 +192,8 @@ GpuPerformanceWithCollectiveModel::CheckIfNvlinkSupportsP2P() { nvmlReturn_t nvlink_cap_result = xla_nvmlDeviceGetNvLinkCapability( nvml_device, /*nvlink link number*/ 0, NVML_NVLINK_CAP_P2P_SUPPORTED, &supported_p2p); - CHECK(nvlink_cap_result == NVML_SUCCESS); + CHECK(nvlink_cap_result == NVML_SUCCESS || + nvlink_cap_result == NVML_ERROR_NOT_SUPPORTED); CHECK(ShutdownNvml()) << "NVML shutdown failed."; return supported_p2p; #else diff --git a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h index c11a78c684e80d..e1bcff0b5023dd 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_collective_performance_model.h @@ -26,7 +26,9 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #if GOOGLE_CUDA +#if defined(PLATFORM_POSIX) || defined(PLATFORM_GOOGLE) #include +#endif #include "third_party/gpus/cuda/nvml/include/nvml.h" // Below is a list of function pointers to be used @@ -57,16 +59,16 @@ class GpuPerformanceWithCollectiveModel : public GpuPerformanceModelBase { // Table for max system bandwidths GB/s for using NCCL's low latency // algorithm. This is used for intra-node estimate. - static constexpr std::array kLowLatencyMaxBandwidths = { - 39.0 /* Volta*/, 87.7 /* Ampere*/, 87.7 /* Hopper*/ + static constexpr std::array kLowLatencyMaxBandwidths = { + 39.0 /* Volta */, 87.7 /* Ampere */, 141.0 /* Hopper */, + 141.0 /* Blackwell */, 141.0 /* next-gen */, }; // Max bandwidth in GB/s for ring low latency 128 algorithm per channel on a // single-node - static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { - 20.0 /* Volta */, - 20.0 /* Ampere */, - 36.7 /* Hopper */, + static constexpr std::array kPerChannelMaxRingLL128Bandwidths = { + 20.0 /* Volta */, 20.0 /* Ampere */, 36.7 /* Hopper */, + 36.7 /* Blackwell */, 36.7 /* next-gen */, }; // Nvlink unidirectional bandwidth for different compute cap. Note this is per diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 49b914eb19cc17..3b0cae33ea7acc 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/model/indexing_map.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -250,7 +251,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction( /*exec_time=*/absl::ZeroDuration()}; } - auto fusion_analysis = AnalyzeFusion(*producer, *device_info_); + auto fusion_analysis = HloFusionAnalysis::Create(*producer, *device_info_); bool is_coalesced = IsReadCoalescedHeuristic( fusion_analysis.GetEmitterFusionKind(), producer); @@ -261,7 +262,7 @@ EstimateRunTimeData GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer( const HloInstruction* producer, const HloInstruction* consumer) { auto fusion_analysis = - AnalyzeProducerConsumerFusion(*producer, *consumer, *device_info_); + HloFusionAnalysis::Create(*producer, *consumer, *device_info_); bool is_coalesced = IsReadCoalescedHeuristic( fusion_analysis.GetEmitterFusionKind(), producer, consumer); @@ -369,7 +370,9 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTiledFusion( absl::Span tile_sizes) { // TODO(b/332714755): Add caching for SymbolicTileAnalysis. SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + fusion_adaptor, mlir_context_, + TritonEmitterConstraints::GetBuilder()); if (const auto* fusion_decision = std::get_if(&analysis_or_error)) { return absl::FailedPreconditionError(absl::StrCat( @@ -429,7 +432,9 @@ absl::StatusOr GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( const HloFusionAdaptor& fusion_adaptor) { SymbolicTileAnalysisOrError analysis_or_error = - SymbolicTileAnalysis::AnalyzeFusion(fusion_adaptor, mlir_context_); + SymbolicTileAnalysis::AnalyzeFusion( + fusion_adaptor, mlir_context_, + TritonEmitterConstraints::GetBuilder()); if (const auto* fusion_decision = std::get_if(&analysis_or_error)) { diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc index c2057e0da5d59c..6bb4071bcc1fab 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model.cc @@ -56,7 +56,7 @@ GpuPerformanceModel::EstimateRunTimeForInstruction( // TODO(jreiffers): Remove this once all callers use a cache. std::optional local_analysis; if (!config.fusion_analysis_cache) { - local_analysis = AnalyzeFusion(*instr, device_info); + local_analysis = HloFusionAnalysis::Create(*instr, device_info); } const auto& fusion_analysis = config.fusion_analysis_cache ? config.fusion_analysis_cache->Get(*instr) @@ -144,7 +144,7 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( // TODO(jreiffers): Remove this once all callers use a cache. std::optional local_analysis; if (!config.fusion_analysis_cache) { - local_analysis = AnalyzeFusion(*fused_consumer, device_info); + local_analysis = HloFusionAnalysis::Create(*fused_consumer, device_info); } const auto& analysis_unfused = config.fusion_analysis_cache @@ -193,7 +193,7 @@ absl::Duration GpuPerformanceModel::EstimateUnfusedExecTime( std::optional local_analysis_fused; if (!config.fusion_analysis_cache) { local_analysis_fused = - AnalyzeProducerConsumerFusion(*producer, *consumer, device_info); + HloFusionAnalysis::Create(*producer, *consumer, device_info); } const auto& fusion_analysis = config.fusion_analysis_cache @@ -296,8 +296,8 @@ absl::Duration GpuPerformanceModel::EstimateFusedExecTime( std::optional local_analysis_fused; if (!config.fusion_analysis_cache) { - local_analysis_fused = AnalyzeProducerConsumerFusion( - *producer, *fused_consumer, device_info); + local_analysis_fused = + HloFusionAnalysis::Create(*producer, *fused_consumer, device_info); } const auto& analysis_fused = config.fusion_analysis_cache @@ -345,8 +345,9 @@ GpuPerformanceModel::EstimateRunTimesForPriorityFusion( const GpuPerformanceModelOptions& config, absl::Span fused_consumers, bool multi_output) { - EstimateRunTimeData producer_runtime = EstimateRunTimeForInstructionCached( - producer, device_info, cost_analysis, config); + auto cache_result = config.gpu_performance_model_cache->Get(*producer); + CHECK(cache_result.has_value()); + EstimateRunTimeData producer_runtime = *cache_result; absl::Duration time_unfused = kKernelLaunchOverhead * (fused_consumers.size() + 1) + @@ -357,8 +358,10 @@ GpuPerformanceModel::EstimateRunTimesForPriorityFusion( for (auto fused_consumer : fused_consumers) { VLOG(8) << "Fused consumer: " << fused_consumer->name(); - EstimateRunTimeData consumer_runtime = EstimateRunTimeForInstructionCached( - fused_consumer, device_info, cost_analysis, config); + auto cache_result = + config.gpu_performance_model_cache->Get(*fused_consumer); + CHECK(cache_result.has_value()); + EstimateRunTimeData consumer_runtime = *cache_result; time_unfused += consumer_runtime.exec_time; diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc index 56e34c30b963dc..08ae2e5d173b64 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base.cc @@ -88,8 +88,6 @@ float AdjustBandwidth(const se::DeviceDescription& gpu_device_info, std::optional GpuPerformanceModelCache::Get( const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - auto it = instruction_runtime_data_.find(&instruction); if (it != instruction_runtime_data_.end()) { return it->second; @@ -113,8 +111,6 @@ std::optional GpuPerformanceModelCache::Get( void GpuPerformanceModelCache::Set(const HloInstruction& instruction, const EstimateRunTimeData& runtime_data) { - absl::MutexLock lock(&mutex_); - instruction_runtime_data_[&instruction] = runtime_data; } @@ -126,8 +122,6 @@ void GpuPerformanceModelCache::Set(const HloInstruction& producer, } void GpuPerformanceModelCache::Invalidate(const HloInstruction& instruction) { - absl::MutexLock lock(&mutex_); - // Remove runtime data for the instruction. instruction_runtime_data_.erase(&instruction); diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc index a3eb17042cb1e7..0ece419e1b009f 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_base_test.cc @@ -211,13 +211,13 @@ ENTRY entry_computation { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - auto fusion_analysis = AnalyzeFusion( + auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); auto launch_dimensions = GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis); - EXPECT_EQ(launch_dimensions.num_blocks(), 16); - EXPECT_EQ(launch_dimensions.num_threads_per_block(), 1024); + EXPECT_EQ(launch_dimensions.num_blocks(), 128); + EXPECT_EQ(launch_dimensions.num_threads_per_block(), 128); } TEST_F(GpuPerformanceModelBaseTest, @@ -247,7 +247,7 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - auto fusion_analysis = AnalyzeFusion( + auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); auto launch_dimensions = GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis); @@ -276,7 +276,7 @@ ENTRY e { TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_string)); - auto fusion_analysis = AnalyzeFusion( + auto fusion_analysis = HloFusionAnalysis::Create( *module->entry_computation()->root_instruction(), device_info_); auto launch_dimensions = GpuPerformanceModelBase::EstimateFusionLaunchDimensions(fusion_analysis); diff --git a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc index 4c0c35e1a9e285..2335906b13f544 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_performance_model_test.cc @@ -68,9 +68,19 @@ class GpuPerformanceModelTest : public HloTestBase { GpuPerformanceModel::RunTimes EstimateRunTimesForPriorityFusion( const HloInstruction* producer, std::vector fused_consumers = {}) { + auto config = GpuPerformanceModelOptions::PriorityFusion( + &fusion_analysis_cache_, &gpu_performance_model_cache_); + + auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction( + producer, device_info_, &analysis_, config); + gpu_performance_model_cache_.Set(*producer, runtime_data); + for (auto consumer : fused_consumers) { + auto runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction( + consumer, device_info_, &analysis_, config); + gpu_performance_model_cache_.Set(*consumer, runtime_data); + } return GpuPerformanceModel::EstimateRunTimesForPriorityFusion( - producer, device_info_, &analysis_, - GpuPerformanceModelOptions::PriorityFusion(), fused_consumers); + producer, device_info_, &analysis_, config, fused_consumers); } mlir::MLIRContext mlir_context_; @@ -82,6 +92,7 @@ class GpuPerformanceModelTest : public HloTestBase { se::DeviceDescription device_info_{TestGpuDeviceInfo::RTXA6000DeviceInfo()}; HloFusionAnalysisCache fusion_analysis_cache_{device_info_}; GpuHloCostAnalysis analysis_{options_, device_info_}; + GpuPerformanceModelCache gpu_performance_model_cache_; GpuPerformanceModelWithIndexingAnalysis indexing_cost_model_{ &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), @@ -150,7 +161,7 @@ ENTRY e { auto reification_cost = root->backend_config() ->fusion_backend_config() .reification_cost(); - EXPECT_NEAR(reification_cost.end_to_end_cycles(), 257.7, 0.1); + EXPECT_NEAR(reification_cost.end_to_end_cycles(), 38.4, 0.1); EXPECT_NEAR(reification_cost.exec_time_us(), 0, 1); auto indexing_t = indexing_cost_model_.EstimateRunTimes(root); @@ -674,16 +685,16 @@ add { } fused_computation.0 { - p0 = f32[4,28672,32] parameter(0) - tanh = f32[4,28672,32] tanh(p0) + p0 = f32[4,256,32] parameter(0) + tanh = f32[4,256,32] tanh(p0) c1 = f32[] constant(72) - broadcast = f32[4,28672,32] broadcast(c1), dimensions={} - ROOT mul = f32[4,28672,32] multiply(tanh, broadcast) + broadcast = f32[4,256, 32] broadcast(c1), dimensions={} + ROOT mul = f32[4,256,32] multiply(tanh, broadcast) } ENTRY fusion { - p0 = f32[4,28672,32] parameter(0) - fusion = f32[4,28672,32] fusion(p0), kind=kLoop, calls=fused_computation.0 + p0 = f32[4,256,32] parameter(0) + fusion = f32[4,256,32] fusion(p0), kind=kLoop, calls=fused_computation.0 c0 = f32[] constant(0) ROOT reduce = f32[4,32] reduce(fusion, c0), to_apply=add, dimensions={1} })"; diff --git a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc index 6a8ed6538e8edb..550e9e7a31ffdf 100644 --- a/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc +++ b/third_party/xla/xla/service/gpu/model/hlo_op_profiler_test.cc @@ -39,7 +39,7 @@ TEST_F(HloOpProfilerTest, BasicMeasurementsAreCorrect) { EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kDivide, F64) .value() .clock_cycles(), - 400); + 300); // c128 sqrt is slow. EXPECT_GT(profiler.MeasureClockCyclesPerOp(HloOpcode::kSqrt, C128) .value() diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc index 89124182909aca..e8842b35cdd70c 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.cc @@ -48,7 +48,6 @@ limitations under the License. #include "xla/layout.h" #include "xla/permutation_util.h" #include "xla/service/gather_simplifier.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/model/affine_map_printer.h" @@ -346,6 +345,7 @@ HloInstructionIndexing ComputeOutputToInputDynamicUpdateSliceOpIndexing( // operand: (d0, ... d_{N-1}) -> (d0, ... d_{N-1}) std::vector identity; + identity.reserve(rank); for (int64_t dim = 0; dim < rank; ++dim) { identity.push_back(getAffineDimExpr(dim, mlir_context)); } @@ -1141,13 +1141,6 @@ std::vector ToTransposeDimensions(const Layout& l) { return out; } -AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, - int64_t num_symbols) { - return AffineMap::get( - /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs, - exprs[0].getContext()); -} - } // namespace IndexingMap CreateIdentityMap(const Shape& shape, MLIRContext* mlir_context) { @@ -1218,83 +1211,6 @@ IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( shape.dimensions(), {}); } -AffineMap GetBlockOffsetsForTiling( - absl::Span num_blocks, - absl::Span tile_sizes_per_block, int64_t rank, - MLIRContext* mlir_context) { - auto offsets = - DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks); - for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) { - offset = offset * tile_size; - } - return GetTilingAffineMap(offsets, rank); -} - -AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetBlockOffsetsForTiling(tiling.GetBlockCounts(), - tiling.GetBlockTileSize(), - tiling.GetShape().size(), mlir_context); -} - -AffineMap GetThreadOffsetsForTiling( - absl::Span num_threads, - absl::Span tile_sizes_per_thread, int64_t rank, - MLIRContext* mlir_context) { - auto offsets = - DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads); - for (int dim = 0; dim < rank; ++dim) { - if (tile_sizes_per_thread[dim] > 1) { - offsets[dim] = offsets[dim] + - getAffineSymbolExpr(dim, mlir_context) * num_threads[dim]; - } - } - return GetTilingAffineMap(offsets, rank); -} - -AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(), - tiling.GetThreadTileSize(), - tiling.GetShape().size(), mlir_context); -} - -IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetIndexingMapForTiling( - GetBlockOffsetsForTiling(tiling, mlir_context), - GetThreadOffsetsForTiling(tiling, mlir_context), - tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(), - tiling.GetThreadTileSize(), tiling.GetShape()); -} - -IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, - AffineMap thread_offsets, - int64_t threads_per_block, - int64_t num_blocks, - absl::Span thread_tile_sizes, - absl::Span tiled_shape) { - auto* mlir_context = block_offsets.getContext(); - llvm::SmallVector offsets; - offsets.reserve(block_offsets.getNumResults()); - for (auto [block, thread] : - llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { - offsets.push_back(block + thread); - } - std::vector dimension_ranges{ - {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, - }; - auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), - block_offsets.getNumSymbols(), offsets, - mlir_context); - IndexingMap map{affine_map, dimension_ranges, - RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}}; - for (int i = 0; i < tiled_shape.size(); ++i) { - map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1}); - } - return map; -} - bool HloInstructionIndexing::Simplify() { bool any_simplified = false; for (auto& operand_indexing : indexing_maps) { diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis.h b/third_party/xla/xla/service/gpu/model/indexing_analysis.h index 201b8f66e636af..e475b5a6e0f95f 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis.h @@ -18,7 +18,6 @@ limitations under the License. #define XLA_SERVICE_GPU_MODEL_INDEXING_ANALYSIS_H_ #include -#include #include #include #include @@ -31,7 +30,6 @@ limitations under the License. #include "mlir/IR/AffineMap.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/affine_map_printer.h" #include "xla/service/gpu/model/indexing_map.h" @@ -145,35 +143,6 @@ IndexingMap GetIndexingMapFromPhysicalLayoutToLogical( IndexingMap GetIndexingMapFromLogicalToPhysicalLayout( const Shape& shape, mlir::MLIRContext* mlir_context); -// Creates an indexing map from thread and block IDs to elements of the tiled -// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 -// are thread indices (currently only 0 is used), dimensions 3 to 5 are block -// indices (currently only 3 is used). -mlir::AffineMap GetBlockOffsetsForTiling( - absl::Span num_blocks, - absl::Span tile_sizes_per_block, int64_t rank, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetThreadOffsetsForTiling( - absl::Span num_threads, - absl::Span tile_sizes_per_thread, int64_t rank, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); - -// Convenience functions for the two functions above -// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up -// the ranges of dimensions and symbols. -IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); -IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, - mlir::AffineMap thread_offsets, - int64_t threads_per_block, - int64_t num_blocks, - absl::Span thread_tile_sizes, - absl::Span tiled_shape); - // Returns the shape of the output of the instruction. const Shape& GetOutputShape(const HloInstruction* instr, int64_t output_id); diff --git a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc index 30fd8056697498..d30c963023b542 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_analysis_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "mlir/IR/MLIRContext.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/gpu/fusions/tiling_util.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/tests/hlo_test_base.h" @@ -2564,32 +2563,6 @@ TEST_F(IndexingAnalysisTest, FusionWithUnsupportedOp) { ElementsAre(UndefinedMap()), ElementsAre(UndefinedMap()))); } -TEST_F(IndexingAnalysisTest, TilingIndexing) { - Tiling tiling{/*shape=*/{1022, 256, 16}, - /*tile_sizes=*/{8, 1, 4}, - /*num_threads=*/{1, 4, 4}}; - auto indexing_map = GetIndexingMapForTiling(tiling, &mlir_context_); - indexing_map.Simplify(); - EXPECT_THAT(indexing_map.ToString(), MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - (d3 floordiv 64) * 8 + s0, - (d3 mod 64) * 4 + d0 floordiv 4, - d0 mod 4 + s2 * 4 - ) - domain: - d0 in [0, 15] - d1 in [0, 0] - d2 in [0, 0] - d3 in [0, 8191] - d4 in [0, 0] - d5 in [0, 0] - s0 in [0, 7] - s1 in [0, 0] - s2 in [0, 3] - (d3 floordiv 64) * 8 + s0 in [0, 1021] - )")); -} - TEST_F(IndexingAnalysisTest, EpilogueIndexing) { auto module = ParseAndReturnVerifiedModule(R"( HloModule m diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.cc b/third_party/xla/xla/service/gpu/model/indexing_map.cc index f0d34e8871e8f2..da21c3464b16b5 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.cc +++ b/third_party/xla/xla/service/gpu/model/indexing_map.cc @@ -1113,6 +1113,20 @@ SmallVector IndexingMap::Evaluate( return eval.getConstantResults(); } +bool IndexingMap::IsSymbolConstrained(int64_t symbol_id) const { + for (const auto& [expr, _] : constraints_) { + bool result = false; + expr.walk([&](mlir::AffineExpr leaf) { + auto sym = mlir::dyn_cast(leaf); + if (sym && sym.getPosition() == symbol_id) { + result = true; + } + }); + if (result) return true; + } + return false; +} + RangeEvaluator::RangeEvaluator(const IndexingMap& indexing_map, MLIRContext* mlir_context, bool use_constraints) : mlir_context_(mlir_context), diff --git a/third_party/xla/xla/service/gpu/model/indexing_map.h b/third_party/xla/xla/service/gpu/model/indexing_map.h index 478e0ecd371bc5..2e6cc1374f505b 100644 --- a/third_party/xla/xla/service/gpu/model/indexing_map.h +++ b/third_party/xla/xla/service/gpu/model/indexing_map.h @@ -369,6 +369,9 @@ class IndexingMap { llvm::ArrayRef dim_const_exprs, llvm::ArrayRef symbol_const_exprs) const; + // Returns true if there is a constraint on the given symbol. + bool IsSymbolConstrained(int64_t symbol_id) const; + // Returns true if the domain is empty. If it returns false, that does not // mean that the domain is not effectively empty. // For example, if there are two constraints 0 <= d0 mod 7 <= 0 and diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc index dc5c2f28c6f3f1..1ea107ce9903a8 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile.cc @@ -602,23 +602,8 @@ std::optional ExtractSizeAndStride( AffineExpr strided_indexing, absl::Span dimension_intervals, absl::Span symbol_intervals) { MLIRContext* ctx = strided_indexing.getContext(); - // Deal with the symbol case (capturing a whole untiled dimension). - // TODO(b/330906085): concatenating across a reduction dimension needs to be - // handled by this code. - if (auto symbol = llvm::dyn_cast(strided_indexing)) { - const Interval& symbol_interval = symbol_intervals[symbol.getPosition()]; - if (symbol_interval.lower != 0) { - return std::nullopt; - } - - return SizeAndStrideExpression( - /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx), - /*stride=*/getAffineConstantExpr(1, ctx)); - } - AffineMapPrinter printer; - // TODO(b/328427138): support multivariate size expressions. switch (strided_indexing.getKind()) { case AffineExprKind::DimId: return SizeAndStrideExpression(/*size=*/strided_indexing, @@ -626,23 +611,15 @@ std::optional ExtractSizeAndStride( case mlir::AffineExprKind::Mul: { const auto mul = llvm::cast(strided_indexing); AffineExpr lhs = mul.getLHS(); - // The stride may not be fully collapsed if it is negative; in that case, - // we need to extract the negative multiplier first. - if (const auto rhs = llvm::dyn_cast(mul.getRHS()); - rhs && rhs.getValue() == -1) { - std::optional maybe_size_and_stride = - ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals); - if (!maybe_size_and_stride.has_value()) { - return std::nullopt; - } - - return SizeAndStrideExpression( - /*size=*/maybe_size_and_stride->size, - /*stride=*/maybe_size_and_stride->stride * rhs); + std::optional maybe_size_and_stride = + ExtractSizeAndStride(lhs, dimension_intervals, symbol_intervals); + if (!maybe_size_and_stride.has_value()) { + return std::nullopt; } - CHECK(lhs.getKind() == AffineExprKind::DimId); - return SizeAndStrideExpression(/*size=*/lhs, - /*stride=*/mul.getRHS()); + + return SizeAndStrideExpression( + /*size=*/maybe_size_and_stride->size, + /*stride=*/maybe_size_and_stride->stride * mul.getRHS()); } case mlir::AffineExprKind::Mod: { auto mod = llvm::cast(strided_indexing); @@ -656,15 +633,18 @@ std::optional ExtractSizeAndStride( case mlir::AffineExprKind::Constant: return SizeAndStrideExpression(/*size=*/getAffineConstantExpr(1, ctx), /*stride=*/getAffineConstantExpr(0, ctx)); - case mlir::AffineExprKind::SymbolId: - VLOG(1) << "Encountered complex size expression involving symbol " - << printer.ToString(strided_indexing); - // It's currently not checked separately, but RTVars shouldn't appear in - // the strided indexing expressions. - return std::nullopt; + case mlir::AffineExprKind::SymbolId: { + auto symbol = llvm::cast(strided_indexing); + const Interval& symbol_interval = symbol_intervals[symbol.getPosition()]; + if (symbol_interval.lower != 0) { + return std::nullopt; + } + + return SizeAndStrideExpression( + /*size=*/getAffineConstantExpr(symbol_interval.upper + 1, ctx), + /*stride=*/getAffineConstantExpr(1, ctx)); + } case mlir::AffineExprKind::Add: { - // TODO(b/328427138): this should only be necessary in the multivariate - // case, and will be implemented later. std::optional> maybe_sizes_and_strides = ExtractSizesAndStridesFromMultivariateSummation( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc index 6aed0c540ba910..7025fe46fc4c47 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.cc @@ -300,13 +300,16 @@ void SortTiledHloInstructionsInPostOrder( } // namespace /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeComputation( - const HloComputation& computation, MLIRContext* ctx) { + const HloComputation& computation, MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder) { auto fusion = HloFusionAdaptor::ForComputation(&computation); - return SymbolicTileAnalysis::AnalyzeFusion(*fusion, ctx); + return SymbolicTileAnalysis::AnalyzeFusion( + *fusion, ctx, emitter_specific_constraints_builder); } /*static*/ SymbolicTileAnalysisOrError SymbolicTileAnalysis::AnalyzeFusion( - const HloFusionAdaptor& fusion, MLIRContext* ctx) { + const HloFusionAdaptor& fusion, MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder) { OrderedUniquePtrValueHashSet tiled_hlo_instructions_set; @@ -383,12 +386,20 @@ void SortTiledHloInstructionsInPostOrder( return std::get(constraints_or); } + // Create emitter-specific constraints if a builder was provided. + std::unique_ptr emitter_specific_constraints; + if (emitter_specific_constraints_builder != nullptr) { + emitter_specific_constraints = + emitter_specific_constraints_builder(tiled_hlo_instructions); + } + // Order instructions in def-before-use order. SortTiledHloInstructionsInPostOrder(tiled_hlo_instructions, root_tiled_hlo); return SymbolicTileAnalysis( std::move(tiled_hlo_instructions), - std::get(std::move(constraints_or)), ctx); + std::get(std::move(constraints_or)), + std::move(emitter_specific_constraints), ctx); } absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( @@ -399,11 +410,6 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( "This should never happen."); } - // Handle the unconstrained case. - if (constraints_.IsAlwaysSatisfied()) { - return true; - } - if (tile_parameters.size() != num_tile_parameters()) { return absl::InvalidArgumentError(absl::StrFormat( "Failed to check if tile parameters satisfy constraints. Number of " @@ -412,6 +418,21 @@ absl::StatusOr SymbolicTileAnalysis::ParametersSatisfyConstraints( tile_parameters.size(), num_tile_parameters())); } + if (emitter_specific_constraints_ != nullptr) { + TF_ASSIGN_OR_RETURN( + bool constraints_are_satisfied, + emitter_specific_constraints_->ParametersSatisfyConstraints( + tile_parameters)); + if (!constraints_are_satisfied) { + return false; + } + } + + // Handle the unconstrained case. + if (constraints_.IsAlwaysSatisfied()) { + return true; + } + // TODO(bchetioui): replace with convenience methods in // `ConstraintExpression`. bool constraints_are_satisfied = false; @@ -443,9 +464,9 @@ SymbolicTileAnalysis::ComputeTiledHloInstructions( TF_ASSIGN_OR_RETURN(bool constraints_are_satisfied, ParametersSatisfyConstraints(tile_parameters)); if (!constraints_are_satisfied) { - return absl::InvalidArgumentError(absl::StrCat( - "Tile parameters ", absl::StrJoin(tile_parameters, ", "), - " do not satisfy the SymbolicTileAnalysis's constraints.")); + return absl::InvalidArgumentError( + absl::StrCat("Tile parameters ", absl::StrJoin(tile_parameters, ", "), + " do not satisfy constraints.")); } } diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h index df56d2325dd641..692e88db11b998 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis.h @@ -44,6 +44,21 @@ class SymbolicTileAnalysis; using SymbolicTileAnalysisOrError = std::variant; +// An interface to implement additional emitter-specific constraints. This +// interface can be used as an extension point to further constrain the set of +// given limitations of a particular codegen solution. +class EmitterSpecificConstraints { + public: + virtual ~EmitterSpecificConstraints() = default; + + virtual absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const = 0; +}; + +using EmitterSpecificConstraintsBuilder = + std::function( + const std::vector>&)>; + // Constructs and holds symbolic tiles for all the instructions within a // computation. We may hold several different symbolic tiles for the same // instruction if the instruction is indexed in several different ways in order @@ -59,10 +74,17 @@ class SymbolicTileAnalysis { // Tries to construct a symbolic tile analysis from a computation. Returns // a diagnostic if the construction fails for any reason. + // + // If `emitter_specific_constraints_builder` is provided, it will be used to + // construct emitter-specific constraints for the analysis. static SymbolicTileAnalysisOrError AnalyzeComputation( - const HloComputation& computation, mlir::MLIRContext* ctx); + const HloComputation& computation, mlir::MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr); static SymbolicTileAnalysisOrError AnalyzeFusion( - const HloFusionAdaptor& fusion, mlir::MLIRContext* ctx); + const HloFusionAdaptor& fusion, mlir::MLIRContext* ctx, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr); // Returns a graph of HLO instructions tiled with the given tile parameters. // The provided tile parameters must satisfy the analysis's constraints. @@ -101,7 +123,8 @@ class SymbolicTileAnalysis { const ConstraintExpression& GetConstraints() const { return constraints_; } // Returns true if a list of tile parameters satisfies the symbolic tile - // analysis's constraints. + // analysis's constraints. If provided, also checks the emitter-specific + // constraints. // // Returns false if the constraints are not satisfied but can be evaluated // correctly. Returns an error if the constraints cannot be evaluated @@ -127,13 +150,16 @@ class SymbolicTileAnalysis { absl::StatusOr> GetGoodTilings() const; private: - SymbolicTileAnalysis(std::vector> - symbolic_tiled_hlo_instructions, - ConstraintExpression constraints, - mlir::MLIRContext* context) + SymbolicTileAnalysis( + std::vector> + symbolic_tiled_hlo_instructions, + ConstraintExpression constraints, + std::unique_ptr emitter_specific_constraints, + mlir::MLIRContext* context) : symbolic_tiled_hlo_instructions_( std::move(symbolic_tiled_hlo_instructions)), constraints_(std::move(constraints)), + emitter_specific_constraints_(std::move(emitter_specific_constraints)), context_(context) {} // The tiled HLO instructions in def-before-use order. @@ -143,6 +169,10 @@ class SymbolicTileAnalysis { // See the documentation of GetConstraints(). ConstraintExpression constraints_; + // Additional emitter-specific constraints on tile parameters. May be null if + // no builder was provided when constructing the analysis. + std::unique_ptr emitter_specific_constraints_; + mlir::MLIRContext* context_; }; diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc index a9680f0f5fdb07..c7fecd6525e2f2 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_analysis_test.cc @@ -35,13 +35,14 @@ limitations under the License. #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/model/indexing_test_utils.h" #include "xla/service/gpu/model/symbolic_tile.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" #include "xla/service/gpu/model/tiled_hlo_instruction.h" #include "xla/service/instruction_fusion.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -78,15 +79,44 @@ Matcher MatchTiledHloInstruction( tile_offsets_indexing); } +// Fake emitter-specific constraints for testing. Requires that the tile size +// along the first dimension is exactly half the size of the axis. +class FakeEmitterSpecificConstraints : public EmitterSpecificConstraints { + public: + absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const override { + return tile_parameters[0] == dim0_tile_size_; + } + + static EmitterSpecificConstraintsBuilder GetBuilder() { + return [](const std::vector>& + instructions) { + const SymbolicTiledHloInstruction* root = instructions[0].get(); + int64_t dim0_size = root->hlo()->shape().dimensions(0); + return std::make_unique( + /*dim0_tile_size=*/dim0_size / 2); + }; + } + + explicit FakeEmitterSpecificConstraints(int64_t dim0_tile_size) + : dim0_tile_size_(dim0_tile_size) {} + + private: + int64_t dim0_tile_size_; +}; + class SymbolicTileAnalysisTest : public HloTestBase { public: - std::optional TryAnalyzeModule(HloModule* module) { + std::optional TryAnalyzeModule( + HloModule* module, + EmitterSpecificConstraintsBuilder emitter_specific_constraints_builder = + nullptr) { SymbolicTileAnalysisOrError analysis_or_error = SymbolicTileAnalysis::AnalyzeComputation( *module->entry_computation() ->root_instruction() ->fused_instructions_computation(), - &mlir_context_); + &mlir_context_, emitter_specific_constraints_builder); if (std::holds_alternative(analysis_or_error)) { return std::get(std::move(analysis_or_error)); @@ -507,6 +537,35 @@ ENTRY main { impossible_tile_parameters, /*constraints_are_known_satisfied=*/true)); } +TEST_F(SymbolicTileAnalysisTest, EmitterSpecificConstraintsAreUsedCorrectly) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( + fusion { + p0 = f32[16,32] parameter(0) + ROOT add = f32[16,32] add(p0, p0) + } + + ENTRY main { + p0 = f32[16,32] parameter(0) + ROOT fusion = f32[16,32] fusion(p0), kind=kLoop, calls=fusion + })")); + + std::optional analysis = TryAnalyzeModule( + module.get(), FakeEmitterSpecificConstraints::GetBuilder()); + + ASSERT_TRUE(analysis.has_value()); + + // FakeEmitterSpecificConstraints require that the tile size along the first + // dimension is exactly half the size of the axis. Tile sizes {5, 32} do not + // satisfy emitter-specific constraints. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({5, 32}), + IsOkAndHolds(false)); + + // However, tile sizes {8, 32} do satisfy emitter-specific constraints. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({8, 32}), + IsOkAndHolds(true)); +} + TEST_F(SymbolicTileAnalysisTest, ConstraintsAreAggregatedCorrectly) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(R"( diff --git a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc index 1db55375c0cc84..92c851d94ea929 100644 --- a/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc +++ b/third_party/xla/xla/service/gpu/model/symbolic_tile_test.cc @@ -549,6 +549,61 @@ TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReshapeOfReverse) { )"))); } +TEST_F(SymbolicTileTest, CanPropagateTileThroughSplitReductionOfSplittedAxis) { + // A split reshape of a reverse creates a sum of strided symbols. + auto input_indexing = GetOutputToInputIndexing(ParseAndGetRoot(R"( + HloModule m + add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) + } + + computation { + p0 = f32[18] parameter(0) + bitcast = f32[9,2] bitcast(p0) + c0 = f32[] constant(0) + reduce_0 = f32[9] reduce(bitcast, c0), dimensions={1}, to_apply=add + ROOT reduce_1 = f32[] reduce(reduce_0, c0), dimensions={0}, to_apply=add + } + + ENTRY e { + p0 = f32[18] parameter(0) + ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=computation + } + )")); + + EXPECT_THAT( + SymbolicTile::FromIndexingMap(*input_indexing.indexing_maps[0].begin()), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: () -> (0) + size_map: () -> (18) + stride_map: () -> (1) + )"))); +} + +TEST_F(SymbolicTileTest, CanPropagateTileThroughSummationOfSymbols) { + // Such an indexing map is representative of a sequence of HLOs containing a + // bitcast followed by two sequential reductions of the split axis, i.e. + // something like + // p0 = f32[18] parameter(0) + // bitcast = f32[9,2] bitcast(p0) + // reduce_0 = f32[9] reduce(bitcast), dimensions={1} + // reduce_1 = f32[] reduce(reduce_0), dimensions={0} + IndexingMap indexing_map = IndexingMap::FromTensorSizes( + ParseAffineMap("()[s0, s1] -> (s1 * 2 + s0)", &mlir_context_), {}, + {2, 9}); + + EXPECT_THAT(SymbolicTile::FromIndexingMap(indexing_map), + Optional(MatchSymbolicTileString(R"( + Symbolic tile with + offset_map: () -> (0) + size_map: () -> (18) + stride_map: () -> (1) + )"))); +} + TEST_F(SymbolicTileTest, FailsGracefullyAtPropagatingTileThroughSliceOfSplitReshape) { // TODO(b/349487906): constraints should allow us to unblock this use case. diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc new file mode 100644 index 00000000000000..6ccc3db120c697 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/triton_emitter_constraints.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/IR/AffineMap.h" +#include "xla/service/gpu/model/affine_map_evaluator.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/symbolic_tiled_hlo_instruction.h" + +namespace xla { +namespace gpu { + +namespace { + +// Triton enforces that all tensors in the program have less than 1048576 +// elements, otherwise it will fail to compile. +constexpr int64_t kMaxTensorNumElements = 1048576; + +} // namespace + +/*static*/ EmitterSpecificConstraintsBuilder +TritonEmitterConstraints::GetBuilder() { + return [](const std::vector>& + instructions) { + llvm::DenseSet unique_tile_size_maps; + for (const auto& tiled_hlo_instruction : instructions) { + unique_tile_size_maps.insert( + tiled_hlo_instruction->symbolic_tile().size_map()); + } + + return std::make_unique( + llvm::SmallVector(unique_tile_size_maps.begin(), + unique_tile_size_maps.end())); + }; +} + +absl::StatusOr TritonEmitterConstraints::ParametersSatisfyConstraints( + absl::Span tile_parameters) const { + // Verify that the tile sizes are not too big. + for (const auto& tile_size_map : tile_size_maps_) { + int64_t tile_size = 1; + for (auto expr : tile_size_map.getResults()) { + tile_size *= llvm::PowerOf2Ceil( + EvaluateAffineExpr(expr, /*dim_values=*/tile_parameters)); + } + + if (tile_size > kMaxTensorNumElements) { + return false; + } + } + return true; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h new file mode 100644 index 00000000000000..d5281bd12f0e98 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints.h @@ -0,0 +1,54 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/AffineMap.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" + +#ifndef XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ +#define XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ + +namespace xla { +namespace gpu { + +// Triton-specific constraints on tile sizes. +class TritonEmitterConstraints : public EmitterSpecificConstraints { + public: + static EmitterSpecificConstraintsBuilder GetBuilder(); + + explicit TritonEmitterConstraints( + llvm::SmallVector tile_size_maps) + : tile_size_maps_(std::move(tile_size_maps)) {} + + absl::StatusOr ParametersSatisfyConstraints( + absl::Span tile_parameters) const override; + + private: + // A collection of unique size maps from all the SymbolicTiledHloInstructions. + // + // Different TiledHloInstructions often have the same size map, so we keep a + // collection of unique maps to improve compilation time. + llvm::SmallVector tile_size_maps_; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_MODEL_TRITON_EMITTER_CONSTRAINTS_H_ diff --git a/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc new file mode 100644 index 00000000000000..827c2fa488a307 --- /dev/null +++ b/third_party/xla/xla/service/gpu/model/triton_emitter_constraints_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/model/triton_emitter_constraints.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/log/log.h" +#include "mlir/IR/MLIRContext.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/instruction_fusion.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class TritonEmitterConstraintsTest : public HloTestBase { + public: + std::optional TryAnalyzeModule(HloModule* module) { + SymbolicTileAnalysisOrError analysis_or_error = + SymbolicTileAnalysis::AnalyzeComputation( + *module->entry_computation() + ->root_instruction() + ->fused_instructions_computation(), + &mlir_context_, TritonEmitterConstraints::GetBuilder()); + + if (std::holds_alternative(analysis_or_error)) { + return std::get(std::move(analysis_or_error)); + } + VLOG(1) << "Cannot analyze module: " + << std::get(analysis_or_error).Explain(); + return std::nullopt; + } + + mlir::MLIRContext mlir_context_; +}; + +TEST_F(TritonEmitterConstraintsTest, TritonSpecificConstraintsAreEnforced) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(R"( +HloModule m + +max_computation { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT maximum = f32[] maximum(param_0, param_1) +} + +fused_computation { + param_0 = f32[8192,50304] parameter(0) + constant = f32[] constant(-inf) + reduce = f32[8192] reduce(param_0, constant), dimensions={1}, to_apply=max_computation + broadcast = f32[8192,50304] broadcast(reduce), dimensions={0} + ROOT subtract = f32[8192,50304] subtract(param_0, broadcast) +} + +ENTRY entry_computation { + param_0 = f32[8192,50304] parameter(0) + ROOT fusion = f32[8192,50304] fusion(param_0), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config":{"kind":"__triton"}} +} +)")); + + std::optional analysis = TryAnalyzeModule(module.get()); + ASSERT_TRUE(analysis.has_value()); + + // The biggest tile in the program has 8 * 65536 = 524288 elements. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({8, 128}), + IsOkAndHolds(true)); + + // The biggest tile in the program is 18 * 50304 = 905472 elements which is + // smaller than the limit of 1048576, but since Triton requires all tile sizes + // to be a power of 2, the actual tile will be 32 * 65536 = 2097152 elements. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({18, 50304}), + IsOkAndHolds(false)); + + // Because of reduce, we need to load full rows from param_0 and the load tile + // will be 1024 * 65536 = 67108864 elements, that is larger than the limit of + // 1048576. + EXPECT_THAT(analysis->ParametersSatisfyConstraints({1024, 1}), + IsOkAndHolds(false)); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index 40eff87caf8c06..2044115f3bbc44 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/nvptx_compiler.h" +#include #include #include #include @@ -52,34 +53,35 @@ limitations under the License. #include "xla/service/dump.h" #include "xla/service/float_normalization.h" #include "xla/service/float_support.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/autotuning/conv_algorithm_picker.h" +#include "xla/service/gpu/autotuning/gemm_algorithm_picker.h" +#include "xla/service/gpu/autotuning/gemm_fusion_autotuner.h" #include "xla/service/gpu/buffer_sharing.h" -#include "xla/service/gpu/conv_algorithm_picker.h" -#include "xla/service/gpu/cublas_pad_for_gemms.h" #include "xla/service/gpu/cublas_padding_requirements.h" -#include "xla/service/gpu/cudnn_fused_conv_rewriter.h" -#include "xla/service/gpu/cudnn_fused_mha_rewriter.h" -#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h" -#include "xla/service/gpu/cudnn_fusion_compiler.h" -#include "xla/service/gpu/cudnn_norm_rewriter.h" -#include "xla/service/gpu/cudnn_pad_for_convolutions.h" -#include "xla/service/gpu/cudnn_simplify_padding.h" -#include "xla/service/gpu/cudnn_vectorize_convolutions.h" -#include "xla/service/gpu/cudnn_workspace_rewriter.h" -#include "xla/service/gpu/cusolver_rewriter.h" -#include "xla/service/gpu/dot_sparsity_rewriter.h" -#include "xla/service/gpu/gemm_algorithm_picker.h" -#include "xla/service/gpu/gemm_fusion_autotuner.h" -#include "xla/service/gpu/gpu_algebraic_simplifier.h" #include "xla/service/gpu/gpu_asm_opts_util.h" #include "xla/service/gpu/gpu_compiler.h" -#include "xla/service/gpu/gpu_conv_padding_legalization.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "xla/service/gpu/metrics.h" #include "xla/service/gpu/target_constants.h" -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" +#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" +#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" +#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_simplify_padding.h" +#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" +#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" +#include "xla/service/gpu/transforms/gpusolver_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include "xla/service/hlo_constant_folding.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -94,6 +96,8 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_asm_compiler.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" +#include "xla/stream_executor/cuda/nvjitlink.h" +#include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/cuda/ptx_compilation_method.h" #include "xla/stream_executor/cuda/ptx_compiler.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" @@ -187,7 +191,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( auto cuda_compute_capability = std::get(gpu_version); // Convert convolutions into CustomCalls to cudnn, then canonicalize them - // (GpuConvPaddingLegalization). Also expand cuSolver calls. + // (ConvPaddingLegalization). Also expand cuSolver calls. HloPassPipeline pipeline("conv_canonicalization"); pipeline.AddInvariantCheckerDebug( /*layout_sensitive=*/false, @@ -202,10 +206,10 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(&matmul_bf16_support); pipeline.AddPass(); - pipeline.AddPass(cuda_compute_capability); + pipeline.AddPass(cuda_compute_capability); pipeline.AddPass(cuda_compute_capability, dnn_version, GetToolkitVersion()); - pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(cuda_compute_capability); pipeline.AddPass(cuda_compute_capability, dnn_version); @@ -230,7 +234,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( // e.g. clean up unnecessary nop `convert`s. pipeline.AddPass(); - // tf2xla bridge, DepthwiseConvolutionConverter, GpuConvRewriter, and + // tf2xla bridge, DepthwiseConvolutionConverter, ConvRewriter, and // CudnnSimplifyPadding introduce reshapes and transposes. Run ReshapeMover // to a fixed point. Include algsimp because ReshapeMover relies on it. [&, &pipeline = pipeline.AddPass>( @@ -252,7 +256,7 @@ absl::Status NVPTXCompiler::OptimizeHloConvolutionCanonicalization( pipeline.AddPass(algsimp_options, gpu_version); }(); - // GpuConvRewriter, GpuConvPaddingLegalization and + // ConvRewriter, ConvPaddingLegalization and // CudnnConvPadForTensorCores may add instructions which can be simplified // by constant folding. pipeline.AddPass(); @@ -338,9 +342,6 @@ absl::Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( // Transform TriangularSolve ops into custom-calls, so we can add temp // memory. post_pipeline.AddPass(); - if (stream_exec) { - post_pipeline.AddPass(*stream_exec); - } TF_RETURN_IF_ERROR(post_pipeline.Run(hlo_module).status()); return absl::OkStatus(); @@ -386,20 +387,22 @@ absl::Status NVPTXCompiler::AddGemmFusionAutotuningPasses( absl::Status NVPTXCompiler::AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) { if (debug_options.xla_gpu_enable_cub_radix_sort()) { - pipeline->AddPass(); + pipeline->AddPass(); } return absl::OkStatus(); } -absl::Status NVPTXCompiler::RunCudnnFusionCompilerPass( +absl::Status NVPTXCompiler::RunCudnnCompilerPasses( HloModule* module, se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) { tsl::profiler::ScopedAnnotation annotation([&] { return absl::StrFormat("XlaCompileCudnnFusion:#module=%s,program_id=%d#", module->name(), module->unique_id()); }); - CuDnnFusionCompiler cudnn_compiler(*stream_exec, *dnn_compiled_graphs); - return cudnn_compiler.Run(module).status(); + CuDnnFusionCompiler fusion_compiler(*stream_exec, *dnn_compiled_graphs); + TF_RETURN_IF_ERROR(fusion_compiler.Run(module).status()); + CuDnnCustomCallCompiler call_compiler(*stream_exec, *dnn_compiled_graphs); + return call_compiler.Run(module).status(); } namespace { @@ -531,6 +534,8 @@ HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() const { return &CanShareBufferHint; } +constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '}; + absl::StatusOr NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, llvm::Module* llvm_module, @@ -568,6 +573,22 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config, RecordLlvmPassesAndLlvmToPtxDuration(end_usecs - start_usecs); } + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, + ChooseLinkingMethod(module_config.debug_options())); + + if (linking_method == se::PtxLinkingMethod::kNvJitLink && relocatable) { + VLOG(2) << "Deferring the PTX to CUBIN compilation of the relocatable " + "module to the linking step."; + std::vector binary; + if (!ptx.empty()) { + binary.reserve(sizeof(kPtxPrefix) + ptx.size() + 1); + binary.insert(binary.end(), kPtxPrefix, kPtxPrefix + sizeof(kPtxPrefix)); + binary.insert(binary.end(), ptx.begin(), ptx.end()); + binary.emplace_back('\0'); + } + return BackendCompileResult{std::move(ptx), std::move(binary)}; + } + absl::StatusOr> maybe_cubin = CompileGpuAsmOrGetCachedResult( ptx, std::get(gpu_version), module_config, @@ -588,6 +609,9 @@ std::vector GetSupportedCompilationMethods() { if (se::IsLibNvPtxCompilerSupported()) { methods.emplace_back(PtxCompilationMethod::kNvPtxCompiler); } + if (se::IsLibNvJitLinkSupported()) { + methods.emplace_back(PtxCompilationMethod::kNvJitLink); + } methods.emplace_back(PtxCompilationMethod::kPtxas); return methods; } @@ -608,11 +632,26 @@ absl::StatusOr ChooseCompilationMethod( } }; + if (!debug_options.xla_gpu_enable_libnvjitlink()) { + VLOG(3) << "Discarding NvJitLink since it is disabled."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } if (!debug_options.xla_gpu_enable_libnvptxcompiler()) { VLOG(3) << "Discarding NvPtxCompiler since it is disabled."; remove_compilation_method(PtxCompilationMethod::kNvPtxCompiler); } + VLOG(2) << "Supported and enabled compilation methods: " + << absl::StrJoin(compilation_methods, ", "); + + if (relocatable && absl::c_linear_search(compilation_methods, + PtxCompilationMethod::kNvJitLink)) { + // NvJitLink can't produce relocatable CUBINs. + VLOG(3) << "Discarding NvJitLink since it can't produce the requested " + "relocatable CUBIN."; + remove_compilation_method(PtxCompilationMethod::kNvJitLink); + } + VLOG(2) << "Considered compilation methods: " << absl::StrJoin(compilation_methods, ", "); @@ -655,6 +694,16 @@ static absl::StatusOr> AssembleOptionsAndCompile( absl::StatusOr> maybe_cubin = [&] { switch (compilation_method) { + case PtxCompilationMethod::kNvJitLink: + return se::CompileAndLinkUsingLibNvJitLink( + cc.major, cc.minor, + {se::NvJitLinkInput{ + se::NvJitLinkInput::Type::kPtx, + absl::Span{ + reinterpret_cast(ptx.c_str()), + ptx.size() + 1 /* We need the null terminator. */}}}, + ptxas_config, cancel_if_reg_spill); + case PtxCompilationMethod::kNvPtxCompiler: return se::CompileGpuAsmUsingLibNvPtxCompiler( cc.major, cc.minor, ptx.c_str(), ptxas_config, cancel_if_reg_spill); @@ -815,6 +864,12 @@ absl::StatusOr NVPTXCompiler::ChooseLinkingMethod( std::string& preferred_cuda_dir = ptxas_config.preferred_cuda_dir; using LinkingMethod = se::PtxLinkingMethod; + + if (stream_executor::IsLibNvJitLinkSupported() && + debug_options.xla_gpu_enable_libnvjitlink()) { + return se::PtxLinkingMethod::kNvJitLink; + } + TF_ASSIGN_OR_RETURN(auto asm_compiler_version, GetAsmCompilerVersion(debug_options, preferred_cuda_dir)); @@ -859,28 +914,60 @@ absl::StatusOr NVPTXCompiler::CanUseLinkModules( } absl::StatusOr> NVPTXCompiler::LinkModules( - se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, - std::vector> modules, + se::GpuComputeCapability compute_capability, + se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) { if (modules.empty()) return std::vector{}; + auto cc = + std::get(compute_capability); + TF_ASSIGN_OR_RETURN(se::PtxLinkingMethod linking_method, ChooseLinkingMethod(debug_options)); VLOG(1) << "Linking " << modules.size() << " modules with linking method: " << linking_method; - std::vector images; - images.reserve(modules.size()); + if (linking_method == se::PtxLinkingMethod::kNvJitLink) { + const auto module_contains_ptx = + [](const std::vector& module) -> bool { + return module.size() >= sizeof(kPtxPrefix) && + std::equal(std::begin(kPtxPrefix), std::end(kPtxPrefix), + std::begin(module)); + }; + + std::vector nvjitlink_inputs; + nvjitlink_inputs.reserve(modules.size()); + for (std::vector& module : modules) { + if (module_contains_ptx(module)) { + nvjitlink_inputs.push_back( + {se::NvJitLinkInput::Type::kPtx, + absl::Span(module).subspan(sizeof(kPtxPrefix))}); + } else { + nvjitlink_inputs.push_back({se::NvJitLinkInput::Type::kCubin, module}); + } + } + + se::GpuAsmOpts ptxas_config = PtxOptsFromDebugOptions(debug_options); + return stream_executor::CompileAndLinkUsingLibNvJitLink( + cc.major, cc.minor, nvjitlink_inputs, ptxas_config, + /*cancel_if_reg_spill=*/false); + } + + std::vector cubin_images; + cubin_images.reserve(modules.size()); for (std::vector& module : modules) { - images.push_back({"", std::move(module)}); + { + std::string profile = absl::StrCat("sm_", cc.major, cc.minor); + cubin_images.push_back({std::move(profile), std::move(module)}); + } } + auto context = se::gpu::ExtractGpuExecutor(stream_exec)->gpu_context(); if (linking_method == se::PtxLinkingMethod::kNvLink) { - return LinkUsingNvlink(std::get(cc), - debug_options.xla_gpu_cuda_data_dir(), context, - images); + return LinkUsingNvlink(cc, debug_options.xla_gpu_cuda_data_dir(), context, + cubin_images); } - return LinkGpuAsm(std::get(cc), context, images); + return LinkGpuAsm(cc, context, cubin_images); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index 25fa268226107d..6d84deb4398176 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -32,7 +32,7 @@ limitations under the License. #include "xla/autotune_results.pb.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/distributed/key_value_store_interface.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/gpu_compiler.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_dataflow_analysis.h" @@ -84,9 +84,9 @@ class NVPTXCompiler : public GpuCompiler { absl::Status AddCustomKernelReplacementPasses( HloPassPipeline* pipeline, const DebugOptions& debug_options) override; - absl::Status RunCudnnFusionCompilerPass( - HloModule* module, se::StreamExecutor* stream_exec, - BinaryMap* dnn_compiled_graphs) override; + absl::Status RunCudnnCompilerPasses(HloModule* module, + se::StreamExecutor* stream_exec, + BinaryMap* dnn_compiled_graphs) override; HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; @@ -100,7 +100,8 @@ class NVPTXCompiler : public GpuCompiler { private: absl::StatusOr> LinkModules( - se::GpuComputeCapability cc, se::StreamExecutor* stream_exec, + se::GpuComputeCapability gpu_compute_capability, + se::StreamExecutor* stream_exec, std::vector> modules, const DebugOptions& debug_options) override; diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc index 642a0cc9eca438..f43066672b2bef 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler_test.cc @@ -34,9 +34,9 @@ limitations under the License. #include "xla/service/logical_buffer.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc index 21d2590e763176..16fdf7c8fbe6ff 100644 --- a/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc +++ b/third_party/xla/xla/service/gpu/parallel_loop_emitter.cc @@ -166,16 +166,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x" // // %nctaid.x is currently specified as 2147483647. - if (launch_dimensions_.thread_counts_per_block().y > 1) { - // When blockDim.y > 1, then we are in the small row case. Each - // blockDim.x do exatly to one row and blockDim.y map to some - // consecutive row. This prevents too small block size that isn't - // efficient. - CHECK(launch_config_.row_vectorized); - CHECK_EQ(shape_.dimensions().back(), - launch_dimensions_.thread_counts_per_block().x * - launch_config_.unroll_factor); - } CHECK_EQ(launch_dimensions_.thread_counts_per_block().z, 1); CHECK_EQ(launch_dimensions_.block_counts().y, 1); CHECK_EQ(launch_dimensions_.block_counts().z, 1); @@ -189,14 +179,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, llvm::Value* linear_index_base = linear_base_and_thread_idx.linear_base; - llvm::Value* row_index = - launch_config_.row_vectorized - ? b_->CreateMul(linear_base_and_thread_idx.thread_idx, - llvm::ConstantInt::get(index_type, - launch_config_.unroll_factor), - "row_index", /*HasNUW=*/true, /*HasNSW=*/true) - : nullptr; - std::vector multidim(shape_.rank(), nullptr); for (int i = 0; i < launch_config_.unroll_factor; ++i) { // The add operation is needed even if the offset is 0, since when the @@ -207,17 +189,6 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, b_->CreateAdd(linear_index_base, llvm::ConstantInt::get(index_type, i), absl::StrCat("linear_index", i), /*HasNUW=*/true, /*HasNSW=*/true); - if (launch_config_.row_vectorized) { - // This lets us avoid emitting the division for the last dimension of the - // index. The check for i > 0 is here for historical reasons, it might not - // do anything. - multidim.back() = - i == 0 ? row_index - : b_->CreateAdd( - row_index, llvm::ConstantInt::get(index_type, i), - absl::StrCat("row_index_plus", i), /*HasNUW=*/true, - /*HasNSW=*/true); - } array_indices.emplace_back(linear_index, multidim, shape_, b_); } diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 928f4bf4cc573a..7ced52cbd17ca0 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -22,10 +22,10 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/copy_insertion.h" #include "xla/service/cpu_gpu_shape_verifier.h" -#include "xla/service/gpu/alias_passthrough_params.h" -#include "xla/service/gpu/copy_fusion.h" -#include "xla/service/gpu/gpu_sanitize_constant_names.h" -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/alias_passthrough_params.h" +#include "xla/service/gpu/transforms/copy_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_pass_pipeline.h" @@ -78,14 +78,14 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( } // We are using a sub-pipeline here, so that the verifier only runs after both - // GpuHorizontalLoopFusion and HloDCE. + // HorizontalLoopFusion and HloDCE. auto& sub_pipeline = pipeline.AddPass("horizontal-loop-fusion-for-copy"); // To fuse the copy. sub_pipeline.AddPass(); - sub_pipeline.AddPass("copy_"); + sub_pipeline.AddPass("copy_"); sub_pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(); return pipeline; } diff --git a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc index b8496b36e52f95..03adc0b2cdaea5 100644 --- a/third_party/xla/xla/service/gpu/ptx_compilation_test.cc +++ b/third_party/xla/xla/service/gpu/ptx_compilation_test.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/nvptx_compiler.h" #include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/cuda/nvjitlink_support.h" #include "xla/stream_executor/cuda/ptx_compilation_method.h" #include "xla/stream_executor/cuda/ptx_compiler_support.h" #include "xla/stream_executor/cuda/ptx_linking_method.h" @@ -101,12 +102,20 @@ ENTRY e { "num_ctas":1}}} })"; +constexpr std::string_view kResultsInNoPtxHlo = R"( + ENTRY e { + a = f32[5,5] parameter(0) + ROOT _ = f32[5,5] custom-call(a, a), custom_call_target="__cublas$gemm", + backend_config="{ \"gemm_backend_config\": {\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"epilogue\":\"DEFAULT\"}}" + })"; + std::string_view GetHlo(std::string_view name) { static const absl::flat_hash_map* const kHloMap = new absl::flat_hash_map( {{"simple", kSimpleHlo}, {"parallel_compilation", kParallelCompilationHlo}, - {"requires_sm90a", kSM90AHlo}}); + {"requires_sm90a", kSM90AHlo}, + {"results_in_no_ptx", kResultsInNoPtxHlo}}); return kHloMap->at(name); } @@ -155,6 +164,29 @@ class NVPTXCompilationTests // Compiled without libnvptxcompiler support GTEST_SKIP() << "libnvptxcompiler is not supported in this build."; } + + if (!stream_executor::IsLibNvJitLinkSupported() && + (compilation_method == PtxCompilationMethod::kNvJitLink || + linking_method == PtxLinkingMethod::kNvJitLink)) { + // Compiled without libnvjitlink support + GTEST_SKIP() << "libnvjitlink is not supported in this build."; + } + + if (compilation_method == PtxCompilationMethod::kNvJitLink && + linking_method != PtxLinkingMethod::kNvJitLink) { + // When compilation method is NvJitLink, linking method must be NvJitLink + // as well. + GTEST_SKIP() << "Compilation method NvJitLink is only supported if the " + "linking method is NvJitLink as well."; + } + + if (compilation_method == PtxCompilationMethod::kPtxas && + linking_method == PtxLinkingMethod::kNvJitLink) { + // We could support this combination, but it would require some + // refactoring of the flags. + GTEST_SKIP() << "Compilation method Ptxas is not supported with linking " + "method NvJitLink."; + } } void SetDebugOptionsFromPtxSettings(DebugOptions* debug_options, @@ -163,6 +195,10 @@ class NVPTXCompilationTests debug_options->set_xla_gpu_enable_libnvptxcompiler( compilation_method == PtxCompilationMethod::kNvPtxCompiler); + debug_options->set_xla_gpu_enable_libnvjitlink( + compilation_method == PtxCompilationMethod::kNvJitLink || + linking_method == PtxLinkingMethod::kNvJitLink); + debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism( linking_method != PtxLinkingMethod::kNone); debug_options->set_xla_gpu_force_compilation_parallelism(12); @@ -260,15 +296,20 @@ TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { absl::Span reference_binary = static_cast(reference.value().get())->binary(); - if (executable_binary != reference_binary) { - std::string test_name = - GenerateParametrizedTestname(name, compilation_method, linking_method); - DumpArtifactIfEnabled(absl::StrCat(test_name, "_executable.bin"), - executable_binary); - DumpArtifactIfEnabled(absl::StrCat(test_name, "_reference.bin"), - reference_binary); + if (executable_binary == reference_binary) { + // If the binaries are exactly the same, we can short circuit and don't need + // to parse them. + SUCCEED(); + return; } + std::string test_name = + GenerateParametrizedTestname(name, compilation_method, linking_method); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_executable.bin"), + executable_binary); + DumpArtifactIfEnabled(absl::StrCat(test_name, "_reference.bin"), + reference_binary); + auto get_text_sections = [&](absl::Span binary) -> absl::StatusOr> { auto buffer = llvm::MemoryBuffer::getMemBuffer( @@ -313,12 +354,15 @@ TEST_P(NVPTXCompilationTests, CompareBinaryOutput) { INSTANTIATE_TEST_SUITE_P( NVPTXCompilationTest, NVPTXCompilationTests, - ::testing::Combine( - ::testing::Values("simple", "parallel_compilation", "requires_sm90a"), - ::testing::Values(PtxCompilationMethod::kNvPtxCompiler, - PtxCompilationMethod::kPtxas), - ::testing::Values(PtxLinkingMethod::kNone, PtxLinkingMethod::kNvLink, - PtxLinkingMethod::kDriver)), + ::testing::Combine(::testing::Values("simple", "parallel_compilation", + "requires_sm90a", "results_in_no_ptx"), + ::testing::Values(PtxCompilationMethod::kNvPtxCompiler, + PtxCompilationMethod::kPtxas, + PtxCompilationMethod::kNvJitLink), + ::testing::Values(PtxLinkingMethod::kNone, + PtxLinkingMethod::kNvLink, + PtxLinkingMethod::kDriver, + PtxLinkingMethod::kNvJitLink)), [](const ::testing::TestParamInfo>& info) { return GenerateParametrizedTestname(std::get<0>(info.param), diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 3eac5b872b1493..8d2d9d1c373993 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -81,7 +81,6 @@ cc_library( "//xla/service:executable", "//xla/service:global_device_id", "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", "//xla/service/gpu:launch_dimensions", "//xla/service/gpu:matmul_utils", "//xla/service/gpu:stream_executor_util", @@ -89,7 +88,6 @@ cc_library( "//xla/stream_executor", "//xla/stream_executor:command_buffer", "//xla/stream_executor:dnn", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:lazy_op_runner", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor/gpu:gpu_stream_header", @@ -124,7 +122,6 @@ cc_library( ":copy_thunk", ":cudnn_thunk", ":custom_call_thunk", - ":fused_mha_thunk", ":gemm_thunk", ":gpublas_lt_matmul_thunk", ":kernel_thunk", @@ -166,11 +163,12 @@ xla_test( "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/functional:function_ref", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", @@ -299,6 +297,7 @@ cc_library( "@com_google_absl//absl/crc:crc32c", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", @@ -385,10 +384,10 @@ xla_test( "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -427,7 +426,6 @@ cc_library( name = "command_buffer_thunk", srcs = ["command_buffer_thunk.cc"], hdrs = ["command_buffer_thunk.h"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":annotation", ":command_buffer_cmd", @@ -454,7 +452,7 @@ cc_library( xla_test( name = "command_buffer_thunk_test", - srcs = if_gpu_is_configured(["command_buffer_thunk_test.cc"]), + srcs = ["command_buffer_thunk_test.cc"], backend_tags = { "gpu_a100": if_google(["config-cuda-only"]), "gpu_v100": if_google(["config-cuda-only"]), @@ -488,10 +486,12 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_memory_allocator", "//xla/stream_executor/gpu:gpu_test_kernels", + "//xla/stream_executor/gpu:gpu_test_kernels_fatbin", "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -663,28 +663,6 @@ cc_library( ], ) -cc_library( - name = "fused_mha_thunk", - srcs = ["fused_mha_thunk.cc"], - hdrs = ["fused_mha_thunk.h"], - deps = [ - ":thunk", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service:buffer_assignment", - "//xla/service/gpu:buffer_allocations", - "//xla/service/gpu:gpu_fused_mha_runner", - "//xla/stream_executor", - "//xla/stream_executor:lazy_op_runner", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status", - "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "gemm_thunk", srcs = ["gemm_thunk.cc"], @@ -759,7 +737,6 @@ cc_library( "//xla/service/gpu:stream_executor_util", "//xla/service/gpu/kernels:custom_kernel", "//xla/stream_executor", - "//xla/stream_executor:kernel_factory", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:inlined_vector", diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc index fcc38cd2d65514..aceb2cdbb94666 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.cc @@ -64,7 +64,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/lazy_op_runner.h" #include "xla/stream_executor/stream.h" @@ -714,7 +713,7 @@ absl::Status CustomKernelLaunchCmd::Initialize( TF_ASSIGN_OR_RETURN( std::unique_ptr kernel, - se::KernelFactory::Create(params.executor, custom_kernel_.kernel_spec())); + params.executor->LoadKernel(custom_kernel_.kernel_spec())); absl::MutexLock lock(&mutex_); kernels_.emplace(params.executor, std::move(kernel)); @@ -1168,314 +1167,6 @@ CommandBufferCmd::BufferUsageVector GemmCmd::buffers() { {workspace_, MemoryAccess::kWrite}}; } -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -FusedMHACmd::FusedMHACmd( - ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHACmd, - execution_stream_id), - config_(std::move(config)), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionRunner& FusedMHACmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHACmd::Initialize(const Thunk::InitializeParams& params, - StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHACmd::Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHARunner(); - CHECK(lazy_runner) << "FusedMHA lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHACmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << " lhs_bmm1_buffer: " << lhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm1_buffer: " << rhs_bmm1_buffer_.ToString(); - VLOG(5) << " rhs_bmm2_buffer: " << rhs_bmm2_buffer_.ToString(); - VLOG(5) << " output_buffer: " << output_buffer_.ToString(); - VLOG(5) << " scratch_buffer: " << scratch_buffer_.ToString(); - VLOG(5) << " bias_buffer: " << bias_buffer_.ToString(); - VLOG(5) << " activation_buffer: " << activation_buffer_.ToString(); - VLOG(5) << " seqlen_q_buffer: " << seqlen_q_buffer_.ToString(); - VLOG(5) << " seqlen_k_buffer: " << seqlen_k_buffer_.ToString(); - - RunFusedMHAOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, stream, opts); - }); -} - -FusedMHACmd::BufferUsageVector FusedMHACmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(9); - buffer_usage.push_back({lhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm1_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({rhs_bmm2_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - } - if (activation_buffer_.allocation() != nullptr) { - buffer_usage.push_back({activation_buffer_, MemoryAccess::kRead}); - } - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - } - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - } - return buffer_usage; -} - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -FusedMHABackwardCmd::FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k) - : TracedCommandBufferCmd(CommandBufferCmdType::kFusedMHABackwardCmd, - execution_stream_id), - config_(std::move(config)), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k) {} - -FusedMultiHeadedAttentionBackwardRunner& FusedMHABackwardCmd::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mutex_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardCmd::Initialize( - const Thunk::InitializeParams& params, StateManager& state) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner - ->GetOrCreateRunner(config, params.command_buffer_trace_stream) - .status(); -} - -absl::Status FusedMHABackwardCmd::Record( - const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, se::CommandBuffer* command_buffer) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(execute_params.command_buffer_trace_stream) - .AsFusedMHABackwardRunner(); - CHECK(lazy_runner) - << "FusedMHABackward lazy runner cache should have been populated"; - - const auto& buffer_allocations = *execute_params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - - ExecutionScopeId execution_scope_id = GetExecutionScope(record_params); - VLOG(5) << "FusedMHABackwardCmd with execution_scope_id: " - << execution_scope_id.value(); - VLOG(5) << "bmm1_grad_gemm1_rhs_buffer" - << bmm1_grad_gemm1_rhs_buffer_.ToString(); - VLOG(5) << "bmm1_grad_gemm2_rhs_buffer" - << bmm1_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm1_lhs_buffer" - << bmm2_grad_gemm1_lhs_buffer_.ToString(); - VLOG(5) << "bmm2_grad_gemm2_rhs_buffer" - << bmm2_grad_gemm2_rhs_buffer_.ToString(); - VLOG(5) << "d_output_buffer" << d_output_buffer_.ToString(); - VLOG(5) << "scratch_buffer" << scratch_buffer_.ToString(); - VLOG(5) << "d_bmm1_lhs_buffer" << d_bmm1_lhs_buffer_.ToString(); - VLOG(5) << "d_bmm1_rhs_buffer" << d_bmm1_rhs_buffer_.ToString(); - VLOG(5) << "d_bmm2_rhs_buffer" << d_bmm2_rhs_buffer_.ToString(); - VLOG(5) << "d_s_buffer" << d_s_buffer_.ToString(); - VLOG(5) << "d_bias_buffer" << d_bias_buffer_.ToString(); - VLOG(5) << "fwd_output_buffer" << fwd_output_buffer_.ToString(); - VLOG(5) << "bias_buffer" << bias_buffer_.ToString(); - VLOG(5) << "seqlen_q_buffer" << seqlen_q_buffer_.ToString(); - VLOG(5) << "seqlen_k_buffer" << seqlen_k_buffer_.ToString(); - - RunFusedMHABackwardOptions opts; - opts.runner_cache = - &GetOrCreateRunner(execute_params.command_buffer_trace_stream); - return AddTracedCommandBuffer( - execute_params, record_params, command_buffer, [&](se::Stream* stream) { - return RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, - d_output_buffer, scratch_buffer, d_bmm1_lhs_buffer, - d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, d_s_buffer, d_bias_buffer, - fwd_output_buffer, bias_buffer, seqlen_q_buffer, seqlen_k_buffer, - stream, opts); - }); -} - -FusedMHABackwardCmd::BufferUsageVector FusedMHABackwardCmd::buffers() { - BufferUsageVector buffer_usage; - buffer_usage.reserve(15); - - buffer_usage.push_back({bmm1_grad_gemm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm1_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({bmm2_grad_gemm2_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_output_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({scratch_buffer_, MemoryAccess::kWrite}); - buffer_usage.push_back({d_bmm1_lhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm1_rhs_buffer_, MemoryAccess::kRead}); - buffer_usage.push_back({d_bmm2_rhs_buffer_, MemoryAccess::kRead}); - - if (d_s_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_s_buffer_, MemoryAccess::kRead}); - }; - if (d_bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({d_bias_buffer_, MemoryAccess::kRead}); - }; - if (fwd_output_buffer_.allocation() != nullptr) { - buffer_usage.push_back({fwd_output_buffer_, MemoryAccess::kRead}); - }; - if (bias_buffer_.allocation() != nullptr) { - buffer_usage.push_back({bias_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_q_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_q_buffer_, MemoryAccess::kRead}); - }; - if (seqlen_k_buffer_.allocation() != nullptr) { - buffer_usage.push_back({seqlen_k_buffer_, MemoryAccess::kRead}); - }; - - return buffer_usage; -} - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// @@ -1836,8 +1527,9 @@ absl::Status CustomCallCmd::RecordXlaFfiCall( execute_params.command_buffer_trace_stream, [&](se::Stream* stream) { ffi::CallOptions options = { execute_params.buffer_allocations->device_ordinal(), - execute_params.stream, - execute_params.buffer_allocations->memory_allocator(), + ffi::CallOptions::GpuOptions{ + execute_params.stream, + execute_params.buffer_allocations->memory_allocator()}, /*called_computation=*/nullptr, // TODO(b/342285364) execute_params.ffi_execution_context}; return ffi::Call(handler_, call_frame, options); @@ -1920,30 +1612,16 @@ absl::Status CollectiveCmd::BarrierIfAsync( absl::Status CollectiveCmd::Prepare( const Thunk::PrepareParams& params, Thunk::ResourceRequests& resource_requests) { - const Thunk::CollectiveExecuteParams* collectives = params.collective_params; - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(collectives->global_device_id, - *collectives->device_assn, + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + GetNumLocalParticipants(*params.collective_params, config().replica_groups, config().group_mode)); - - std::vector local_devices; - if (collectives->global_device_id_map) { - local_devices.reserve(collectives->global_device_id_map->size()); - for (const auto& entry : *collectives->global_device_id_map) { - local_devices.push_back(entry.second); - } - } - - size_t num_local_participants = GetNumLocalParticipants( - participants, - collectives->global_device_id_map ? &local_devices : nullptr); - - return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), nccl_stream_id(), - GetAsyncStreamKind()), - num_local_participants); + return resource_requests.AddClique(clique_key, num_local_participants); } absl::Status CollectiveCmd::AddTracedCommandBuffer( diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h index b7a077e81a9e4f..27e8fea0d86366 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd.h @@ -40,7 +40,6 @@ limitations under the License. #include "xla/service/buffer_assignment.h" #include "xla/service/collective_ops_utils.h" #include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" #include "xla/service/gpu/kernels/custom_kernel.h" #include "xla/service/gpu/launch_dimensions.h" #include "xla/service/gpu/matmul_utils.h" @@ -81,8 +80,6 @@ namespace xla::gpu { V(kReduceScatter, "ReduceScatterCmd") \ V(kAllGatherCmd, "AllGatherCmd") \ V(kCollectiveBroadcastCmd, "CollectiveBroadcastCmd") \ - V(kFusedMHACmd, "FusedMHACmd") \ - V(kFusedMHABackwardCmd, "FusedMHABackwardCmd") \ V(kUnknownCmd, "UnknownCmd") \ // clang-format on @@ -782,112 +779,6 @@ class GemmCmd : public TracedCommandBufferCmd { const bool deterministic_; }; -//===----------------------------------------------------------------------===// -// FusedMHACmd -//===----------------------------------------------------------------------===// - -class FusedMHACmd : public TracedCommandBufferCmd { - public: - FusedMHACmd(ExecutionStreamId execution_stream_id, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, - BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHAConfig config_; - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - -//===----------------------------------------------------------------------===// -// FusedMHABackwardCmd -//===----------------------------------------------------------------------===// - -class FusedMHABackwardCmd : public TracedCommandBufferCmd { - public: - FusedMHABackwardCmd( - ExecutionStreamId execution_stream_id, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice d_bias, BufferAllocation::Slice fwd_output, - BufferAllocation::Slice bias, BufferAllocation::Slice seqlen_q, - BufferAllocation::Slice seqlen_k); - - absl::Status Initialize(const Thunk::InitializeParams& params, - StateManager& state) override; - - absl::Status Record(const Thunk::ExecuteParams& execute_params, - const RecordParams& record_params, - se::CommandBuffer* command_buffer) override; - - BufferUsageVector buffers() override; - - bool IsNestedCommandBuffer() const final { return true; } - - private: - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - const GpufMHABackwardConfig config_; - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - // FusedMHA config - absl::Mutex mutex_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mutex_); -}; - //===----------------------------------------------------------------------===// // CublasLtCmd //===----------------------------------------------------------------------===// diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc index 54e01fab8e1109..230d050856fcc2 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_emitter.cc @@ -29,7 +29,6 @@ limitations under the License. #include "xla/service/gpu/runtime/copy_thunk.h" #include "xla/service/gpu/runtime/cudnn_thunk.h" #include "xla/service/gpu/runtime/custom_call_thunk.h" -#include "xla/service/gpu/runtime/fused_mha_thunk.h" #include "xla/service/gpu/runtime/gemm_thunk.h" #include "xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h" #include "xla/service/gpu/runtime/kernel_thunk.h" @@ -143,27 +142,6 @@ static absl::StatusOr Convert(const CublasLtMatmulThunk& thunk) { thunk.workspace().value()); } -static absl::StatusOr Convert(const FusedMHAThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), thunk.lhs_bmm1_buffer(), - thunk.rhs_bmm1_buffer(), thunk.rhs_bmm2_buffer(), thunk.output_buffer(), - thunk.scratch_buffer(), BufferAllocation::Slice(), thunk.bias_buffer(), - thunk.activation_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - -static absl::StatusOr Convert(const FusedMHABackwardThunk& thunk) { - return std::make_unique( - thunk.execution_stream_id(), thunk.config(), - thunk.bmm1_grad_gemm1_rhs_buffer(), thunk.bmm1_grad_gemm2_rhs_buffer(), - thunk.bmm2_grad_gemm1_lhs_buffer(), thunk.bmm2_grad_gemm2_rhs_buffer(), - thunk.d_output_buffer(), thunk.scratch_buffer(), - thunk.d_bmm1_lhs_buffer(), thunk.d_bmm1_rhs_buffer(), - thunk.d_bmm2_rhs_buffer(), thunk.d_s_buffer(), thunk.d_bias_buffer(), - thunk.fwd_output_buffer(), thunk.bias_buffer(), thunk.seqlen_q_buffer(), - thunk.seqlen_k_buffer()); -} - static absl::StatusOr Convert( const ConditionalThunk& thunk, CommandBufferCmdSequence::SynchronizationMode synchronization_mode) { @@ -276,10 +254,6 @@ static absl::Status AppendCommands( return append(Convert(thunk)); case Thunk::Kind::kCustomKernel: return append(Convert(thunk)); - case Thunk::Kind::kFusedMHA: - return append(Convert(thunk)); - case Thunk::Kind::kFusedMHABackward: - return append(Convert(thunk)); case Thunk::Kind::kKernel: return append(Convert(thunk)); case Thunk::Kind::kGemm: diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc index 22d586775e5ab7..90b6e0666c8adf 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_cmd_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/functional/function_ref.h" #include "absl/status/status.h" #include "absl/strings/ascii.h" +#include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" @@ -30,13 +31,13 @@ limitations under the License. #include "xla/service/service_executable_run_options.h" #include "xla/stream_executor/command_buffer.h" #include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -235,7 +236,7 @@ TEST(CommandBufferCmdTest, MemcpyCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); @@ -306,7 +307,7 @@ TEST(CommandBufferCmdTest, BarrierCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy data back to host, correct executor order should populate all buffers // with expected value. @@ -352,20 +353,15 @@ TEST(CommandBufferCmdTest, LaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); // Initialize command sequence and load device kernels. - Thunk::ExecutableSource source = { -#if defined(GOOGLE_CUDA) - /*text=*/se::gpu::internal::kAddI32Kernel, - /*binary=*/{} -#elif defined(TENSORFLOW_USE_ROCM) - /*text=*/{}, - /*binary=*/se::gpu::internal::kAddI32KernelModule -#endif - }; + TF_ASSERT_OK_AND_ASSIGN(std::vector fatbin, + se::gpu::GetGpuTestKernelsFatbin()); + Thunk::ExecutableSource source = {/*text=*/{}, + /*binary=*/fatbin}; CommandBufferCmd::StateManager state; TF_ASSERT_OK(commands.Initialize({executor, source}, state)); @@ -384,7 +380,7 @@ TEST(CommandBufferCmdTest, LaunchCmd) { TF_ASSERT_OK(commands.Record(params, record_params, command_buffer.get())); // Execute command buffer and verify that it copied the memory. - TF_ASSERT_OK(executor->Submit(stream.get(), *command_buffer)); + TF_ASSERT_OK(command_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc index 42d14071fcf4e1..c7a7117a86c6b4 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk.cc @@ -256,7 +256,7 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { {"num_executions", cmd_buffer->num_executions}}); }); - return executor->Submit(params.stream, *cmd_buffer->command_buffer); + return cmd_buffer->command_buffer->Submit(params.stream); } absl::StatusOr> diff --git a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc index 9146213d72fe89..bce9d1927d05ea 100644 --- a/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/command_buffer_thunk_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include #include #include +#include #include // NOLINT #include #include #include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/types/span.h" #include "xla/service/buffer_assignment.h" #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/launch_dimensions.h" @@ -41,6 +43,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/device_memory_allocator.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" @@ -48,9 +51,9 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" #include "tsl/profiler/lib/profiler_lock.h" @@ -64,27 +67,29 @@ namespace xla::gpu { using MemoryAccess = CommandBufferCmd::MemoryAccess; using KernelArgsPacking = se::MultiKernelLoaderSpec::KernelArgsPacking; -static se::StreamExecutor* GpuExecutor() { +namespace { +se::StreamExecutor* GpuExecutor() { auto name = absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()); auto* platform = se::PlatformManager::PlatformWithName(name).value(); return platform->ExecutorForDevice(0).value(); } -static Thunk::ExecutableSource ExecutableSource() { - Thunk::ExecutableSource source = { -#if defined(GOOGLE_CUDA) - /*text=*/se::gpu::internal::kAddI32Kernel, - /*binary=*/{} -#elif defined(TENSORFLOW_USE_ROCM) - /*text=*/{}, - /*binary=*/se::gpu::internal::kAddI32KernelModule -#endif - }; - return source; +struct OwningExecutableSource { + std::string text; + std::vector binary; + + explicit operator Thunk::ExecutableSource() const { return {text, binary}; } +}; + +absl::StatusOr ExecutableSource() { + TF_ASSIGN_OR_RETURN(std::vector fatbin, + se::gpu::GetGpuTestKernelsFatbin()); + return OwningExecutableSource{/*text=*/{}, + /*binary=*/fatbin}; } -static KernelArgsPacking CreateDefaultArgsPacking() { +KernelArgsPacking CreateDefaultArgsPacking() { using Packed = absl::StatusOr>; return [=](const se::Kernel& kernel, const se::KernelArgs& args) -> Packed { @@ -96,7 +101,7 @@ static KernelArgsPacking CreateDefaultArgsPacking() { } // Some of the tests rely on CUDA 12.3+ features. -static bool IsAtLeastCuda12300() { +bool IsAtLeastCuda12300() { #if defined(TENSORFLOW_USE_ROCM) return false; #endif @@ -107,8 +112,9 @@ static bool IsAtLeastCuda12300() { } // Give a short aliases to execution threads. -static constexpr auto s0 = ExecutionStreamId(0); -static constexpr auto s1 = ExecutionStreamId(1); +constexpr auto s0 = ExecutionStreamId(0); +constexpr auto s1 = ExecutionStreamId(1); +} // namespace TEST(CommandBufferThunkTest, MemcpyCmd) { se::StreamExecutor* executor = GpuExecutor(); @@ -428,7 +434,7 @@ TEST(CommandBufferThunkTest, LaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -442,9 +448,10 @@ TEST(CommandBufferThunkTest, LaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -498,7 +505,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { spec.AddInProcessSymbol(se::gpu::internal::GetAddI32Kernel(), "add"); auto custom_kernel = - CustomKernel("add", std::move(spec), se::BlockDim(), + CustomKernel("AddI32", std::move(spec), se::BlockDim(), se::ThreadDim(4, 1, 1), /*shared_memory_bytes=*/0); int64_t length = 4; @@ -524,7 +531,7 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -538,9 +545,10 @@ TEST(CommandBufferThunkTest, CustomAddKernelLaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -880,10 +888,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { // Prepare commands sequence for constructing command buffer. CommandBufferCmdSequence commands; - commands.Emplace(s0, "add", args, args_access, + commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); - commands.Emplace(s0, "add", args_1, args_access, + commands.Emplace(s0, "AddI32", args_1, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -897,9 +905,10 @@ TEST(CommandBufferThunkTest, MultipleLaunchCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -994,7 +1003,7 @@ TEST(CommandBufferThunkTest, IfCmd) { // Prepare commands sequence for `then` branch. CommandBufferCmdSequence then_commands; - then_commands.Emplace(s0, "add", args, args_access, + then_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -1012,9 +1021,10 @@ TEST(CommandBufferThunkTest, IfCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1084,14 +1094,14 @@ TEST(CommandBufferThunkTest, IfElseCmd) { { // Then: b = a + a auto args = {slice_a, slice_a, slice_b}; - then_commands.Emplace(s0, "add", args, args_access, + then_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } { // Else: b = b + b auto args = {slice_b, slice_b, slice_b}; - else_commands.Emplace(s0, "add", args, args_access, + else_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } @@ -1111,9 +1121,10 @@ TEST(CommandBufferThunkTest, IfElseCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1174,14 +1185,14 @@ TEST(CommandBufferThunkTest, CaseCmd) { { // Case 0: b = a + a auto args = {slice_a, slice_a, slice_b}; - branches[0].Emplace(s0, "add", args, args_access, + branches[0].Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } { // Case 1: b = b + b auto args = {slice_b, slice_b, slice_b}; - branches[1].Emplace(s0, "add", args, args_access, + branches[1].Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); } @@ -1200,9 +1211,10 @@ TEST(CommandBufferThunkTest, CaseCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); @@ -1260,7 +1272,7 @@ TEST(CommandBufferThunkTest, ForCmd) { // Prepare commands sequence for loop `body`. CommandBufferCmdSequence body_commands; - body_commands.Emplace(s0, "add", args, args_access, + body_commands.Emplace(s0, "AddI32", args, args_access, LaunchDimensions(1, 4), /*shmem_bytes=*/0); @@ -1279,9 +1291,10 @@ TEST(CommandBufferThunkTest, ForCmd) { Thunk::ExecuteParams params = Thunk::ExecuteParams::Create( run_options, allocations, stream.get(), stream.get(), nullptr, nullptr); - Thunk::ExecutableSource source = ExecutableSource(); + TF_ASSERT_OK_AND_ASSIGN(OwningExecutableSource source, ExecutableSource()); TF_ASSERT_OK( - thunk.Initialize({executor, source, &allocations, stream.get()})); + thunk.Initialize({executor, static_cast(source), + &allocations, stream.get()})); // Execute command buffer thunk and verify that it added the value 10 times. TF_ASSERT_OK(thunk.ExecuteOnStream(params)); diff --git a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc index e994facfbd9817..f77653e17d72c0 100644 --- a/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/custom_call_thunk.cc @@ -198,9 +198,9 @@ absl::Status CustomCallThunk::ExecuteFfiHandler( builder.AddAttributes(attrs.Build()); CallFrame call_frame = builder.Build(); - CallOptions options = {device_ordinal, stream, - allocator, called_computation_, - execution_context, execution_state_.get()}; + CallOptions options = { + device_ordinal, CallOptions::GpuOptions{stream, allocator}, + called_computation_, execution_context, execution_state_.get()}; return Call(handler, call_frame, options, stage); } diff --git a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc index 9f42ad2efc69f7..75c700b75b4e34 100644 --- a/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/dynamic_slice_thunk_test.cc @@ -47,8 +47,8 @@ limitations under the License. #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" // IWYU pragma: keep -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc index 604bce14592727..eda2bc6e2c3462 100644 --- a/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc +++ b/third_party/xla/xla/service/gpu/runtime/for_all_thunks.cc @@ -74,8 +74,6 @@ void ForAllThunks(absl::FunctionRef fn, case Thunk::kCustomKernel: case Thunk::kCuDnn: case Thunk::kFft: - case Thunk::kFusedMHA: - case Thunk::kFusedMHABackward: case Thunk::kGemm: case Thunk::kInfeed: case Thunk::kKernel: diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc deleted file mode 100644 index ee13689fbbb578..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.cc +++ /dev/null @@ -1,230 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/runtime/fused_mha_thunk.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/buffer_allocations.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/stream_executor/device_memory.h" -#include "xla/stream_executor/lazy_op_runner.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/util.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -FusedMHAThunk::FusedMHAThunk( - ThunkInfo thunk_info, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1, BufferAllocation::Slice rhs_bmm1, - BufferAllocation::Slice rhs_bmm2, BufferAllocation::Slice output, - BufferAllocation::Slice scratch, BufferAllocation::Slice mask, - BufferAllocation::Slice bias, BufferAllocation::Slice activation, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : Thunk(Kind::kFusedMHA, thunk_info), - lhs_bmm1_buffer_(lhs_bmm1), - rhs_bmm1_buffer_(rhs_bmm1), - rhs_bmm2_buffer_(rhs_bmm2), - output_buffer_(output), - scratch_buffer_(scratch), - bias_buffer_(bias), - activation_buffer_(activation), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k), - config_(std::move(config)) {} - -FusedMultiHeadedAttentionRunner& FusedMHAThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -std::optional AssignBufferIfNotNull( - const BufferAllocations& buffer_allocations, - BufferAllocation::Slice& slice) { - return slice.allocation() != nullptr - ? std::optional{buffer_allocations - .GetDeviceAddress(slice)} - : std::nullopt; -} - -absl::Status FusedMHAThunk::Initialize(const InitializeParams& params) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.stream).AsFusedMHARunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHAOpConfig()); - return lazy_runner->GetOrCreateRunner(config, params.stream).status(); -} - -absl::Status FusedMHAThunk::ExecuteOnStream(const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - se::DeviceMemoryBase lhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(lhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm1_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm1_buffer_); - se::DeviceMemoryBase rhs_bmm2_buffer = - buffer_allocations.GetDeviceAddress(rhs_bmm2_buffer_); - se::DeviceMemoryBase output_buffer = - buffer_allocations.GetDeviceAddress(output_buffer_); - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional activation_buffer = - AssignBufferIfNotNull(buffer_allocations, activation_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - RunFusedMHAOptions opts; - opts.runner_cache = &GetOrCreateRunner(params.stream); - TF_RETURN_IF_ERROR(RunGpuFMHA(config_, lhs_bmm1_buffer, rhs_bmm1_buffer, - rhs_bmm2_buffer, output_buffer, scratch_buffer, - bias_buffer, activation_buffer, seqlen_q_buffer, - seqlen_k_buffer, params.stream, opts)); - - if (!params.stream->ok()) { - return Internal("FusedMHAThunk::ExecuteOnStream failed."); - } - return absl::OkStatus(); -} -FusedMHABackwardThunk::FusedMHABackwardThunk( - ThunkInfo thunk_info, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs, - BufferAllocation::Slice bmm1_grad_gemm2_rhs, - BufferAllocation::Slice bmm2_grad_gemm1_lhs, - BufferAllocation::Slice bmm2_grad_gemm2_rhs, - BufferAllocation::Slice d_output, BufferAllocation::Slice scratch, - BufferAllocation::Slice d_bmm1_lhs, BufferAllocation::Slice d_bmm1_rhs, - BufferAllocation::Slice d_bmm2_rhs, BufferAllocation::Slice d_s, - BufferAllocation::Slice mask, BufferAllocation::Slice d_bias, - BufferAllocation::Slice fwd_output, BufferAllocation::Slice bias, - BufferAllocation::Slice seqlen_q, BufferAllocation::Slice seqlen_k) - : Thunk(Kind::kFusedMHABackward, thunk_info), - bmm1_grad_gemm1_rhs_buffer_(bmm1_grad_gemm1_rhs), - bmm1_grad_gemm2_rhs_buffer_(bmm1_grad_gemm2_rhs), - bmm2_grad_gemm1_lhs_buffer_(bmm2_grad_gemm1_lhs), - bmm2_grad_gemm2_rhs_buffer_(bmm2_grad_gemm2_rhs), - d_output_buffer_(d_output), - scratch_buffer_(scratch), - d_bmm1_lhs_buffer_(d_bmm1_lhs), - d_bmm1_rhs_buffer_(d_bmm1_rhs), - d_bmm2_rhs_buffer_(d_bmm2_rhs), - d_s_buffer_(d_s), - d_bias_buffer_(d_bias), - fwd_output_buffer_(fwd_output), - bias_buffer_(bias), - seqlen_q_buffer_(seqlen_q), - seqlen_k_buffer_(seqlen_k), - config_(std::move(config)) {} - -FusedMultiHeadedAttentionBackwardRunner& -FusedMHABackwardThunk::GetOrCreateRunner( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&mu_); - auto it = runner_cache_.find(stream); - if (it == runner_cache_.end()) { - it = runner_cache_ - .insert({stream, - std::make_unique( - config_)}) - .first; - } - return *it->second; -} - -absl::Status FusedMHABackwardThunk::Initialize(const InitializeParams& params) { - se::dnn::LazyOpRunner* lazy_runner = - GetOrCreateRunner(params.stream).AsFusedMHABackwardRunner(); - TF_ASSIGN_OR_RETURN(auto config, config_.AsDnnFusedMHABackwardOpConfig()); - return lazy_runner->GetOrCreateRunner(config, params.stream).status(); -} - -absl::Status FusedMHABackwardThunk::ExecuteOnStream( - const ExecuteParams& params) { - const auto& buffer_allocations = *params.buffer_allocations; - se::DeviceMemoryBase bmm1_grad_gemm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm1_rhs_buffer_); - - se::DeviceMemoryBase bmm1_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm1_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm1_lhs_buffer_); - - se::DeviceMemoryBase bmm2_grad_gemm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(bmm2_grad_gemm2_rhs_buffer_); - - se::DeviceMemoryBase d_output_buffer = - buffer_allocations.GetDeviceAddress(d_output_buffer_); - - se::DeviceMemoryBase scratch_buffer = - buffer_allocations.GetDeviceAddress(scratch_buffer_); - - se::DeviceMemoryBase d_bmm1_lhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_lhs_buffer_); - - se::DeviceMemoryBase d_bmm1_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm1_rhs_buffer_); - - se::DeviceMemoryBase d_bmm2_rhs_buffer = - buffer_allocations.GetDeviceAddress(d_bmm2_rhs_buffer_); - - std::optional d_s_buffer = - AssignBufferIfNotNull(buffer_allocations, d_s_buffer_); - std::optional d_bias_buffer = - AssignBufferIfNotNull(buffer_allocations, d_bias_buffer_); - std::optional fwd_output_buffer = - AssignBufferIfNotNull(buffer_allocations, fwd_output_buffer_); - std::optional bias_buffer = - AssignBufferIfNotNull(buffer_allocations, bias_buffer_); - std::optional seqlen_q_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_q_buffer_); - std::optional seqlen_k_buffer = - AssignBufferIfNotNull(buffer_allocations, seqlen_k_buffer_); - RunFusedMHABackwardOptions opts; - - opts.runner_cache = &GetOrCreateRunner(params.stream); - - TF_RETURN_IF_ERROR(RunGpuFMHABackward( - config_, bmm1_grad_gemm1_rhs_buffer, bmm1_grad_gemm2_rhs_buffer, - bmm2_grad_gemm1_lhs_buffer, bmm2_grad_gemm2_rhs_buffer, d_output_buffer, - scratch_buffer, d_bmm1_lhs_buffer, d_bmm1_rhs_buffer, d_bmm2_rhs_buffer, - d_s_buffer, d_bias_buffer, fwd_output_buffer, bias_buffer, - seqlen_q_buffer, seqlen_k_buffer, params.stream, opts)); - if (!params.stream->ok()) { - return Internal("FusedMHABackwardThunk::ExecuteOnStream failed."); - } - return absl::OkStatus(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h b/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h deleted file mode 100644 index 99a8327269499e..00000000000000 --- a/third_party/xla/xla/service/gpu/runtime/fused_mha_thunk.h +++ /dev/null @@ -1,184 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ -#define XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ - -#include - -#include "absl/base/thread_annotations.h" -#include "absl/container/flat_hash_map.h" -#include "absl/status/status.h" -#include "absl/synchronization/mutex.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/gpu_fused_mha_runner.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/stream_executor/stream_executor.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -// This class stores everything that StreamExecutor needs to launch a DNN -// fMHA. It is generated by IrEmitter. -// -// This is thread-compatible. -class FusedMHAThunk : public Thunk { - public: - // Constructs a thunk for launching a DNN FMHA. - FusedMHAThunk(ThunkInfo thunk_info, GpufMHAConfig config, - BufferAllocation::Slice lhs_bmm1_slice, - BufferAllocation::Slice rhs_bmm1_slice, - BufferAllocation::Slice rhs_bmm2_slice, - BufferAllocation::Slice output_slice, - BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice mask_slice, /* may be null */ - BufferAllocation::Slice bias_slice /* may be null */, - BufferAllocation::Slice activation_slice /* may be null */, - BufferAllocation::Slice seqlen_q_slice /* may be null */, - BufferAllocation::Slice seqlen_k_slice /* may be null */); - - FusedMHAThunk(const FusedMHAThunk&) = delete; - FusedMHAThunk& operator=(const FusedMHAThunk&) = delete; - - BufferAllocation::Slice lhs_bmm1_buffer() const { return lhs_bmm1_buffer_; } - BufferAllocation::Slice rhs_bmm1_buffer() const { return rhs_bmm1_buffer_; } - BufferAllocation::Slice rhs_bmm2_buffer() const { return rhs_bmm2_buffer_; } - BufferAllocation::Slice output_buffer() const { return output_buffer_; } - BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; } - BufferAllocation::Slice bias_buffer() const { return bias_buffer_; } - BufferAllocation::Slice activation_buffer() const { - return activation_buffer_; - } - BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; } - BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; } - - GpufMHAConfig config() const { return config_; } - absl::Status Initialize(const InitializeParams& params) override; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - BufferAllocation::Slice lhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm1_buffer_; - BufferAllocation::Slice rhs_bmm2_buffer_; - BufferAllocation::Slice output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice activation_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - FusedMultiHeadedAttentionRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - // FusedMHA config - const GpufMHAConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; - -class FusedMHABackwardThunk : public Thunk { - public: - // Constructs a thunk for launching a DNN FMHA backward. - FusedMHABackwardThunk(ThunkInfo thunk_info, GpufMHABackwardConfig config, - BufferAllocation::Slice bmm1_grad_gemm1_rhs_slice, - BufferAllocation::Slice bmm1_grad_gemm2_rhs_slice, - BufferAllocation::Slice bmm2_grad_gemm1_lhs_slice, - BufferAllocation::Slice bmm2_grad_gemm2_rhs_slice, - BufferAllocation::Slice d_output_slice, - BufferAllocation::Slice scratch_slice, - BufferAllocation::Slice d_bmm1_lhs_slice, - BufferAllocation::Slice d_bmm1_rhs_slice, - BufferAllocation::Slice d_bmm2_rhs_slice, - BufferAllocation::Slice d_s_slice, - BufferAllocation::Slice mask_slice, - BufferAllocation::Slice d_bias_slice, - BufferAllocation::Slice fwd_output_slice, - BufferAllocation::Slice bias_slice, - BufferAllocation::Slice seqlen_q_slice, - BufferAllocation::Slice seqlen_k_slice); - - FusedMHABackwardThunk(const FusedMHABackwardThunk&) = delete; - FusedMHABackwardThunk& operator=(const FusedMHABackwardThunk&) = delete; - - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer() const { - return bmm1_grad_gemm1_rhs_buffer_; - } - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer() const { - return bmm1_grad_gemm2_rhs_buffer_; - } - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer() const { - return bmm2_grad_gemm1_lhs_buffer_; - } - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer() const { - return bmm2_grad_gemm2_rhs_buffer_; - } - BufferAllocation::Slice d_output_buffer() const { return d_output_buffer_; } - BufferAllocation::Slice scratch_buffer() const { return scratch_buffer_; } - BufferAllocation::Slice d_bmm1_lhs_buffer() const { - return d_bmm1_lhs_buffer_; - } - BufferAllocation::Slice d_bmm1_rhs_buffer() const { - return d_bmm1_rhs_buffer_; - } - BufferAllocation::Slice d_bmm2_rhs_buffer() const { - return d_bmm2_rhs_buffer_; - } - BufferAllocation::Slice d_s_buffer() const { return d_s_buffer_; } - BufferAllocation::Slice d_bias_buffer() const { return d_bias_buffer_; } - BufferAllocation::Slice fwd_output_buffer() const { - return fwd_output_buffer_; - } - BufferAllocation::Slice bias_buffer() const { return bias_buffer_; } - BufferAllocation::Slice seqlen_q_buffer() const { return seqlen_q_buffer_; } - BufferAllocation::Slice seqlen_k_buffer() const { return seqlen_k_buffer_; } - - GpufMHABackwardConfig config() const { return config_; } - - absl::Status Initialize(const InitializeParams& params) override; - absl::Status ExecuteOnStream(const ExecuteParams& params) override; - - private: - BufferAllocation::Slice bmm1_grad_gemm1_rhs_buffer_; - BufferAllocation::Slice bmm1_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm1_lhs_buffer_; - BufferAllocation::Slice bmm2_grad_gemm2_rhs_buffer_; - BufferAllocation::Slice d_output_buffer_; - BufferAllocation::Slice scratch_buffer_; - BufferAllocation::Slice d_bmm1_lhs_buffer_; - BufferAllocation::Slice d_bmm1_rhs_buffer_; - BufferAllocation::Slice d_bmm2_rhs_buffer_; - BufferAllocation::Slice d_s_buffer_; - BufferAllocation::Slice d_bias_buffer_; - BufferAllocation::Slice fwd_output_buffer_; - BufferAllocation::Slice bias_buffer_; - BufferAllocation::Slice seqlen_q_buffer_; - BufferAllocation::Slice seqlen_k_buffer_; - - FusedMultiHeadedAttentionBackwardRunner& GetOrCreateRunner( - const stream_executor::Stream* stream); - - // FusedMHA backward config - const GpufMHABackwardConfig config_; - absl::Mutex mu_; - absl::flat_hash_map> - runner_cache_ ABSL_GUARDED_BY(mu_); -}; -} // namespace gpu -} // namespace xla -#endif // XLA_SERVICE_GPU_RUNTIME_FUSED_MHA_THUNK_H_ diff --git a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc index 3ea5a010658af4..a26de45ddaa853 100644 --- a/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/kernel_thunk.cc @@ -37,7 +37,6 @@ limitations under the License. #include "xla/service/gpu/stream_executor_util.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" @@ -187,9 +186,9 @@ absl::Status CustomKernelThunk::Initialize(const InitializeParams& params) { auto it = kernel_cache_.find(params.executor); if (kernel_cache_.end() == it) { - TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - se::KernelFactory::Create( - params.executor, custom_kernel_.kernel_spec())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr kernel, + params.executor->LoadKernel(custom_kernel_.kernel_spec())); kernel_cache_.emplace(params.executor, std::move(kernel)); } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc index 783bd2ddaddee4..f790956d8525af 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_api.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_api.cc @@ -451,6 +451,7 @@ absl::StatusOr> DefaultNcclApi::CommSplit( TF_RETURN_IF_ERROR(GroupEnd()); std::vector split_comms; + split_comms.reserve(split_comms_handles.size()); for (size_t i = 0; i < split_comms_handles.size(); ++i) { split_comms.emplace_back(Cast(split_comms_handles[i]), NcclCommDeleter{this}); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc index a7b068c4a9a0b4..9bbc6f4019eab1 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.cc @@ -27,6 +27,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/service/global_device_id.h" @@ -36,12 +37,14 @@ namespace xla::gpu { // NcclCliqueKey //===----------------------------------------------------------------------===// -NcclCliqueKey::NcclCliqueKey(std::vector devices, - NcclStreamId stream_id, - AsyncStreamKind stream_kind) +NcclCliqueKey::NcclCliqueKey( + std::vector devices, NcclStreamId stream_id, + AsyncStreamKind stream_kind, + std::vector> participant_groups) : devices_(std::move(devices)), stream_id_(stream_id), - stream_kind_(stream_kind) {} + stream_kind_(stream_kind), + participant_groups_(std::move(participant_groups)) {} absl::Span NcclCliqueKey::devices() const { return devices_; @@ -64,12 +67,23 @@ bool NcclCliqueKey::IsSubsetOf(const NcclCliqueKey& other) const { } std::string NcclCliqueKey::ToString() const { - return absl::StrFormat("devices=[%s]; stream=%d", - GlobalDeviceIdsToString(devices_), stream_id_.value()); + std::string group_string = ""; + if (!participant_groups_.empty()) { + std::vector values; + values.reserve(participant_groups_.size()); + for (const auto& group : participant_groups_) { + values.push_back("[" + GlobalDeviceIdsToString(group) + "]"); + } + group_string = absl::StrFormat("; groups=[%s]", absl::StrJoin(values, ",")); + } + return absl::StrFormat("devices=[%s]; stream=%d%s", + GlobalDeviceIdsToString(devices_), stream_id_.value(), + group_string); } bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { - return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_; + return a.devices_ == b.devices_ && a.stream_id_ == b.stream_id_ && + a.participant_groups_ == b.participant_groups_; } bool operator<(const NcclCliqueKey& a, const NcclCliqueKey& b) { diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h index 56c9b81f81e2ba..0946ce62ef7275 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key.h @@ -82,7 +82,8 @@ class NcclCliqueKey { explicit NcclCliqueKey( std::vector devices, NcclStreamId stream_id = NcclStreamId(0), - AsyncStreamKind stream_kind = AsyncStreamKind::kCollective); + AsyncStreamKind stream_kind = AsyncStreamKind::kCollective, + std::vector> participant_groups = {}); absl::Span devices() const; @@ -113,11 +114,23 @@ class NcclCliqueKey { std::vector devices_; NcclStreamId stream_id_; AsyncStreamKind stream_kind_; + // The full list of groups across all devices which this clique is a part of. + // When enable_nccl_comm_splitting is enabled, this is used to distinguish + // which cliques can be reused from the cache or must be split in order to + // prevent a deadlock situation. + // For example, imagine we have a communicator with devices = [0,1] and groups + // = [0, 1] Later on, we may want to create communicators [0, 1] and [2, 3] by + // splitting [0, 1, 2, 3] If ranks 0 and 1 reuse the exisiting [0, 1] clique + // but ranks 2 and 3 initiate a split, there will be a deadlock since ranks 2, + // 3 and will be waiting forever for 0, 1 to join the split. Having the + // particating groups as part of the cache key will prevent such situations + std::vector> participant_groups_; }; template H AbslHashValue(H h, const NcclCliqueKey& k) { - return H::combine(std::move(h), k.devices_, k.stream_id_); + return H::combine(std::move(h), k.devices_, k.stream_id_, + k.participant_groups_); } bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc index 4346f544db20bc..c72c5115252865 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_clique_key_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "absl/container/btree_map.h" #include "xla/service/global_device_id.h" @@ -53,6 +54,26 @@ TEST(NcclCliqueKeyTest, Compare) { EXPECT_GT(key1, key0); } +TEST(NcclCliqueKeyTest, CompareWithParticipantGroups) { + GlobalDeviceId id0 = GlobalDeviceId(0); + GlobalDeviceId id1 = GlobalDeviceId(1); + GlobalDeviceId id2 = GlobalDeviceId(2); + GlobalDeviceId id3 = GlobalDeviceId(3); + + // The keys are not equal because the replica groups are different. + NcclCliqueKey key0({id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}}); + NcclCliqueKey key1( + {id0, id1}, NcclStreamId(0), AsyncStreamKind::kCollective, + std::vector>{{id0, id1}, {id2, id3}}); + EXPECT_FALSE(key0 == key1); + + // With no replica groups, the keys are equal + NcclCliqueKey key0_nogroups({id0, id1}, NcclStreamId(0)); + NcclCliqueKey key1_nogroups({id0, id1}, NcclStreamId(0)); + EXPECT_EQ(key0_nogroups, key1_nogroups); +} + TEST(NcclCliqueKeyTest, BtreeIterationOrder) { GlobalDeviceId id0 = GlobalDeviceId(0); GlobalDeviceId id1 = GlobalDeviceId(1); diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc index 7582c18c292e72..93b113b3a25627 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.cc @@ -217,7 +217,7 @@ NcclCollectiveThunk::NcclCollectiveThunk(Kind kind, ThunkInfo thunk_info, nccl_api_(nccl_api), async_events_(is_sync ? nullptr : new AsyncEvents()) {} -static absl::StatusOr GetNcclCliqueKey( +absl::StatusOr GetNcclCliqueKey( const Thunk::CollectiveExecuteParams& params, const std::vector& replica_groups, CollectiveOpGroupMode group_mode, NcclStreamId stream_id, @@ -229,6 +229,18 @@ static absl::StatusOr GetNcclCliqueKey( GetParticipatingDevices(global_device_id, *params.device_assn, replica_groups, group_mode)); + // If splitting is enabled, particpating groups must match in order for a + // clique to be reused from the cache. We can ignore the particpating groups + // otherwise. + static const int64_t enable_nccl_comm_splitting = + xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_comm_splitting(); + std::vector> participant_groups; + if (enable_nccl_comm_splitting) { + TF_ASSIGN_OR_RETURN(participant_groups, + GetParticipatingDevicesGroups( + *params.device_assn, replica_groups, group_mode)); + } + if (IsGlobalNcclConfig() && (participants.size() != params.device_assn->replica_count())) { return InvalidArgument( @@ -240,7 +252,7 @@ static absl::StatusOr GetNcclCliqueKey( return NcclCliqueKey(std::move(participants), enable_per_stream_comms ? stream_id : kNoStreamId, - stream_kind); + stream_kind, std::move(participant_groups)); } absl::StatusOr GetNcclComm( @@ -373,33 +385,16 @@ absl::StatusOr NcclCollectiveThunk::AsyncEvents::GetEvent( absl::Status NcclCollectiveThunk::Prepare(const PrepareParams& params, ResourceRequests& resource_requests) { - const CollectiveExecuteParams* collectives = params.collective_params; - TF_ASSIGN_OR_RETURN( - std::vector participants, - GetParticipatingDevices(collectives->global_device_id, - *collectives->device_assn, + NcclCliqueKey clique_key, + GetNcclCliqueKey(*params.collective_params, config().replica_groups, + config().group_mode, nccl_stream_id(), + GetAsyncStreamKind())); + TF_ASSIGN_OR_RETURN( + size_t num_local_participants, + GetNumLocalParticipants(*params.collective_params, config().replica_groups, config().group_mode)); - - std::vector local_devices; - if (collectives->global_device_id_map) { - local_devices.reserve(collectives->global_device_id_map->size()); - for (const auto& entry : *collectives->global_device_id_map) { - local_devices.push_back(entry.second); - } - } - - size_t num_local_participants = GetNumLocalParticipants( - participants, - collectives->global_device_id_map ? &local_devices : nullptr); - AsyncStreamKind stream_kind = GetAsyncStreamKind(); - static const bool enable_per_stream_comms = - xla::GetDebugOptionsFromFlags().xla_gpu_enable_nccl_per_stream_comms(); - return resource_requests.AddClique( - NcclCliqueKey(std::move(participants), - enable_per_stream_comms ? nccl_stream_id() : kNoStreamId, - stream_kind), - num_local_participants); + return resource_requests.AddClique(clique_key, num_local_participants); } absl::Status NcclCollectiveThunk::Initialize(const InitializeParams& params) { @@ -537,13 +532,26 @@ absl::Status IsValidOperand(Shape shape, Thunk::Kind reduction_op) { return absl::OkStatus(); } -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices) { - if (local_devices == nullptr) return participants.size(); +absl::StatusOr GetNumLocalParticipants( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode) { + TF_ASSIGN_OR_RETURN( + std::vector participants, + GetParticipatingDevices(params.global_device_id, *params.device_assn, + replica_groups, group_mode)); + if (!params.global_device_id_map) { + return participants.size(); + } + + std::vector local_devices; + local_devices.reserve(params.global_device_id_map->size()); + for (const auto& entry : *params.global_device_id_map) { + local_devices.push_back(entry.second); + } return absl::c_count_if(participants, [&](const GlobalDeviceId& device_id) { - return absl::c_linear_search(*local_devices, device_id); + return absl::c_linear_search(local_devices, device_id); }); } diff --git a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h index ccaffb35c308a8..2a549cdd81f520 100644 --- a/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/nccl_collective_thunk.h @@ -283,9 +283,16 @@ absl::Status AddOpDescription(absl::Status status, OpT op, //===----------------------------------------------------------------------===// -size_t GetNumLocalParticipants( - const std::vector& participants, - const std::vector* local_devices); // may be null +absl::StatusOr GetNcclCliqueKey( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode, NcclStreamId stream_id, + AsyncStreamKind stream_kind); + +absl::StatusOr GetNumLocalParticipants( + const Thunk::CollectiveExecuteParams& params, + const std::vector& replica_groups, + CollectiveOpGroupMode group_mode); // Returns a nccl comm handle and a flag indicating if // it's a local communicator. diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.cc b/third_party/xla/xla/service/gpu/runtime/thunk.cc index fc3c0cff8741c5..6f3081a90eb234 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/thunk.cc @@ -286,8 +286,6 @@ Thunk::ExecuteParams::ExecuteParams( CASE(kSequential); CASE(kTriangularSolve); CASE(kWhile); - CASE(kFusedMHA); - CASE(kFusedMHABackward); CASE(kWaitForStreams); CASE(kCuDnn); } diff --git a/third_party/xla/xla/service/gpu/runtime/thunk.h b/third_party/xla/xla/service/gpu/runtime/thunk.h index 346664976a2d9c..cd26323ee70fea 100644 --- a/third_party/xla/xla/service/gpu/runtime/thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/thunk.h @@ -165,8 +165,6 @@ class Thunk { kSendDone, kTriangularSolve, kWhile, - kFusedMHA, - kFusedMHABackward, kWaitForStreams, kCuDnn }; diff --git a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc index 879ca6faf7c671..33bbac0f90f373 100644 --- a/third_party/xla/xla/service/gpu/runtime_intrinsics.cc +++ b/third_party/xla/xla/service/gpu/runtime_intrinsics.cc @@ -28,10 +28,11 @@ limitations under the License. #include "xla/service/custom_call_target_registry.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_finder.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -51,11 +52,8 @@ absl::Status AssertOnGpu(void* stream_handle, void* buffer, TF_ASSIGN_OR_RETURN( se::Platform * platform, se::PlatformManager::PlatformWithName(GetGpuPlatformName())); - se::StreamExecutorConfig config; - config.gpu_stream = stream_handle; - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - platform->GetExecutor(config)); - se::Stream* stream = executor->FindAllocatedStream(stream_handle); + TF_ASSIGN_OR_RETURN(se::Stream * stream, + stream_executor::FindStream(platform, stream_handle)); if (!stream) { return Internal("Stream not found for: %p", stream_handle); } diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc index 54b29378adec0c..ea5607516d6e3d 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter.cc @@ -40,7 +40,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_query.h" #include "xla/layout.h" #include "xla/literal_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" diff --git a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc index 51013b4411bd1f..8c17196090f3ca 100644 --- a/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/split_k_gemm_rewriter_test.cc @@ -37,9 +37,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/stream_executor_util.cc b/third_party/xla/xla/service/gpu/stream_executor_util.cc index cde9b554bd504d..cd64405d0a8ca8 100644 --- a/third_party/xla/xla/service/gpu/stream_executor_util.cc +++ b/third_party/xla/xla/service/gpu/stream_executor_util.cc @@ -52,7 +52,6 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" @@ -377,7 +376,7 @@ absl::StatusOr> CreateKernel( } TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - se::KernelFactory::Create(stream_exec, loader_spec)); + stream_exec->LoadKernel(loader_spec)); se::KernelMetadata m; m.set_shared_memory_bytes(shared_mem_bytes); @@ -437,7 +436,7 @@ static void InitializeTypedBuffer(se::Stream* stream, // Use a large prime number to fragment the accesses. constexpr int host_buffer_size = 10069; - static std::vector* host_buffer = [] { + static std::vector* host_buffer = [&] { auto* ret = new std::vector(host_buffer_size); // Default-seeded random numbers. std::mt19937 gen; diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 4036bb0dc8e5d2..98d41316fd348b 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -72,7 +72,10 @@ xla_test( srcs = if_gpu_is_configured(["dynamic_slice_fusion_test.cc"]), backends = ["gpu"], local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), - tags = ["notsan"], # TODO(b/345034145): Fix tsan error. + tags = [ + "notsan", # TODO(b/345034145): Fix tsan error. + "no_rocm", # TODO(rocm): sync 24-08-20 + ], deps = if_gpu_is_configured( #keep sorted [ @@ -162,44 +165,6 @@ xla_test( ], ) -xla_cc_test( - name = "gpu_reduce_scatter_creator_test", - srcs = ["gpu_reduce_scatter_creator_test.cc"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu:gpu_reduce_scatter_creator", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings:string_view", - "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "gpu_all_gather_optimizer_test", - srcs = ["gpu_all_gather_optimizer_test.cc"], - deps = [ - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/service/gpu:gpu_all_gather_optimizer", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", - ], -) - xla_test( name = "gpu_spmd_e2e_compile_test", size = "small", @@ -212,69 +177,10 @@ xla_test( "//xla/hlo/utils:hlo_query", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_test( - name = "gemm_rewrite_test", - srcs = ["gemm_rewrite_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - ":gpu_codegen_test", - "//xla:error_spec", - "//xla:test", - "//xla:xla_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service:executable", - "//xla/service:hlo_module_config", - "//xla/service:hlo_pass", - "//xla/service:pattern_matcher", - "//xla/service:pattern_matcher_gmock", - "//xla/service/gpu:gemm_rewriter", - "//xla/service/gpu:gpu_executable", - "//xla/stream_executor:device_description", - "//xla/stream_executor:device_memory_allocator", - "//xla/stream_executor:stream_executor_memory_allocator", - "//xla/tests:filecheck", - "//xla/tests:verified_hlo_module", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test_main", - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", - ]) + if_rocm_is_configured([ - "@local_config_rocm//rocm:rocm_headers", - ]), -) - -xla_test( - name = "gemm_broadcast_folding_rewrite_test", - srcs = ["gemm_broadcast_folding_rewrite_test.cc"], - backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), - deps = [ - ":gpu_codegen_test", - "//xla:error_spec", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gemm_broadcast_folding_rewriter", - "//xla/service/gpu:gemm_rewriter", "@local_tsl//tsl/platform:statusor", - "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", ], ) @@ -297,51 +203,6 @@ xla_test( ], ) -xla_cc_test( - name = "reduction_degenerate_dim_remover_test", - srcs = [ - "reduction_degenerate_dim_remover_test.cc", - ], - deps = [ - "//xla/service/gpu:reduction_degenerate_dim_remover", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_test( - name = "reduction_layout_normalizer_test", - srcs = [ - "reduction_layout_normalizer_test.cc", - ], - backends = ["gpu"], - deps = [ - "//xla:error_spec", - "//xla/service/gpu:reduction_layout_normalizer", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - -xla_cc_test( - name = "tree_reduction_rewriter_test", - srcs = [ - "tree_reduction_rewriter_test.cc", - ], - deps = [ - "//xla/service/gpu:tree_reduction_rewriter", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - xla_test( name = "swap_conv_operands_test", srcs = [ @@ -375,20 +236,6 @@ xla_test( ], ) -xla_cc_test( - name = "reduction_dimension_grouper_test", - srcs = [ - "reduction_dimension_grouper_test.cc", - ], - deps = [ - "//xla/service/gpu:reduction_dimension_grouper", - "//xla/tests:hlo_test_base", - "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) - xla_test( name = "parallel_reduction_test", srcs = [ @@ -552,6 +399,7 @@ xla_test( "gpu_p100", "gpu_amd_any", ] + if_oss(["gpu_any"]), + tags = ["no_rocm"], # TODO(rocm): weekly sync 24-08-20 deps = [ ":gpu_codegen_test", "//xla:error_spec", @@ -640,7 +488,7 @@ xla_test( "//xla/hlo/ir:hlo", "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu:gpu_fusible", - "//xla/service/gpu:instruction_fusion", + "//xla/service/gpu/transforms:instruction_fusion", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", ], @@ -655,10 +503,10 @@ xla_test( "//xla:shape_util", "//xla/service:hlo_cost_analysis", "//xla/service:hlo_pass_pipeline", - "//xla/service/gpu:fusion_merger", "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:instruction_fusion", - "//xla/service/gpu:multi_output_fusion", + "//xla/service/gpu/transforms:fusion_merger", + "//xla/service/gpu/transforms:instruction_fusion", + "//xla/service/gpu/transforms:multi_output_fusion", "//xla/stream_executor:device_description", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:test_main", @@ -780,11 +628,9 @@ lit_test_suite( "copy.hlo", "dot_bf16.hlo", "dynamic_update_slice_inplace.hlo", - "element_wise_row_vectorization.hlo", "fused_scatter.hlo", "fused_slice.hlo", "kernel_reuse.hlo", - "launch_dimensions.hlo", "pad_to_static.hlo", "reduce_atomic_min.hlo", "reduce_column_layout_change.hlo", @@ -828,6 +674,7 @@ lit_test_suite( "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", ], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../../cuda_nvcc", tools = [ "//xla/tools:hlo-opt", "@llvm-project//llvm:FileCheck", @@ -850,8 +697,8 @@ lit_test_suite( # name = "xla-opt", # srcs = ["xla-opt.cc"], # deps = [ -# "//xla/service/gpu/fusions/triton:prevent_mmav3_loop_unrolling", -# "//xla/service/gpu/fusions/triton:sparse_extensions", +# "//xla/service/gpu/fusions/transforms:passes", +# "//xla/service/gpu/fusions/triton:passes", # "@llvm-project//mlir:AllExtensions", # "@llvm-project//mlir:MlirOptLib", # "@triton//:AllPassesAndDialects", @@ -946,7 +793,7 @@ xla_test( "//xla:literal", "//xla:literal_util", "//xla/tests:xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ["@local_tsl//tsl/platform:test_main"], # b/317293391 ), @@ -964,7 +811,7 @@ xla_test( "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_sort_rewriter", + "//xla/service/gpu/transforms:sort_rewriter", "//xla/tests:hlo_test_base", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", @@ -1031,8 +878,8 @@ cc_library( deps = [ "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) diff --git a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo b/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo deleted file mode 100644 index 3e75fceb48f530..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/element_wise_row_vectorization.hlo +++ /dev/null @@ -1,292 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK-LLVM %s -// We check that the row loads are vectorized. - -HloModule SimpleAddRowBroadcasting, is_scheduled=true - -%fused_computation.0 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672]{ - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.0 = f32[512,14,14,672]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0 -} - -// CHECK-LABEL: fusion_0 -// CHECK: .reqntid 168, 1, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -HloModule SimpleAddSmallRowBroadcasting, is_scheduled=true - -%fused_computation.0 (param_0: f32[48], param_1: f32[512,14,14,48]) -> f32[512,14,14,48]{ - %param_0 = f32[48]{0} parameter(0) - %broadcast = f32[512,14,14,48]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,48]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,48]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[48]{0} parameter(0) - %param_1 = f32[512,14,14,48]{3,2,1,0} parameter(1) - - ROOT %fusion.0_small = f32[512,14,14,48]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.0 -} - -// CHECK-LABEL: fusion_0_small -// CHECK: .reqntid 12, 11, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -// This test an BatchNorm fused kernel found in EfficientNet. -HloModule EfficientNetSwish, is_scheduled=true - -%fused_computation.1 (param_0.89: f32[672], param_1: f32[672], param_2: f32[672], param_3: f32[672], param_4: f16[512,14,14,672], param_5: f32[672], param_6: f16[512,14,14,672], param_7: f32[672]) -> f16[512,14,14,672] { - %param_2 = f32[672]{0} parameter(2) - %constant_157 = f32[] constant(1), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.186 = f32[672]{0} broadcast(f32[] %constant_157), dimensions={}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_5 = f32[672]{0} parameter(5) - %constant_56 = f32[] constant(9.96492327e-06) - %broadcast.185 = f32[672]{0} broadcast(f32[] %constant_56), dimensions={} - %multiply.155 = f32[672]{0} multiply(f32[672]{0} %param_5, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %param_3 = f32[672]{0} parameter(3) - %multiply.154 = f32[672]{0} multiply(f32[672]{0} %param_3, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.153 = f32[672]{0} multiply(f32[672]{0} %multiply.154, f32[672]{0} %multiply.154), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %subtract.15 = f32[672]{0} subtract(f32[672]{0} %multiply.155, f32[672]{0} %multiply.153), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %constant_155 = f32[] constant(0.001), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.184 = f32[672]{0} broadcast(f32[] %constant_155), dimensions={} - %add.14 = f32[672]{0} add(f32[672]{0} %subtract.15, f32[672]{0} %broadcast.184), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %rsqrt.23 = f32[672]{0} rsqrt(f32[672]{0} %add.14), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.152 = f32[672]{0} multiply(f32[672]{0} %rsqrt.23, f32[672]{0} %rsqrt.23), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %divide.14 = f32[672]{0} divide(f32[672]{0} %broadcast.186, f32[672]{0} %multiply.152), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %rsqrt.7 = f32[672]{0} rsqrt(f32[672]{0} %divide.14), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.29 = f32[672]{0} multiply(f32[672]{0} %param_2, f32[672]{0} %rsqrt.7), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.28 = f32[672]{0} multiply(f32[672]{0} %multiply.29, f32[672]{0} %broadcast.185), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.47 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.28), dimensions={3} - %param_6 = f16[512,14,14,672]{3,2,1,0} parameter(6) - %constant_194 = f16[] constant(1), metadata={op_type="AddV2" op_name="add"} - %broadcast.256 = f16[512,14,14,672]{3,2,1,0} broadcast(f16[] %constant_194), dimensions={} - %param_4 = f16[512,14,14,672]{3,2,1,0} parameter(4) - %convert.66 = f32[512,14,14,672]{3,2,1,0} convert(f16[512,14,14,672]{3,2,1,0} %param_4), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.254 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %multiply.154), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %subtract.82 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %convert.66, f32[512,14,14,672]{3,2,1,0} %broadcast.254), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.251 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %rsqrt.23), dimensions={3} - %multiply.219 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %subtract.82, f32[512,14,14,672]{3,2,1,0} %broadcast.251), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %broadcast.250 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_2), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %multiply.218 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %multiply.219, f32[512,14,14,672]{3,2,1,0} %broadcast.250), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %param_7 = f32[672]{0} parameter(7) - %broadcast.249 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_7), dimensions={3}, metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %add.79 = f32[512,14,14,672]{3,2,1,0} add(f32[512,14,14,672]{3,2,1,0} %multiply.218, f32[512,14,14,672]{3,2,1,0} %broadcast.249), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %convert.65 = f16[512,14,14,672]{3,2,1,0} convert(f32[512,14,14,672]{3,2,1,0} %add.79), metadata={op_type="FusedBatchNormV3" op_name="foo/batch_normalization/FusedBatchNormV3"} - %negate.12 = f16[512,14,14,672]{3,2,1,0} negate(f16[512,14,14,672]{3,2,1,0} %convert.65) - %exponential.10 = f16[512,14,14,672]{3,2,1,0} exponential(f16[512,14,14,672]{3,2,1,0} %negate.12) - %add.78 = f16[512,14,14,672]{3,2,1,0} add(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %exponential.10) - %divide.20 = f16[512,14,14,672]{3,2,1,0} divide(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %add.78), metadata={op_type="Sigmoid" op_name="foo/activation/Sigmoid"} - %subtract.77 = f16[512,14,14,672]{3,2,1,0} subtract(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %divide.20), metadata={op_type="Sub" op_name="sub"} - %multiply.211 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %convert.65, f16[512,14,14,672]{3,2,1,0} %subtract.77), metadata={op_type="Mul" op_name="mul"} - %add.75 = f16[512,14,14,672]{3,2,1,0} add(f16[512,14,14,672]{3,2,1,0} %broadcast.256, f16[512,14,14,672]{3,2,1,0} %multiply.211), metadata={op_type="AddV2" op_name="add"} - %multiply.210 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %divide.20, f16[512,14,14,672]{3,2,1,0} %add.75), metadata={op_type="Mul" op_name="mul_1"} - %multiply.209 = f16[512,14,14,672]{3,2,1,0} multiply(f16[512,14,14,672]{3,2,1,0} %param_6, f16[512,14,14,672]{3,2,1,0} %multiply.210), metadata={op_type="Mul" op_name="mul_2"} - %convert.8 = f32[512,14,14,672]{3,2,1,0} convert(f16[512,14,14,672]{3,2,1,0} %multiply.209), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %constant_48 = f32[] constant(100352), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.46 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[] %constant_48), dimensions={}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.27 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %convert.8, f32[512,14,14,672]{3,2,1,0} %broadcast.46), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_1 = f32[672]{0} parameter(1) - %broadcast.45 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_1), dimensions={3}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %subtract.10 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %multiply.27, f32[512,14,14,672]{3,2,1,0} %broadcast.45), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %param_0.89 = f32[672]{0} parameter(0) - %broadcast.44 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %param_0.89), dimensions={3}, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.26 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %broadcast.44, f32[512,14,14,672]{3,2,1,0} %subtract.82), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %broadcast.42 = f32[512,14,14,672]{3,2,1,0} broadcast(f32[672]{0} %divide.14), dimensions={3} - %divide.6 = f32[512,14,14,672]{3,2,1,0} divide(f32[512,14,14,672]{3,2,1,0} %multiply.26, f32[512,14,14,672]{3,2,1,0} %broadcast.42), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %subtract.9 = f32[512,14,14,672]{3,2,1,0} subtract(f32[512,14,14,672]{3,2,1,0} %subtract.10, f32[512,14,14,672]{3,2,1,0} %divide.6), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - %multiply.25 = f32[512,14,14,672]{3,2,1,0} multiply(f32[512,14,14,672]{3,2,1,0} %broadcast.47, f32[512,14,14,672]{3,2,1,0} %subtract.9), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} - ROOT %convert.7 = f16[512,14,14,672]{3,2,1,0} convert(f32[512,14,14,672]{3,2,1,0} %multiply.25), metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[672]{0} parameter(1) - %param_2 = f32[672]{0} parameter(2) - %param_3 = f32[672]{0} parameter(3) - %param_4 = f16[512,14,14,672]{3,2,1,0} parameter(4) - %param_5 = f32[672]{0} parameter(5) - %param_6 = f16[512,14,14,672]{3,2,1,0} parameter(6) - %param_7 = f32[672]{0} parameter(7) - - ROOT %fusion.1 = f16[512,14,14,672]{3,2,1,0} fusion(f32[672]{0} %param_0, f32[672]{0} %param_1, f32[672]{0} %param_2, f32[672]{0} %param_3, f16[512,14,14,672]{3,2,1,0} %param_4, f32[672]{0} %param_5, f16[512,14,14,672]{3,2,1,0} %param_6, f32[672]{0} %param_7), kind=kLoop, calls=%fused_computation.1, metadata={op_type="FusedBatchNormGradV3" op_name="gradient_tape/foo/batch_normalization/FusedBatchNormGradV3"} -} - -// CHECK-LABEL: fusion_1 -// CHECK: .reqntid 168, 1, 1 -// CHECK-NOT: ld.global.nc.f -// CHECK-NOT: ld.global.nc.b - -// ----- - -HloModule TransposeOutput, is_scheduled=true - -%fused_computation.2 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) - ROOT %copy = f32[512,14,14,672]{0,2,3,1} copy(%add) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.2 = f32[512,14,14,672]{0,2,3,1} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.2 -} -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_2 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule TransposeInput, is_scheduled=true - -%fused_computation.3 (param_0: f32[672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[672]{0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={3} - %param_1 = f32[512,14,14,672]{0,2,3,1} parameter(1) - %copy = f32[512,14,14,672]{3,2,1,0} copy(%param_1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %copy) -} - -ENTRY main { - %param_0 = f32[672]{0} parameter(0) - %param_1 = f32[512,14,14,672]{0,2,3,1} parameter(1) - - ROOT %fusion.3 = f32[512,14,14,672]{3,2,1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3 -} -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_3 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule ScalarBroadcasting, is_scheduled=true - -%fused_computation.5 (param_0: f32[], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[] parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.5 = f32[512,14,14,672] fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.5 -} - -// CHECK-LABEL: fusion_5 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- - -HloModule NotSupportedBroadcasting, is_scheduled=true - -%fused_computation.6 (param_0: f32[14,672], param_1: f32[512,14,14,672]) -> f32[512,14,14,672] { - %param_0 = f32[14,672]{1,0} parameter(0) - %broadcast = f32[512,14,14,672]{3,2,1,0} broadcast(%param_0), dimensions={2,3} - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - ROOT %add = f32[512,14,14,672]{3,2,1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[14,672]{1,0} parameter(0) - %param_1 = f32[512,14,14,672]{3,2,1,0} parameter(1) - - ROOT %fusion.6 = f32[512,14,14,672] fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.6 -} - -// Check that we didn't do anything. The block size didn't change. -// CHECK-LABEL: fusion_6 -// CHECK: .reqntid 128, 1, 1 -// CHECK: ld.global.nc.f - -// ----- -HloModule Module, is_scheduled=true - -%fused_computation.7 { - %constant_2 = f32[] constant(0) - %broadcast.1 = f32[32,7,7,352]{2,1,3,0} broadcast(f32[] %constant_2), dimensions={} - %param_1.2 = f32[32,7,7,320]{2,1,3,0} parameter(1) - %param_2.1 = f32[32,7,7,224]{2,1,3,0} parameter(2) - %param_3.1 = f32[32,7,7,128]{2,1,3,0} parameter(3) - %tmp_8.1 = f32[32,7,7,1024]{2,1,3,0} concatenate(f32[32,7,7,352]{2,1,3,0} %broadcast.1, f32[32,7,7,320]{2,1,3,0} %param_1.2, f32[32,7,7,224]{2,1,3,0} %param_2.1, f32[32,7,7,128]{2,1,3,0} %param_3.1), dimensions={3} - %param_0.1 = f32[32,7,7,1024]{2,1,3,0} parameter(0) - ROOT %tmp_10.1 = f32[32,7,7,1024]{2,1,3,0} add(f32[32,7,7,1024]{2,1,3,0} %tmp_8.1, f32[32,7,7,1024]{2,1,3,0} %param_0.1) -} - -ENTRY %computation { - %tmp_0 = u8[32,224,224,3]{3,2,1,0} parameter(0) - %tmp_9 = f32[32,7,7,1024]{2,1,3,0} constant({...}) - %tmp_5 = f32[32,7,7,320]{2,1,3,0} constant({...}) - %tmp_6 = f32[32,7,7,224]{2,1,3,0} constant({...}) - %tmp_7 = f32[32,7,7,128]{2,1,3,0} constant({...}) - ROOT %fusion.7 = f32[32,7,7,1024]{2,1,3,0} fusion(f32[32,7,7,1024]{2,1,3,0} %tmp_9, f32[32,7,7,320]{2,1,3,0} %tmp_5, f32[32,7,7,224]{2,1,3,0} %tmp_6, f32[32,7,7,128]{2,1,3,0} %tmp_7), kind=kLoop, calls=%fused_computation.7 -} - - -// This graph triggered a bug where the new indexing was generated -// CHECK-LLVM-LABEL: @fusion_7 -// CHECK-LLVM-NOT: row_index - -// ----- -HloModule RowToLong, is_scheduled=true - -%fused_computation.1 { - %p0 = f32[2025]{0} parameter(0) - ROOT %r = f32[3025,2025]{1,0} broadcast(%p0), dimensions={1} -} - -ENTRY main { - %param_0 = f32[2025]{0} parameter(0) - ROOT %fusion.8 = f32[3025,2025]{1,0} fusion(%param_0), kind=kLoop, calls=%fused_computation.1 - -} -// Check that we didn't emit the simpler row broadcasting. -// CHECK-LLVM-LABEL: @fusion_8 -// CHECK-LLVM-NOT: row_index - -// ----- - -HloModule module, is_scheduled=true - -%fused_computation.1 { - %p0 = f16[5000,64,64,32] parameter(0) - %p1 = f16[] parameter(1) - ROOT %pad1 = f16[5000,65,65,32] pad(%p0, %p1), padding=0_0x0_1x0_1x0_0 -} - -ENTRY computation { - p0 = f16[5000,64,64,32] parameter(0) - zero = f16[] constant(0) - - ROOT %fusion.9 = f16[5000,65,65,32] fusion(p0, zero), kind=kLoop, calls=%fused_computation.1 -} - -// Our codegen can't emit a vectorized load here, but it can emit a vectorized -// store. -// CHECK-LABEL: .visible .entry fusion_9 -// CHECK-COUNT-4: ld.global.nc.u16 -// CHECK: st.global.v4.b16 diff --git a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc index 2e5db538d8a0a5..b4124f2673d958 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_cub_sort_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include "xla/tests/hlo_test_base.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -35,7 +35,7 @@ namespace { bool HloWasRewrittenToUseCubSort(const HloModule& module) { for (const auto& pass_metadata : module.metadata().proto().pass_metadata()) { - if (pass_metadata.pass_name() == "gpu-sort-rewriter") { + if (pass_metadata.pass_name() == "sort-rewriter") { return pass_metadata.module_changed(); } } @@ -50,13 +50,13 @@ class CubSortKeysTest : public HloTestBase, public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly(33000); } }; TEST_P(CubSortKeysTest, CompareToReference) { int batch_size = std::get<2>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; const char* kHloTpl = R"( HloModule TestSortKeys @@ -103,7 +103,7 @@ ENTRY m { })"; int batch_size = std::get<2>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), @@ -138,13 +138,13 @@ class CubSortPairsTest public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(33000); + SortRewriter::SetSortSizeThresholdForTestingOnly(33000); } }; TEST_P(CubSortPairsTest, CompareToReference) { int batch_size = std::get<3>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; const char* kHloTpl = R"( HloModule TestSortPairs @@ -216,7 +216,7 @@ ENTRY m { })"; int batch_size = std::get<3>(GetParam()); - int segment_size = GpuSortRewriter::SortSizeThreshold() / batch_size; + int segment_size = SortRewriter::SortSizeThreshold() / batch_size; std::string hlo_str = absl::Substitute( kHloTpl, primitive_util::LowercasePrimitiveTypeName(std::get<0>(GetParam())), diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc index 639cf511875f43..aed017cbefb2fa 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fused_mha_test.cc @@ -134,7 +134,7 @@ class MultiHeadedAttentionTest : public GpuCodegenTest { EXPECT_TRUE( LiteralTestUtil::Near(expected_result, actual_result, mha_error_spec_)); - // Run FusedMHA/FuseMHABackward thunk through command buffer + // Run through command buffer DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); debug_options.set_xla_gpu_graph_min_graph_size(1); @@ -393,8 +393,8 @@ class FlashAttentionBMMScaleCausalMaskSoftmaxBMM void TestImpl_Flash_Attention_BMM1_CausalMask_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -414,8 +414,8 @@ class FlashAttentionBMMScaleCausalMaskSoftmaxBMM void TestImpl_Flash_Attention_Training_BMM1_CausalMask_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -709,8 +709,8 @@ class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -731,8 +731,8 @@ class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { void TestImpl_Flash_Attention_Training_BMM1_Bias_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -756,9 +756,9 @@ class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { void TestImpl_Flash_Attention_BMM1_Bias_Softmax_BMM2_Cross_Attention() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { + se::dnn::VersionInfo(9, 0, 0)) { GTEST_SKIP() << "Flash Attention cross attention requires " - "cuDNN >= 8.9.4."; + "cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -780,10 +780,10 @@ class FlashAttentionBMMScaleBiasSoftmaxBMM : public MultiHeadedAttentionTest { if (skip_reason_) GTEST_SKIP() << *skip_reason_; auto cc = GetCudaComputeCapability(); if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 6) || + se::dnn::VersionInfo(9, 0, 0) || !cc.IsAtLeastHopper() || cc.minor != 0) { GTEST_SKIP() - << "Flash Attention dbias requires cuDNN >= 8.9.6 and Hopper arch."; + << "Flash Attention dbias requires cuDNN >= 9.0.0 and Hopper arch."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -900,8 +900,8 @@ class FlashAttentionBMMScaleSoftmaxBMM : public MultiHeadedAttentionTest { void TestImpl_Flash_Attention_Training_BMM1_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.4."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -925,10 +925,10 @@ class FlashAttentionBMMScaleSoftmaxBMM : public MultiHeadedAttentionTest { if (skip_reason_) GTEST_SKIP() << *skip_reason_; auto cc = GetCudaComputeCapability(); if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 4) || + se::dnn::VersionInfo(9, 0, 0) || !cc.IsAtLeastHopper() || cc.minor != 0) { GTEST_SKIP() << "Flash Attention deterministic kernels requires cuDNN >= " - "8.9.4 and Hopper arch."; + "9.0.0 and Hopper arch."; } XlaBuilder builder(TestName()); auto lhs_bmm1_literal = @@ -1085,8 +1085,8 @@ class FlashAttentionBMMScalePaddingMaskSoftmaxBMM void TestImpl_Flash_Attention_Training_BMM1_PaddingMask_Softmax_BMM2() { if (skip_reason_) GTEST_SKIP() << *skip_reason_; if (GetDnnVersionInfoOrDefault(backend().default_stream_executor()) < - se::dnn::VersionInfo(8, 9, 3)) { - GTEST_SKIP() << "Flash Attention requires cuDNN >= 8.9.3."; + se::dnn::VersionInfo(9, 0, 0)) { + GTEST_SKIP() << "Flash Attention requires cuDNN >= 9.0.0."; } XlaBuilder builder(TestName()); // pass padding mask as bias diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc index 3e573eb569bb62..cc20ef8b8484e8 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_pipeline_test.cc @@ -18,11 +18,11 @@ limitations under the License. #include #include "absl/strings/string_view.h" -#include "xla/service/gpu/fusion_merger.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/instruction_fusion.h" -#include "xla/service/gpu/multi_output_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/fusion_merger.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/shape.h" @@ -51,8 +51,7 @@ class GpuFusionPipelineTest : public GpuCodegenTest { device_info); pipeline.AddPass(/*may_duplicate=*/true, device_info); pipeline.AddPass(device_info, ShapeSizeBytesFunction()); - pipeline.AddPass(device_info, - ShapeSizeBytesFunction()); + pipeline.AddPass(device_info, ShapeSizeBytesFunction()); RunAndFilecheckHloRewrite(hlo, std::move(pipeline), expected); } @@ -65,15 +64,17 @@ HloModule module ENTRY computation { p = f32[5000,6000]{1,0} parameter(0) e = f32[5000,6000]{1,0} sqrt(p) - c = f32[6000,5000] transpose(p), dimensions={1,0} + b = f32[1,5000,6000] reshape(p) + c = f32[1,6000,5000] transpose(b), dimensions={0,2,1} r = f32[300,20,5000] reshape(c) ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r) } )"; CheckGpuFusionPipeline(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[5000,6000]) -> (f32[300,20,5000], f32[5000,6000]) { +// CHECK: %fused_computation ({{[^:]+}}: f32[5000,6000]) -> (f32[300,20,5000], f32[5000,6000]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[5000,6000]{1,0} parameter(0) -// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f32[6000,5000]{1,0} transpose([[param_0_1_0]]), dimensions={1,0} +// CHECK-NEXT: [[bc:%[^ ]+]] = f32[1,5000,6000]{2,1,0} reshape([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_1:%[^ ]+]] = f32[1,6000,5000]{2,1,0} transpose([[bc]]), dimensions={0,2,1} // CHECK-NEXT: [[r_1_2:%[^ ]+]] = f32[300,20,5000]{2,1,0} reshape([[c_1_1]]) // CHECK-NEXT: [[e_1_3:%[^ ]+]] = f32[5000,6000]{1,0} sqrt([[param_0_1_0]]) // CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[300,20,5000]{2,1,0}, f32[5000,6000]{1,0}) tuple([[r_1_2]], [[e_1_3]]) diff --git a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc index 849cf1dcaf5bba..43c6a509239ccd 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_fusion_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/instruction_fusion.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -181,18 +181,18 @@ TEST_F(TransposeFusionTest, ElementaryLogical) { HloModule module ENTRY main { - p = f32[16,32]{1,0} parameter(0) - s = sqrt(p) - ROOT c = f32[32,16]{1,0} transpose(s), dimensions={1,0} + p = f32[1,16,32]{2,1,0} parameter(0) + s = f32[1,16,32]{2,1,0} sqrt(p) + ROOT c = f32[1,32,16]{2,1,0} transpose(s), dimensions={0,2,1} } )"; CheckGpuFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[16,32]) -> f32[32,16] { -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0} -// CHECK: ROOT [[fusion_3:%[^ ]+]] = f32[32,16]{1,0} fusion([[p_4:%[^ ]+]]), kind=kInput, calls=[[fused_computation_5:%[^ ]+]] +// CHECK: %fused_computation ({{[^:]+}}: f32[1,16,32]) -> f32[1,32,16] { +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: ROOT [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]), dimensions={0,2,1} +// CHECK: ROOT [[fusion_3:%[^ ]+]] = f32[1,32,16]{2,1,0} fusion([[p_4:%[^ ]+]]), kind=kInput, calls=[[fused_computation_5:%[^ ]+]] )"); } diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index e86f2c09b06cea..45e7af622d9814 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -34,7 +34,11 @@ class GpuKernelTilingTest : public GpuCodegenTest { // Most tests in this file want to skip layout assignment, but a few need it // enabled. HloModuleConfig ConfigWithLayoutAssignment() { - return GetModuleConfigForTest(); + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_mlir_emitter_level(3); + config.set_debug_options(debug_options); + return config; } HloModuleConfig ConfigWithoutLayoutAssignment() { @@ -42,6 +46,7 @@ class GpuKernelTilingTest : public GpuCodegenTest { auto debug_options = HloTestBase::GetDebugOptionsForTest(); // Disable layout_assignment to use the preassigned layouts. debug_options.add_xla_disable_hlo_passes("layout-assignment"); + debug_options.set_xla_gpu_mlir_emitter_level(3); config.set_debug_options(debug_options); return config; } @@ -635,6 +640,8 @@ TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { } )"; auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); + auto &debug_options = hlo_module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); auto expected_ir = is_built_with_rocm_ ? R"( ; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } ; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison diff --git a/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc index 5ea691967fa6b8..6133a5b38f4bc5 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_sparse_dot_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc index e247abe6872847..cc4c36507fec94 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_spmd_e2e_compile_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/service/executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo b/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo deleted file mode 100644 index 3d05dcf9892ad5..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/launch_dimensions.hlo +++ /dev/null @@ -1,338 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s -// This tests that we do not increase the grid launch size when -// few_waves is enabled. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @wrapped_b -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-DAG: ![[ctaid_range]] = !{i32 0, i32 2} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 1024} - - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[100,20]{1,0} parameter(0) - ROOT b.1 = f32[100,20]{1,0} round-nearest-even(f32[100,20]{1,0} param_0) -} - -ENTRY main { - a = f32[100, 20]{1,0} parameter(0) - ROOT wrapped_b = f32[100,20]{1,0} fusion(f32[100,20]{1,0} a), kind=kLoop, calls=fused_computation -} - -// ----- - -// This tests that we cap grid launch code when few_waves is enabled. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @wrapped_b -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule Test, is_scheduled=true - -fused_computation { - param_0 = f32[10000,10000]{1,0} parameter(0) - ROOT b.1 = f32[10000,10000]{1,0} round-nearest-even(f32[10000,10000]{1,0} param_0) -} - -ENTRY main { - a = f32[10000, 10000]{1,0} parameter(0) - ROOT wrapped_b = f32[10000,10000]{1,0} fusion(f32[10000,10000]{1,0} a), kind=kLoop, calls=fused_computation -} - -// ----- - -// This tests that we cap grid launch code when few_waves is enabled -// and scalar broadcast are present. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion_3 -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule ScalarBroadcast, is_scheduled=true - -%fused_computation.3 (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[] parameter(0) - %broadcast = f32[10000, 10000]{1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%broadcast, %param_1) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion.3 = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation.3 -} - -// ----- - -// This tests that we enable few_waves in a simple fusion. It is the baseline -// for the tests below. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule SimpleFusion, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%param_0, %param_1) -} - -ENTRY main { - %param_0 = f32[10000, 10000]{1,0} parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we keep few_waves enabled for large constants. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule LargeConstant, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %c0 = f32[10000,10000] constant(0) - ROOT %add = f32[10000, 10000]{1,0} add(%param_0, %c0) -} - -ENTRY main { - %param_0 = f32[10000, 10000] parameter(0) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we disable few_waves if a non-elementwise op is present. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 195313} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 97657} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} - -HloModule NonElementwise, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[10000, 10000]) -> f32[10000, 10000] { - %param_0 = f32[10000,10000] parameter(0) - %reverse = f32[10000,10000]{1,0} reverse(%param_0), dimensions={0,1} - %param_1 = f32[10000, 10000]{1,0} parameter(1) - ROOT %add = f32[10000, 10000]{1,0} add(%reverse, %param_1) -} - -ENTRY main { - %param_0 = f32[10000, 10000]{1,0} parameter(0) - %param_1 = f32[10000, 10000]{1,0} parameter(1) - - ROOT %fusion = f32[10000, 10000]{1,0} fusion(%param_0, %param_1), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we disable few_waves if -// - a tensor broadcast is present -// - at least four big inputs are present -// - the fusion is not row-vectorizable -// It serves as a baseline for the tests below. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 7813} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 3907} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 256} - -HloModule NoFewWaves, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[2000] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={0} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{0,1} parameter(2) - %param_3 = f32[2000, 2000]{0,1} parameter(3) - %param_4 = f32[2000, 2000]{0,1} parameter(4) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - %sum.2 = f32[2000, 2000] add(%sum.1, %param_4) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.2, %broadcast) -} - -ENTRY main { - %param_0 = f32[2000]{0} parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{0,1} parameter(2) - %param_3 = f32[2000, 2000]{0,1} parameter(3) - %param_4 = f32[2000, 2000]{0,1} parameter(4) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3, %param_4), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we enable few_waves if -// - a tensor broadcast is present -// - THREE big inputs are present -// - the fusion IS row-vectorizable -// In this case, the block count is changed from 7813 to 2000. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 500} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 2000} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 500} - -HloModule RowVectorizable, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[2000] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={1} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.1, %broadcast) -} - -ENTRY main { - %param_0 = f32[2000]{0} parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3), kind=kLoop, calls=%fused_computation -} - -// ----- - -// This tests that we enable few_waves if -// - a SCALAR broadcast is present -// - four big inputs are present -// - the fusion is not row-vectorizable -// In this case, the block count is changed from 7813 to 1008. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-PTX-DAG: ![[tid_range]] = !{i32 0, i32 128} -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-GCN-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule ScalarBroadcastFourInputs, is_scheduled=true - -%fused_computation (param_0: f32[], param_1: f32[2000, 2000]) -> f32[2000, 2000] { - %param_0 = f32[] parameter(0) - %broadcast = f32[2000, 2000]{1,0} broadcast(%param_0), dimensions={} - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - %param_4 = f32[2000, 2000]{1,0} parameter(4) - - %sum.0 = f32[2000, 2000] add(%param_1, %param_2) - %sum.1 = f32[2000, 2000] add(%sum.0, %param_3) - %sum.2 = f32[2000, 2000] add(%sum.1, %param_4) - ROOT %add = f32[2000, 2000]{1,0} add(%sum.2, %broadcast) -} - -ENTRY main { - %param_0 = f32[] parameter(0) - %param_1 = f32[2000, 2000]{1,0} parameter(1) - %param_2 = f32[2000, 2000]{1,0} parameter(2) - %param_3 = f32[2000, 2000]{1,0} parameter(3) - %param_4 = f32[2000, 2000]{1,0} parameter(4) - - ROOT %fusion = f32[2000, 2000]{1,0} fusion(%param_0, %param_1, %param_2, %param_3, %param_4), kind=kLoop, calls=%fused_computation -} - -// ----- -// This tests the GELU kernel. The original kernel that -// motivated few_waves implementation. - -// CHECK-LABEL: define{{( amdgpu_kernel)?}} void @fusion -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workgroup.id.x(), !range ![[ctaid_range:[0-9]+]] -// CHECK-PTX-DAG: call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range ![[tid_range:[0-9]+]] -// CHECK-GCN-DAG: call i32 @llvm.amdgcn.workitem.id.x(), !range ![[tid_range:[0-9]+]] -// CHECK-PTX-DAG: ![[ctaid_range]] = !{i32 0, i32 1008} -// CHECK-GCN-DAG: ![[ctaid_range]] = !{i32 0, i32 1760} -// CHECK-DAG: ![[tid_range]] = !{i32 0, i32 128} - -HloModule Test, is_scheduled=true - -%fused_computation (param_0: f16[6,512,4096]) -> f16[6,512,4096] { - %param_0 = f16[6,512,4096]{2,1,0} parameter(0) - %power.tmp.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %param_0) - %power.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.tmp.1, f16[6,512,4096]{2,1,0} %param_0) - %constant_4 = f16[] constant(0.044708), metadata={op_type="Mul" op_name="mul"} - %broadcast.3 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_4), dimensions={}, metadata={op_type="Mul" op_name="mul"} - %multiply.3 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %power.0, f16[6,512,4096]{2,1,0} %broadcast.3), metadata={op_type="Mul" op_name="mul"} - %add.1 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.3), metadata={op_type="AddV2" op_name="add"} - %constant_2 = f16[] constant(0.79785), metadata={op_type="Mul" op_name="mul_1"} - %broadcast.2 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_2), dimensions={}, metadata={op_type="Mul" op_name="mul_1"} - %multiply.2 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.1, f16[6,512,4096]{2,1,0} %broadcast.2), metadata={op_type="Mul" op_name="mul_1"} - %tanh.0 = f16[6,512,4096]{2,1,0} tanh(f16[6,512,4096]{2,1,0} %multiply.2), metadata={op_type="Tanh" op_name="Tanh"} - %constant_1 = f16[] constant(1), metadata={op_type="AddV2" op_name="add_1"} - %broadcast.1 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_1), dimensions={}, metadata={op_type="AddV2" op_name="add_1"} - %add.0 = f16[6,512,4096]{2,1,0} add(f16[6,512,4096]{2,1,0} %tanh.0, f16[6,512,4096]{2,1,0} %broadcast.1), metadata={op_type="AddV2" op_name="add_1"} - %constant_0 = f16[] constant(0.5), metadata={op_type="Mul" op_name="mul_2"} - %broadcast.0 = f16[6,512,4096]{2,1,0} broadcast(f16[] %constant_0), dimensions={}, metadata={op_type="Mul" op_name="mul_2"} - %multiply.1 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %add.0, f16[6,512,4096]{2,1,0} %broadcast.0), metadata={op_type="Mul" op_name="mul_2"} - ROOT %multiply.0 = f16[6,512,4096]{2,1,0} multiply(f16[6,512,4096]{2,1,0} %param_0, f16[6,512,4096]{2,1,0} %multiply.1), metadata={op_type="Mul" op_name="mul_3"} -} - -ENTRY %cluster_0__XlaCompiledKernel_true__XlaNumConstantArgs_0__XlaNumResourceArgs_0_.24 (arg0.1: f16[6,512,4096]) -> f16[6,512,4096] { - %arg0.1 = f16[6,512,4096]{2,1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} - ROOT %fusion = f16[6,512,4096]{2,1,0} fusion(f16[6,512,4096]{2,1,0} %arg0.1), kind=kLoop, calls=%fused_computation, metadata={op_type="Mul" op_name="mul_3"} -} diff --git a/third_party/xla/xla/service/gpu/tests/simple_optimization_test.cc b/third_party/xla/xla/service/gpu/tests/simple_optimization_test.cc index a18d58d6df333c..2ece976d3e5e39 100644 --- a/third_party/xla/xla/service/gpu/tests/simple_optimization_test.cc +++ b/third_party/xla/xla/service/gpu/tests/simple_optimization_test.cc @@ -15,7 +15,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir similarity index 91% rename from third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir rename to third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir index 6457691c211c3e..d8abd2d4504a24 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_add_layout.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_add_encoding.mlir @@ -1,6 +1,6 @@ // RUN: xla-opt %s \ // RUN: -convert-triton-to-tritongpu='target=cuda:80' \ -// RUN: -add-sparse-encoding -canonicalize \ +// RUN: -sparse-add-encoding -canonicalize \ // RUN: | FileCheck %s // Note: 'canonicalize' folds redundant (back-and-forth) convert_layout ops. @@ -35,10 +35,10 @@ module { // CHECK-NEXT: %[[CVT:.*]] = triton_gpu.convert_layout %[[D]] // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED4x4]]> // CHECK-SAME: -> tensor<64x64xf32, #[[BLOCKED1x1]]> - // CHECK-NEXT: tt.print "" {hex = false} : %[[CVT]] + // CHECK-NEXT: tt.print "" {hex = false, isSigned = array} : %[[CVT]] // CHECK-SAME: : tensor<64x64xf32, #[[BLOCKED1x1]]> // A use with side effects so we don't DCE the whole function. - tt.print "" { hex = false } : %d : tensor<64x64xf32> + tt.print "" { hex = false, isSigned = array} : %d : tensor<64x64xf32> // CHECK-NEXT: tt.return tt.return diff --git a/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir new file mode 100644 index 00000000000000..ad611620a0bd35 --- /dev/null +++ b/third_party/xla/xla/service/gpu/tests/sparse_convert_triton_to_tritongpu.mlir @@ -0,0 +1,25 @@ +// RUN: xla-opt %s -convert-triton-to-tritongpu='target=cuda:80' | FileCheck %s + +module attributes {} { + tt.func @gemm_fusion_dot_1_impl() { + %c0_i32 = arith.constant 0 : i32 + %c32_i32 = arith.constant 32 : i32 + %acc = arith.constant dense<0.000000e+00> : tensor<32x32xf32> + %a = arith.constant dense<0.000000e+00> : tensor<32x16xbf16> + // CHECK: %[[A:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x16xbf16, {{.+}}> -> tensor<32x16xbf16> + %b = arith.constant dense<0.000000e+00> : tensor<32x32xbf16> + // CHECK: %[[B:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xbf16, {{.+}}> -> tensor<32x32xbf16> + %meta = arith.constant dense<0> : tensor<32x2xi16> + // CHECK: %[[META:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x2xi16, {{.+}}> -> tensor<32x2xi16> + %35:1 = scf.for %arg4 = %c0_i32 to %c32_i32 step %c32_i32 iter_args(%arg8 = %acc) -> (tensor<32x32xf32>) : i32 { + // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32, {{.+}}> -> tensor<32x32xf32> + // CHECK-NEXT: %[[D:.*]] = triton_gpu.sparse_dot %[[A]], %[[B]], %[[ACC]], %[[META]] + // CHECK-SAME: : tensor<32x16xbf16> meta tensor<32x2xi16> + // CHECK-SAME: * tensor<32x32xbf16> -> tensor<32x32xf32> + %74 = triton_gpu.sparse_dot %a, %b, %arg8, %meta : tensor<32x16xbf16> meta tensor<32x2xi16> * tensor<32x32xbf16> -> tensor<32x32xf32> + // CHECK: %[[ACC:.+]] = triton_gpu.convert_layout {{.+}} : tensor<32x32xf32> -> tensor<32x32xf32, {{.+}}> + scf.yield %74 : tensor<32x32xf32> + } + tt.return + } +} \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir b/third_party/xla/xla/service/gpu/tests/sparse_remove_layout_conversion.mlir similarity index 90% rename from third_party/xla/xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir rename to third_party/xla/xla/service/gpu/tests/sparse_remove_layout_conversion.mlir index 5604a1ac5baf46..7db3874eef4047 100644 --- a/third_party/xla/xla/service/gpu/tests/sparse_ttg_reduce_data_duplication.mlir +++ b/third_party/xla/xla/service/gpu/tests/sparse_remove_layout_conversion.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-reduce-data-duplication | FileCheck %s +// RUN: xla-opt %s --sparse-remove-layout-conversion | FileCheck %s #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> #mma = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}> diff --git a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc index b8a263df629a69..b999d1c05fa079 100644 --- a/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc +++ b/third_party/xla/xla/service/gpu/tests/transpose_emitter_test.cc @@ -61,7 +61,7 @@ TEST_F(TransposeEmitterTest, SimpleLogicalTranspose) { )"; CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); + /*run_optimization_passes=*/true); EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); } @@ -90,7 +90,8 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + bc = f32[1,16,32]{2,1,0} bitcast(%s.1) + %t.1 = f32[1,32,16]{2,1,0} transpose(bc), dimensions={0,2,1} b = f32[32,16,1]{2,1,0} bitcast(%t.1) ROOT o = f32[32,16,1]{2,1,0} sqrt(b) } @@ -116,8 +117,10 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %t.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %t1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) + %t.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} + %t1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} %r.1 = f32[32,16,1]{2,1,0} reshape(%t.1) %r1.1 = f32[32,16,1]{2,1,0} reshape(%t1.1) ROOT %tuple = (f32[32,16,1]{2,1,0}, f32[32,16,1]{2,1,0}) tuple(%r.1, %r1.1) @@ -170,14 +173,16 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} - %c1.1 = f32[32,16]{1,0} transpose(%param_0.1), dimensions={1,0} - ROOT %tuple = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(%c.1, %c1.1) + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %bc.2 = f32[1,16,32]{2,1,0} bitcast(%param_0.1) + %c.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} + %c1.1 = f32[1,32,16]{2,1,0} transpose(%bc.2), dimensions={0,2,1} + ROOT %tuple = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple(%c.1, %c1.1) } ENTRY main { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) fusion(%p), kind=kInput, calls=%fused_computation } )"; @@ -251,14 +256,15 @@ HloModule m %fused_computation { %param_0.1 = f32[16,32]{1,0} parameter(0) %s.1 = f32[16,32]{1,0} sqrt(%param_0.1) - %c.1 = f32[32,16]{1,0} transpose(%s.1), dimensions={1,0} + %bc.1 = f32[1,16,32]{2,1,0} bitcast(%s.1) + %c.1 = f32[1,32,16]{2,1,0} transpose(%bc.1), dimensions={0,2,1} %c1.1 = f32[16,32]{1,0} exponential(%param_0.1) - ROOT %tuple = (f32[32,16]{1,0}, f32[16,32]{1,0}) tuple(%c.1, %c1.1) + ROOT %tuple = (f32[1,32,16]{2,1,0}, f32[16,32]{1,0}) tuple(%c.1, %c1.1) } ENTRY entry { %p = f32[16,32]{1,0} parameter(0) - ROOT %fusion = (f32[32,16]{1,0}, f32[16,32]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation + ROOT %fusion = (f32[1,32,16]{2,1,0}, f32[16,32]{1,0}) fusion(%p), kind=kInput, calls=%fused_computation } )"; diff --git a/third_party/xla/xla/service/gpu/tests/xla-opt.cc b/third_party/xla/xla/service/gpu/tests/xla-opt.cc index 30bd45f242a94c..f27b6f82366230 100644 --- a/third_party/xla/xla/service/gpu/tests/xla-opt.cc +++ b/third_party/xla/xla/service/gpu/tests/xla-opt.cc @@ -15,16 +15,16 @@ limitations under the License. #include "mlir/InitAllExtensions.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" -#include "xla/service/gpu/fusions/triton/prevent_mmav3_loop_unrolling.h" -#include "xla/service/gpu/fusions/triton/sparse_extensions.h" +#include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/fusions/triton/passes.h" #include "third_party/triton/bin/RegisterTritonDialects.h" int main(int argc, char **argv) { mlir::DialectRegistry registry; mlir::registerAllExtensions(registry); registerTritonDialects(registry); // This registers all passes as well. - xla::gpu::RegisterSparsePasses(); - xla::gpu::RegisterPreventMmaV3LoopUnrollingPass(); + xla::gpu::registerTritonFusionTransformsPasses(); + xla::gpu::registerGpuFusionTransformsPasses(); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "xla-opt modular optimizer driver\n", registry)); diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD new file mode 100644 index 00000000000000..9bf90d9d58ddb7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -0,0 +1,3044 @@ +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_is_configured", +) +load( + "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", + "if_cuda_is_configured", +) +load("//xla:xla.bzl", "xla_cc_test") +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) +load("//xla/tests:build_defs.bzl", "xla_test") +load("//xla/tsl:tsl.bzl", "if_google", "if_oss") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//xla/service/gpu:__subpackages__"], + licenses = ["notice"], +) + +cc_library( + name = "algebraic_simplifier", + srcs = [ + "algebraic_simplifier.cc", + ], + hdrs = [ + "algebraic_simplifier.h", + ], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/service:hlo_pass", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "algebraic_simplifier_test", + srcs = ["algebraic_simplifier_test.cc"], + deps = [ + ":algebraic_simplifier", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + +# End-to-end tested via //third_party/tensorflow/compiler/xla/service/gpu:dot_algorithm_support_test +cc_library( + name = "algorithm_checker", + srcs = ["algorithm_checker.cc"], + hdrs = ["algorithm_checker.h"], + deps = [ + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algorithm_util", + "//xla/service:hlo_pass", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +cc_library( + name = "alias_passthrough_params", + srcs = ["alias_passthrough_params.cc"], + hdrs = ["alias_passthrough_params.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "alias_passthrough_params_test", + srcs = ["alias_passthrough_params_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":alias_passthrough_params", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "all_gather_optimizer", + srcs = ["all_gather_optimizer.cc"], + hdrs = ["all_gather_optimizer.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "all_gather_optimizer_test", + srcs = ["all_gather_optimizer_test.cc"], + deps = [ + ":all_gather_optimizer", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "all_reduce_blueconnect", + srcs = ["all_reduce_blueconnect.cc"], + hdrs = ["all_reduce_blueconnect.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:computation_placer_hdr", + "//xla/service:global_device_id", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_reduce_blueconnect_test", + srcs = ["all_reduce_blueconnect_test.cc"], + deps = [ + ":all_reduce_blueconnect", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:computation_placer_hdr", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "all_reduce_splitter", + srcs = ["all_reduce_splitter.cc"], + hdrs = ["all_reduce_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "all_reduce_splitter_test", + srcs = ["all_reduce_splitter_test.cc"], + deps = [ + ":all_reduce_splitter", + ":reduce_scatter_creator", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass_pipeline", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "async_collective_annotator", + srcs = ["async_collective_annotator.cc"], + hdrs = ["async_collective_annotator.h"], + deps = [ + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "async_collective_annotator_test", + srcs = ["async_collective_annotator_test.cc"], + deps = [ + ":async_collective_annotator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:test_macros_header", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "async_wrapper", + srcs = ["async_wrapper.cc"], + hdrs = ["async_wrapper.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "async_wrapper_test", + srcs = ["async_wrapper_test.cc"], + deps = [ + ":async_wrapper", + "//xla:literal", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:hlo_proto_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:literal_test_util", + "//xla/tests:verified_hlo_module", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "collective_permute_cycle_decomposer", + srcs = ["collective_permute_cycle_decomposer.cc"], + hdrs = ["collective_permute_cycle_decomposer.h"], + deps = [ + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "collective_permute_cycle_decomposer_test", + srcs = ["collective_permute_cycle_decomposer_test.cc"], + deps = [ + ":collective_permute_cycle_decomposer", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "collective_permute_valid_iteration_annotator", + srcs = ["collective_permute_valid_iteration_annotator.cc"], + hdrs = ["collective_permute_valid_iteration_annotator.h"], + deps = [ + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service:while_loop_analysis", + ], +) + +xla_cc_test( + name = "collective_permute_valid_iteration_annotator_test", + srcs = ["collective_permute_valid_iteration_annotator_test.cc"], + deps = [ + ":collective_permute_valid_iteration_annotator", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass_pipeline", + "//xla/service:while_loop_trip_count_annotator", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "command_buffer_scheduling", + srcs = ["command_buffer_scheduling.cc"], + hdrs = ["command_buffer_scheduling.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:variant_visitor", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "command_buffer_scheduling_test", + srcs = ["command_buffer_scheduling_test.cc"], + deps = [ + ":command_buffer_scheduling", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "conv_padding_legalization", + srcs = ["conv_padding_legalization.cc"], + hdrs = ["conv_padding_legalization.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:shape_inference", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "conv_padding_legalization_test", + srcs = ["conv_padding_legalization_test.cc"], + deps = [ + ":conv_padding_legalization", + "//xla:shape_util", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_cudnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "conv_rewriter", + srcs = ["conv_rewriter.cc"], + hdrs = ["conv_rewriter.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "conv_rewriter_test", + srcs = ["conv_rewriter_test.cc"], + deps = [ + ":conv_rewriter", + "//xla:array4d", + "//xla:literal_util", + "//xla:protobuf_util", + "//xla:shape_util", + "//xla:test", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:shape_inference", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "convert_async_collectives_to_sync", + srcs = ["convert_async_collectives_to_sync.cc"], + hdrs = ["convert_async_collectives_to_sync.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:convert_async_collectives_to_sync", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "convert_async_collectives_to_sync_test", + srcs = ["convert_async_collectives_to_sync_test.cc"], + deps = [ + ":convert_async_collectives_to_sync", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "copy_fusion", + srcs = ["copy_fusion.cc"], + hdrs = ["copy_fusion.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:reduction_utils", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "copy_fusion_test", + srcs = ["copy_fusion_test.cc"], + deps = [ + ":copy_fusion", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "cublas_pad_for_gemms", + srcs = ["cublas_pad_for_gemms.cc"], + hdrs = ["cublas_pad_for_gemms.h"], + deps = [ + ":gemm_fusion", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cublas_pad_for_gemms_test", + srcs = ["cublas_pad_for_gemms_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":cublas_pad_for_gemms", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "cudnn_custom_call_converter", + srcs = ["cudnn_custom_call_converter.cc"], + hdrs = ["cudnn_custom_call_converter.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:ir_emission_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "cudnn_custom_call_converter_test", + srcs = ["cudnn_custom_call_converter_test.cc"], + deps = [ + ":cudnn_custom_call_converter", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "cudnn_fused_conv_rewriter", + srcs = ["cudnn_fused_conv_rewriter.cc"], + hdrs = ["cudnn_fused_conv_rewriter.h"], + deps = [ + "//xla:comparison_util", + "//xla:debug_options_flags", + "//xla:literal", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "cudnn_fused_conv_rewriter_test", + srcs = ["cudnn_fused_conv_rewriter_test.cc"], + backend_tags = { + "gpu_a100": [ + "noasan", + "nomsan", + "no_rocm", + ], + }, + backends = [ + "gpu_a100", + "gpu_amd_any", + ] + if_oss(["gpu_any"]), + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + shard_count = 10, + deps = [ + ":conv_rewriter", + ":cudnn_fused_conv_rewriter", + "//xla:comparison_util", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/service:convert_mover", + "//xla/service:hlo_constant_folding", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:reshape_mover", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + +cc_library( + name = "cudnn_fused_mha_rewriter", + srcs = ["cudnn_fused_mha_rewriter.cc"], + hdrs = ["cudnn_fused_mha_rewriter.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +xla_test( + name = "cudnn_fused_mha_rewriter_test", + srcs = ["cudnn_fused_mha_rewriter_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-nvidia", + "no_rocm", + ]}, + backends = [ + "gpu", + ], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":cudnn_fused_mha_rewriter", + ":cudnn_fused_mha_transpose_fusion", + "//xla:error_spec", + "//xla:test_helpers", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:algebraic_simplifier", + "//xla/service:computation_layout", + "//xla/service:hlo_cse", + "//xla/service:hlo_dce", + "//xla/service:hlo_module_config", + "//xla/service:hlo_parser", + "//xla/service:hlo_verifier", + "//xla/service:layout_normalization", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:reshape_decomposer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]), +) + +# Tested via cudnn_fused_mha_rewriter_test. +cc_library( + name = "cudnn_fused_mha_transpose_fusion", + srcs = ["cudnn_fused_mha_transpose_fusion.cc"], + hdrs = ["cudnn_fused_mha_transpose_fusion.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +# Tested via //third_party/tensorflow/compiler/xla/service/gpu/fusions:cudnn_test +cc_library( + name = "cudnn_fusion_compiler", + srcs = if_cuda_is_configured(["cudnn_fusion_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_fusion_compiler.h"]), + deps = if_cuda_is_configured([ + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cudnn_support_utils", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:kernel_reuse_cache", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/service/gpu:triton_fusion_analysis", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_config_cuda//cuda:cudnn_header", + "//xla:shape_util", + "//xla:comparison_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/stream_executor:dnn", + "//xla/stream_executor:stream_executor_h", + "//xla/service:dump", + "//xla/stream_executor/cuda:cudnn_frontend_helpers", + "//xla/stream_executor/cuda:cudnn_plugin", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ]), +) + +cc_library( + name = "cudnn_norm_rewriter", + srcs = ["cudnn_norm_rewriter.cc"], + hdrs = ["cudnn_norm_rewriter.h"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + "//xla:shape_util", + "//xla:types", + "//xla:util", + "//xla:window_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/protobuf:dnn_proto_cc", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]) + if_google([ + "@com_google_protobuf//:wrappers_cc_proto", + ]), +) + +xla_test( + name = "cudnn_norm_rewriter_test", + srcs = ["cudnn_norm_rewriter_test.cc"], + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + deps = [ + ":cudnn_norm_rewriter", + "//xla:error_spec", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tsl/lib/core:status_test_util", + "@com_google_googletest//:gtest_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_cuda//cuda:cudnn_header", + ]), +) + +cc_library( + name = "cudnn_pad_for_convolutions", + srcs = ["cudnn_pad_for_convolutions.cc"], + hdrs = ["cudnn_pad_for_convolutions.h"], + deps = [ + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:cudnn_support_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cudnn_pad_for_convolutions_test", + srcs = ["cudnn_pad_for_convolutions_test.cc"], + deps = [ + ":cudnn_pad_for_convolutions", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_cudnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "cudnn_simplify_padding", + srcs = ["cudnn_simplify_padding.cc"], + hdrs = ["cudnn_simplify_padding.h"], + deps = [ + "//xla:literal", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cudnn_simplify_padding_test", + srcs = ["cudnn_simplify_padding_test.cc"], + deps = [ + ":cudnn_pad_for_convolutions", + ":cudnn_simplify_padding", + ":cudnn_vectorize_convolutions", + "//xla:literal", + "//xla:util", + "//xla/service:algebraic_simplifier", + "//xla/service:call_inliner", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service:reshape_mover", + "//xla/service:tuple_simplifier", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/functional:function_ref", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "cudnn_vectorize_convolutions", + srcs = ["cudnn_vectorize_convolutions.cc"], + hdrs = ["cudnn_vectorize_convolutions.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:cudnn_support_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "cudnn_vectorize_convolutions_test", + srcs = ["cudnn_vectorize_convolutions_test.cc"], + deps = [ + ":cudnn_vectorize_convolutions", + "//xla:util", + "//xla/service:call_inliner", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "cudnn_custom_call_compiler", + srcs = if_cuda_is_configured(["cudnn_custom_call_compiler.cc"]), + hdrs = if_cuda_is_configured(["cudnn_custom_call_compiler.h"]), + deps = if_cuda_is_configured([ + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_config_cuda//cuda:cudnn_header", + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu/runtime:cudnn_thunk", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:dnn", + "//xla/stream_executor:stream_executor_h", + "//xla/stream_executor/cuda:cudnn_frontend_helpers", + "//xla/stream_executor/cuda:cudnn_plugin", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ]), +) + +cc_library( + name = "custom_kernel_fusion_rewriter", + srcs = ["custom_kernel_fusion_rewriter.cc"], + hdrs = ["custom_kernel_fusion_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu/kernels:custom_fusion_library", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "custom_kernel_fusion_rewriter_test", + srcs = ["custom_kernel_fusion_rewriter_test.cc"], + deps = [ + ":custom_kernel_fusion_rewriter", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/kernels:custom_kernel_fusion_pattern", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_dimension_sorter", + srcs = ["dot_dimension_sorter.cc"], + hdrs = ["dot_dimension_sorter.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_test( + name = "dot_dimension_sorter_test", + srcs = ["dot_dimension_sorter_test.cc"], + backends = ["gpu"], + deps = [ + ":dot_dimension_sorter", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "dot_operand_converter", + srcs = ["dot_operand_converter.cc"], + hdrs = ["dot_operand_converter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_test( + name = "dot_operand_converter_test", + srcs = if_gpu_is_configured(["dot_operand_converter_test.cc"]), + backends = [ + "gpu_a100", + "gpu_p100", + "gpu_v100", + "gpu_amd_any", + ], + deps = if_gpu_is_configured( + [ + ":dot_operand_converter", + "@com_google_googletest//:gtest", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_matchers", + "//xla/service:pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + ], + ) + [ + # b/317293391 + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "dot_sparsity_rewriter", + srcs = ["dot_sparsity_rewriter.cc"], + hdrs = ["dot_sparsity_rewriter.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dot_sparsity_rewriter_test", + srcs = ["dot_sparsity_rewriter_test.cc"], + deps = [ + ":dot_sparsity_rewriter", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "double_buffer_loop_unrolling", + srcs = ["double_buffer_loop_unrolling.cc"], + hdrs = ["double_buffer_loop_unrolling.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_instruction_utils", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:flatten_call_graph", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "double_buffer_loop_unrolling_test", + srcs = ["double_buffer_loop_unrolling_test.cc"], + deps = [ + ":double_buffer_loop_unrolling", + "//xla:test", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:tuple_simplifier", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/container:flat_hash_set", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "dynamic_slice_fusion_rewriter", + srcs = ["dynamic_slice_fusion_rewriter.cc"], + hdrs = ["dynamic_slice_fusion_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:custom_call_target_registry", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:gpu_constants", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/kernels:custom_fusion_library", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "dynamic_slice_fusion_rewriter_test", + srcs = ["dynamic_slice_fusion_rewriter_test.cc"], + tags = [ + "gpu", + "no_rocm", + ], + deps = [ + ":dynamic_slice_fusion_rewriter", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client/lib:constants", + "//xla/ffi", + "//xla/ffi:ffi_api", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_value", + "//xla/service:custom_call_target_registry", + "//xla/service:executable", + "//xla/service:hlo_memory_scheduler", + "//xla/service:hlo_module_config", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor", + "//xla/stream_executor/gpu:gpu_types_header", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "fusion_merger", + srcs = ["fusion_merger.cc"], + hdrs = ["fusion_merger.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + ], +) + +xla_cc_test( + name = "fusion_merger_test", + srcs = ["fusion_merger_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":fusion_merger", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "fusion_wrapper", + srcs = ["fusion_wrapper.cc"], + hdrs = ["fusion_wrapper.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:gpu_fusible", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "fusion_wrapper_test", + srcs = ["fusion_wrapper_test.cc"], + deps = [ + ":fusion_wrapper", + "//xla/tests:hlo_test_base", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "gemm_broadcast_folding_rewriter", + srcs = ["gemm_broadcast_folding_rewriter.cc"], + hdrs = ["gemm_broadcast_folding_rewriter.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "gemm_broadcast_folding_rewriter_test", + srcs = ["gemm_broadcast_folding_rewriter_test.cc"], + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + ":gemm_broadcast_folding_rewriter", + ":gemm_rewriter", + "//xla:error_spec", + "//xla/hlo/ir:hlo", + "//xla/service/gpu/tests:gpu_codegen_test", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "gemm_fusion", + srcs = ["gemm_fusion.cc"], + hdrs = ["gemm_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_padding_requirements", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/service/gpu:triton_tiling_propagation", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:tensor_float_32_utils", + ], +) + +xla_cc_test( + name = "gemm_fusion_test", + srcs = ["gemm_fusion_test.cc"], + deps = [ + ":gemm_fusion", + "//xla:autotuning_proto_cc", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_padding_requirements", + "//xla/service/gpu:triton_fusion_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "gemm_rewriter", + srcs = ["gemm_rewriter.cc"], + hdrs = ["gemm_rewriter.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:status_macros", + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/evaluator:hlo_evaluator", + "//xla/hlo/ir:hlo", + "//xla/service:algorithm_util", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu:matmul_utils", + "//xla/stream_executor:blas", + "//xla/stream_executor:device_description", + "//xla/stream_executor/gpu:gpu_blas_lt", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:ml_dtypes", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/protobuf:dnn_proto_cc", + ], +) + +xla_test( + name = "gemm_rewriter_test", + srcs = ["gemm_rewriter_test.cc"], + backends = ["gpu"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ + "TENSORFLOW_USE_ROCM=1", + ]), + deps = [ + ":gemm_rewriter", + "//xla:error_spec", + "//xla:test", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_executable", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory_allocator", + "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tests:filecheck", + "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", + ]), +) + +cc_library( + name = "gemv_rewriter", + srcs = ["gemv_rewriter.cc"], + hdrs = ["gemv_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "gemv_rewriter_test", + srcs = ["gemv_rewriter_test.cc"], + deps = [ + ":gemv_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "gpusolver_rewriter", + srcs = if_gpu_is_configured(["gpusolver_rewriter.cc"]), + hdrs = if_gpu_is_configured(["gpusolver_rewriter.h"]), + deps = if_gpu_is_configured([ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "//xla:comparison_util", + "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:cusolver_context", + "//xla/service/gpu:ir_emission_utils", + "//xla/stream_executor", + "//xla/stream_executor:blas", + "//xla/stream_executor:device_memory_allocator", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ]), +) + +cc_library( + name = "horizontal_input_fusion", + srcs = ["horizontal_input_fusion.cc"], + hdrs = ["horizontal_input_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "horizontal_input_fusion_test", + srcs = ["horizontal_input_fusion_test.cc"], + backends = ["gpu"], + deps = [ + ":horizontal_input_fusion", + "//xla:error_spec", + "//xla:literal_util", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/tests:gpu_codegen_test", + "//xla/stream_executor:device_description", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "horizontal_loop_fusion", + srcs = ["horizontal_loop_fusion.cc"], + hdrs = ["horizontal_loop_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:sub_byte_normalization", + "//xla/service/gpu:gpu_fusible", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "horizontal_loop_fusion_test", + srcs = ["horizontal_loop_fusion_test.cc"], + backends = ["gpu"], + deps = [ + ":horizontal_loop_fusion", + ":instruction_fusion", + "//xla:error_spec", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_dce", + "//xla/service:hlo_parser", + "//xla/service:hlo_pass", + "//xla/service:hlo_pass_pipeline", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + ], +) + +cc_library( + name = "instruction_fusion", + srcs = ["instruction_fusion.cc"], + hdrs = ["instruction_fusion.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:fusion_node_indexing_evaluation", + "//xla/service:fusion_queue", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "instruction_fusion_test", + srcs = ["instruction_fusion_test.cc"], + tags = [ + "nomsan", + "not_run:arm", + ], + deps = [ + ":instruction_fusion", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:hlo_test_base", + "//xla/tests:test_utils", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "layout_assignment", + srcs = ["layout_assignment.cc"], + hdrs = ["layout_assignment.h"], + deps = [ + "//xla:shape_layout", + "//xla:shape_util", + "//xla:util", + "//xla:window_util", + "//xla:xla_data_proto_cc", + "//xla:xla_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:computation_layout", + "//xla/service:host_memory_offload_annotations_hdr", + "//xla/service:layout_assignment", + "//xla/service:logical_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:reduction_utils", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor", + "//xla/stream_executor:dnn", + "//xla/tsl/util:env_var", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "layout_assignment_test", + srcs = ["layout_assignment_test.cc"], + deps = [ + ":layout_assignment", + "//xla:shape_layout", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:computation_layout", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:stream_executor_util", + "//xla/stream_executor:device_description", + "//xla/stream_executor:dnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "move_copy_to_users", + srcs = ["move_copy_to_users.cc"], + hdrs = ["move_copy_to_users.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "move_copy_to_users_test", + srcs = ["move_copy_to_users_test.cc"], + deps = [ + ":move_copy_to_users", + "//xla/service:layout_assignment", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "multi_output_fusion", + srcs = ["multi_output_fusion.cc"], + hdrs = ["multi_output_fusion.h"], + deps = [ + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/ir:hlo_dfs_reachability", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "multi_output_fusion_test", + srcs = ["multi_output_fusion_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":multi_output_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + ], +) + +cc_library( + name = "pipelined_p2p_rewriter", + srcs = ["pipelined_p2p_rewriter.cc"], + hdrs = ["pipelined_p2p_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "pipelined_p2p_rewriter_test", + srcs = ["pipelined_p2p_rewriter_test.cc"], + deps = [ + ":pipelined_p2p_rewriter", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "priority_fusion", + srcs = ["priority_fusion.cc"], + hdrs = ["priority_fusion.h"], + deps = [ + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:dump", + "//xla/service:fusion_queue", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_graph_dumper", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:fusion_process_dump_proto_cc", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/service/gpu/model:gpu_indexing_performance_model", + "//xla/service/gpu/model:gpu_performance_model", + "//xla/service/gpu/model:gpu_performance_model_base", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/meta:type_traits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "priority_fusion_test", + srcs = ["priority_fusion_test.cc"], + local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), + tags = ["no_pip"], + deps = [ + ":priority_fusion", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_cost_analysis", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu:hlo_fusion_analysis", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduce_scatter_creator", + srcs = ["reduce_scatter_creator.cc"], + hdrs = ["reduce_scatter_creator.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:collective_opt_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "reduce_scatter_creator_test", + srcs = ["reduce_scatter_creator_test.cc"], + deps = [ + ":reduce_scatter_creator", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_module_config", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "reduction_degenerate_dim_remover", + srcs = ["reduction_degenerate_dim_remover.cc"], + hdrs = ["reduction_degenerate_dim_remover.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_degenerate_dim_remover_test", + srcs = [ + "reduction_degenerate_dim_remover_test.cc", + ], + deps = [ + ":reduction_degenerate_dim_remover", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_dimension_grouper", + srcs = ["reduction_dimension_grouper.cc"], + hdrs = ["reduction_dimension_grouper.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_dimension_grouper_test", + srcs = [ + "reduction_dimension_grouper_test.cc", + ], + deps = [ + ":reduction_dimension_grouper", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_layout_normalizer", + srcs = ["reduction_layout_normalizer.cc"], + hdrs = ["reduction_layout_normalizer.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "reduction_layout_normalizer_test", + srcs = [ + "reduction_layout_normalizer_test.cc", + ], + backends = ["gpu"], + deps = [ + ":reduction_layout_normalizer", + "//xla:error_spec", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "reduction_splitter", + srcs = ["reduction_splitter.cc"], + hdrs = ["reduction_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:reduction_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "reduction_splitter_test", + srcs = ["reduction_splitter_test.cc"], + deps = [ + ":reduction_splitter", + "//xla:shape_util", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + ], +) + +cc_library( + name = "rename_fusions", + srcs = ["rename_fusions.cc"], + hdrs = ["rename_fusions.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_cc_test( + name = "rename_fusions_test", + srcs = ["rename_fusions_test.cc"], + deps = [ + ":rename_fusions", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + ], +) + +cc_library( + name = "sanitize_constant_names", + srcs = ["sanitize_constant_names.cc"], + hdrs = ["sanitize_constant_names.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:name_uniquer", + "//xla/service/llvm_ir:buffer_assignment_util", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + ], +) + +xla_cc_test( + name = "sanitize_constant_names_test", + srcs = ["sanitize_constant_names_test.cc"], + deps = [ + ":sanitize_constant_names", + "//xla:literal_util", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "scatter_slice_simplifier", + srcs = ["scatter_slice_simplifier.cc"], + hdrs = ["scatter_slice_simplifier.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "scatter_expander", + srcs = ["scatter_expander.cc"], + hdrs = ["scatter_expander.h"], + deps = [ + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:scatter_expander", + "@com_google_absl//absl/strings:string_view", + ], +) + +xla_cc_test( + name = "scatter_slice_simplifier_test", + srcs = ["scatter_slice_simplifier_test.cc"], + deps = [ + ":scatter_slice_simplifier", + "//xla:shape_util", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "schedule_postprocessing", + srcs = ["schedule_postprocessing.cc"], + hdrs = ["schedule_postprocessing.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "schedule_postprocessing_test", + srcs = ["schedule_postprocessing_test.cc"], + deps = [ + ":schedule_postprocessing", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "scheduling_instruction_annotator", + srcs = ["scheduling_instruction_annotator.cc"], + hdrs = ["scheduling_instruction_annotator.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "scheduling_instruction_annotator_test", + srcs = ["scheduling_instruction_annotator_test.cc"], + deps = [ + ":scheduling_instruction_annotator", + "//xla/hlo/ir:hlo", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "softmax_rewriter_triton", + srcs = ["softmax_rewriter_triton.cc"], + hdrs = ["softmax_rewriter_triton.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_cost_analysis", + "//xla/service:hlo_pass", + "//xla/service:instruction_fusion", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:hlo_traversal", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:fusion_analysis_cache", + "//xla/service/gpu/model:gpu_indexing_performance_model", + "//xla/service/gpu/model:symbolic_tile_analysis", + "//xla/service/gpu/model:tiled_hlo_computation", + "//xla/service/gpu/model:triton_emitter_constraints", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "softmax_rewriter_triton_test", + srcs = ["softmax_rewriter_triton_test.cc"], + deps = [ + ":softmax_rewriter_triton", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:instruction_fusion", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_device_info_for_tests", + "//xla/service/gpu/fusions/triton:triton_support", + "//xla/service/gpu/model:gpu_hlo_cost_analysis", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # build_cleaner: keep + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "sort_rewriter", + srcs = if_gpu_is_configured( + ["sort_rewriter.cc"], + ["sort_rewriter_stub.cc"], + ), + hdrs = ["sort_rewriter.h"], + deps = [ + "//xla:comparison_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:stable_sort_expander", + "//xla/service/gpu:cublas_cudnn", + "//xla/service/gpu/runtime:cub_sort_thunk", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "sort_rewriter_test", + srcs = if_cuda_is_configured(["sort_rewriter_test.cc"]), + backends = ["gpu"], + tags = ["no_oss"], + deps = [ + ":sort_rewriter", + "//xla:error_spec", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:cublas_cudnn", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "stream_attribute_annotator", + srcs = ["stream_attribute_annotator.cc"], + hdrs = ["stream_attribute_annotator.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:gpu_fusible", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_annotator_test", + srcs = ["stream_attribute_annotator_test.cc"], + deps = [ + ":stream_attribute_annotator", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "stream_attribute_async_wrapper", + srcs = ["stream_attribute_async_wrapper.cc"], + hdrs = ["stream_attribute_async_wrapper.h"], + deps = [ + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu/runtime:thunk", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "stream_attribute_async_wrapper_test", + srcs = ["stream_attribute_async_wrapper_test.cc"], + deps = [ + ":stream_attribute_async_wrapper", + "//xla/hlo/ir:hlo", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "topk_specializer", + srcs = ["topk_specializer.cc"], + hdrs = ["topk_specializer.h"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:hlo_proto_cc", + "//xla/service:tuple_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + ], +) + +xla_test( + name = "topk_specializer_test", + srcs = ["topk_specializer_test.cc"], + backends = ["gpu"], + deps = [ + ":topk_specializer", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:platform_util", + "//xla/service:topk_rewriter", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "topk_splitter", + srcs = ["topk_splitter.cc"], + hdrs = ["topk_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "topk_splitter_test", + srcs = ["topk_splitter_test.cc"], + deps = [ + ":topk_splitter", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_dce", + "//xla/service:pattern_matcher", + "//xla/service:topk_rewriter", + "//xla/tests:hlo_test_base", + "//xla/tests:verified_hlo_module", + "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + ], +) + +cc_library( + name = "transpose_dimension_grouper", + srcs = ["transpose_dimension_grouper.cc"], + hdrs = ["transpose_dimension_grouper.h"], + deps = [ + "//xla:permutation_util", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "transpose_dimension_grouper_test", + srcs = [ + "transpose_dimension_grouper_test.cc", + ], + deps = [ + ":transpose_dimension_grouper", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +cc_library( + name = "tree_reduction_rewriter", + srcs = ["tree_reduction_rewriter.cc"], + hdrs = ["tree_reduction_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:collective_ops_utils", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "tree_reduction_rewriter_test", + srcs = [ + "tree_reduction_rewriter_test.cc", + ], + deps = [ + ":tree_reduction_rewriter", + "//xla/stream_executor:device_description", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +# TODO(b/358278858): Currently lacking test coverage. +cc_library( + name = "triangular_solve_rewriter", + srcs = ["triangular_solve_rewriter.cc"], + hdrs = ["triangular_solve_rewriter.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service/gpu:cublas_cudnn", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "triton_fusion_numerics_verifier", + srcs = ["triton_fusion_numerics_verifier.cc"], + hdrs = ["triton_fusion_numerics_verifier.h"], + tags = ["gpu"], + deps = [ + "//xla:shape_util", + "//xla:status_macros", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:executable", + "//xla/service:hlo_module_config", + "//xla/service:hlo_pass", + "//xla/service:shaped_buffer", + "//xla/service/gpu:backend_configs_cc", + "//xla/service/gpu:buffer_comparator", + "//xla/service/gpu:ir_emission_utils", + "//xla/service/gpu/autotuning:autotuner_compile_util", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:stream", + "//xla/tools:hlo_decomposer_lib", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_test( + name = "triton_fusion_numerics_verifier_test", + srcs = ["triton_fusion_numerics_verifier_test.cc"], + backend_tags = {"gpu": [ + "requires-gpu-sm80", + ]}, + tags = ["no_rocm"], + backends = ["gpu"], + deps = [ + ":triton_fusion_numerics_verifier", + "//xla:shape_util", + "//xla:test_helpers", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/service/gpu/autotuning:autotuner_compile_util", + "//xla/service/gpu/autotuning:autotuner_util", + "//xla/stream_executor:platform", + "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "variadic_op_splitter", + srcs = ["variadic_op_splitter.cc"], + hdrs = ["variadic_op_splitter.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "variadic_op_splitter_test", + srcs = ["variadic_op_splitter_test.cc"], + tags = [ + "nomsan", + ], + deps = [ + ":variadic_op_splitter", + "//xla:literal_util", + "//xla:shape_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:hlo_parser", + "//xla/service:pattern_matcher", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "windowed_einsum_handler", + srcs = ["windowed_einsum_handler.cc"], + hdrs = ["windowed_einsum_handler.h"], + deps = [ + "//xla:literal", + "//xla:literal_util", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/utils:hlo_query", + "//xla/service:hlo_creation_utils", + "//xla/service:hlo_pass", + "//xla/service:pattern_matcher", + "//xla/service:shape_inference", + "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", + ], +) + +xla_cc_test( + name = "windowed_einsum_handler_test", + srcs = ["windowed_einsum_handler_test.cc"], + deps = [ + ":windowed_einsum_handler", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/service/gpu:backend_configs_cc", + "//xla/tests:filecheck", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + ], +) + +cc_library( + name = "pgle_accuracy_checker", + srcs = ["pgle_accuracy_checker.cc"], + hdrs = ["pgle_accuracy_checker.h"], + deps = [ + "//xla/hlo/ir:hlo", + "//xla/service:hlo_pass", + "//xla/service:profile_guided_latency_estimator", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "pgle_accuracy_checker_test", + srcs = ["pgle_accuracy_checker_test.cc"], + deps = [ + ":pgle_accuracy_checker", + "//xla/hlo/ir:hlo", + "//xla/service:latency_hiding_scheduler", + "//xla/service:profile_guided_latency_estimator", + "//xla/service/gpu:gpu_latency_hiding_scheduler", + "//xla/tests:hlo_test_base", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + ], +) diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc index 21e8d6ca7c0bce..d59ae2b6a1d039 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.cc @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include "absl/log/check.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/xla_data.pb.h" diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h index 855359654395a0..f29b31e8bb737b 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ -#define XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ #include @@ -75,4 +75,4 @@ class GpuAlgebraicSimplifier : public AlgebraicSimplifier { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_GPU_ALGEBRAIC_SIMPLIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALGEBRAIC_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc index 135ddb12ddf0db..c1e52e90a417c0 100644 --- a/third_party/xla/xla/service/gpu/gpu_algebraic_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/algebraic_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_algebraic_simplifier.h" +#include "xla/service/gpu/transforms/algebraic_simplifier.h" #include diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.cc b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc similarity index 98% rename from third_party/xla/xla/service/gpu/algorithm_checker.cc rename to third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc index 3104293f8d255d..664d7b2722f923 100644 --- a/third_party/xla/xla/service/gpu/algorithm_checker.cc +++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/algorithm_checker.h" +#include "xla/service/gpu/transforms/algorithm_checker.h" #include diff --git a/third_party/xla/xla/service/gpu/algorithm_checker.h b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h similarity index 90% rename from third_party/xla/xla/service/gpu/algorithm_checker.h rename to third_party/xla/xla/service/gpu/transforms/algorithm_checker.h index f3b30c1c61f5f8..c2cf0d2c9f0f2a 100644 --- a/third_party/xla/xla/service/gpu/algorithm_checker.h +++ b/third_party/xla/xla/service/gpu/transforms/algorithm_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ -#define XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_ #include @@ -51,4 +51,4 @@ class AlgorithmChecker : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_ALGORITHM_CHECKER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALGORITHM_CHECKER_H_ diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc similarity index 97% rename from third_party/xla/xla/service/gpu/alias_passthrough_params.cc rename to third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc index 5dea5bc548374c..0d6bff333a0400 100644 --- a/third_party/xla/xla/service/gpu/alias_passthrough_params.cc +++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/alias_passthrough_params.h" +#include "xla/service/gpu/transforms/alias_passthrough_params.h" #include diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params.h b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h similarity index 89% rename from third_party/xla/xla/service/gpu/alias_passthrough_params.h rename to third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h index 029068a6b5b5c2..4cd4ab4300961d 100644 --- a/third_party/xla/xla/service/gpu/alias_passthrough_params.h +++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_ -#define XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -47,4 +47,4 @@ class AliasPassthroughParams : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_ALIAS_PASSTHROUGH_PARAMS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALIAS_PASSTHROUGH_PARAMS_H_ diff --git a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc rename to third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc index d8141232ebbd3f..32c1e5b9fc6f64 100644 --- a/third_party/xla/xla/service/gpu/alias_passthrough_params_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/alias_passthrough_params_test.cc @@ -13,10 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/alias_passthrough_params.h" +#include "xla/service/gpu/transforms/alias_passthrough_params.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc index fe2d2d1e145140..2f7c130fab449e 100644 --- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_all_gather_optimizer.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h similarity index 88% rename from third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h index e28e42246910f9..988c1f6a1bd5ba 100644 --- a/third_party/xla/xla/service/gpu/gpu_all_gather_optimizer.h +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ -#define XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -43,4 +43,4 @@ class AllGatherOptimizer : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_ALL_GATHER_OPTIMIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_GATHER_OPTIMIZER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc rename to third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc index 5db5ffd47def70..27f6d65df781a1 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_all_gather_optimizer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_gather_optimizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_all_gather_optimizer.h" +#include "xla/service/gpu/transforms/all_gather_optimizer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc similarity index 99% rename from third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc index 2e75ffaf55b12a..0e0e67ac063ed1 100644 --- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/all_reduce_blueconnect.h" +#include "xla/service/gpu/transforms/all_reduce_blueconnect.h" #include #include diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h similarity index 90% rename from third_party/xla/xla/service/gpu/all_reduce_blueconnect.h rename to third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h index 8633c77b0eba4b..6da0bbe4bfe377 100644 --- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect.h +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_ -#define XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_ #include @@ -53,4 +53,4 @@ class AllReduceBlueConnect : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_GPU_ALL_REDUCE_BLUECONNECT_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_BLUECONNECT_H_ diff --git a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc index a6a66c5189af42..cafbf24986ae7c 100644 --- a/third_party/xla/xla/service/gpu/all_reduce_blueconnect_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_blueconnect_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/all_reduce_blueconnect.h" +#include "xla/service/gpu/transforms/all_reduce_blueconnect.h" #include #include diff --git a/third_party/xla/xla/service/all_reduce_splitter.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc similarity index 99% rename from third_party/xla/xla/service/all_reduce_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc index ce1e0e2bc37fcd..51f71c06c800af 100644 --- a/third_party/xla/xla/service/all_reduce_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" #include #include diff --git a/third_party/xla/xla/service/all_reduce_splitter.h b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h similarity index 94% rename from third_party/xla/xla/service/all_reduce_splitter.h rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h index ac8dec7afa7833..91e081163035b1 100644 --- a/third_party/xla/xla/service/all_reduce_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ -#define XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -74,4 +74,4 @@ class AllReduceSplitter : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_ALL_REDUCE_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ALL_REDUCE_SPLITTER_H_ diff --git a/third_party/xla/xla/service/all_reduce_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc similarity index 99% rename from third_party/xla/xla/service/all_reduce_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc index 6725a50bc35c6f..ec2e66d1b66100 100644 --- a/third_party/xla/xla/service/all_reduce_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/all_reduce_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/all_reduce_splitter.h" +#include "xla/service/gpu/transforms/all_reduce_splitter.h" #include #include @@ -29,12 +29,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc index c2f6c04e5c274a..aa76aff4dfec49 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -29,7 +29,7 @@ limitations under the License. namespace xla { namespace gpu { -absl::StatusOr GpuAsyncCollectiveAnnotator::Run( +absl::StatusOr AsyncCollectiveAnnotator::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h similarity index 78% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h index 4000fbcbdd4991..1b41d5056b29d2 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ #include @@ -29,12 +29,12 @@ namespace xla { namespace gpu { // Annotate async collectives with CollectiveBackendConfig. -class GpuAsyncCollectiveAnnotator : public HloModulePass { +class AsyncCollectiveAnnotator : public HloModulePass { public: - explicit GpuAsyncCollectiveAnnotator(HloPredicate is_collective_async) + explicit AsyncCollectiveAnnotator(HloPredicate is_collective_async) : is_collective_async_(std::move(is_collective_async)) {} absl::string_view name() const override { - return "gpu-async-collective-annotator"; + return "async-collective-annotator"; } using HloPassInterface::Run; @@ -49,4 +49,4 @@ class GpuAsyncCollectiveAnnotator : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_ASYNC_COLLECTIVE_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_COLLECTIVE_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc index f874a7e565ea73..6622a7b2d20035 100644 --- a/third_party/xla/xla/service/gpu/gpu_async_collective_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/async_collective_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_async_collective_annotator.h" +#include "xla/service/gpu/transforms/async_collective_annotator.h" #include #include @@ -97,18 +97,18 @@ struct TestCase { absl::flat_hash_set expected_sync; }; -class GpuAsyncCollectiveAnnotatorTest +class AsyncCollectiveAnnotatorTest : public HloTestBase, public ::testing::WithParamInterface {}; -XLA_TEST_P(GpuAsyncCollectiveAnnotatorTest, Test) { +XLA_TEST_P(AsyncCollectiveAnnotatorTest, Test) { const TestCase& test_case = GetParam(); TF_ASSERT_OK_AND_ASSIGN( std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString, /*replica_count=*/2)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, GpuAsyncCollectiveAnnotator(test_case.is_async_predicate) - .Run(module.get())); + bool changed, + AsyncCollectiveAnnotator(test_case.is_async_predicate).Run(module.get())); EXPECT_TRUE(changed); // Assert that all async collectives are annotated with the backend config. @@ -175,8 +175,8 @@ std::string TestCaseName(const ::testing::TestParamInfo& test_case) { return test_case.param.test_name; } -INSTANTIATE_TEST_SUITE_P(GpuAsyncCollectiveAnnotatorTest, - GpuAsyncCollectiveAnnotatorTest, +INSTANTIATE_TEST_SUITE_P(AsyncCollectiveAnnotatorTest, + AsyncCollectiveAnnotatorTest, ::testing::ValuesIn(TestCases()), TestCaseName); } // namespace } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc new file mode 100644 index 00000000000000..e80f225e027508 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.cc @@ -0,0 +1,85 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/async_wrapper.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +absl::StatusOr AsyncWrapper::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + + XLA_VLOG_LINES( + 1, absl::StrCat("AsyncWrapper will process the following module:\n", + module->ToString())); + + std::deque computations; + computations.push_back(module->entry_computation()); + while (!computations.empty()) { + HloComputation* computation = computations.front(); + computations.pop_front(); + + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (predicate_(instruction)) { + XLA_VLOG_LINES( + 1, absl::StrCat( + "AsyncWrapper will make the following instruction async:\n", + instruction->ToString())); + // If the predicate matches, then wrap the instructions in async blocks. + TF_RETURN_IF_ERROR( + computation + ->CreateAsyncInstructions(instruction, + {ShapeUtil::MakeScalarShape(U32)}) + .status()); + changed = true; + continue; + } + + // Otherwise, follow any `calls` to discover other instructions that can + // potentially be made async. + if (instruction->opcode() == HloOpcode::kCall) { + std::copy(instruction->called_computations().begin(), + instruction->called_computations().end(), + std::back_inserter(computations)); + } + } + } + + XLA_VLOG_LINES( + 1, + absl::StrCat("AsyncWrapper finished processing the following module:\n", + module->ToString())); + return changed; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h new file mode 100644 index 00000000000000..d6cefe812b24de --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper.h @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla::gpu { + +// AsyncWrappers wrap instructions that match a given `predicate` into async +// blocks (i.e. `async-start` and `async-stop` instructions) so that they run +// concurrently. +class AsyncWrapper : public HloModulePass { + public: + using Predicate = std::function; + explicit AsyncWrapper(Predicate predicate) + : predicate_(std::move(predicate)) {} + + absl::string_view name() const override { return "async-wrapper"; } + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + const Predicate predicate_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_ASYNC_WRAPPER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc new file mode 100644 index 00000000000000..9d698991afcbd6 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/async_wrapper_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/async_wrapper.h" + +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tests/literal_test_util.h" +#include "xla/tests/verified_hlo_module.h" +#include "tsl/platform/status_matchers.h" + +namespace xla::gpu { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +class AsyncWrapperTest : public HloTestBase {}; + +int CountAsyncInstructions(HloComputation* computation) { + int count = 0; + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->IsAsynchronous()) ++count; + } + return count; +} + +TEST_F(AsyncWrapperTest, BasicFusion) { + const char* hlo_text = R"( + HloModule m + + double1 { + p0 = f32[1] parameter(0) + ROOT add = f32[1] add(p0, p0) + } + + double2 { + p0 = f32[1] parameter(0) + ROOT add = f32[1] add(p0, p0) + } + + ENTRY main { + p0 = f32[1] parameter(0) + agg1 = f32[1] fusion(p0), kind=kLoop, calls=double1 + agg2 = f32[1] fusion(p0), kind=kLoop, calls=double2 + ROOT done = f32[1] add(agg1, agg2) + })"; + + std::unique_ptr module = + ParseAndReturnVerifiedModule(hlo_text).value(); + + AsyncWrapper wrapper([](const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kFusion; + }); + EXPECT_THAT(wrapper.HloModulePass::Run(module.get()), IsOkAndHolds(true)); + EXPECT_EQ(CountAsyncInstructions(module->entry_computation()), 4); + + Literal argument = LiteralUtil::CreateR1({1.0}); + Literal expected = LiteralUtil::CreateR1({4.0}); + + Literal result = ExecuteNoHloPasses(std::move(module), {&argument}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc similarity index 99% rename from third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc rename to third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc index 9102d75ddbd1a8..07f52fd8950e2c 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/collective_permute_cycle_decomposer.h" +#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h similarity index 92% rename from third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h rename to third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h index 508a8597ee42fe..cfacd6629e5519 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer.h +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ -#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ #include @@ -70,4 +70,4 @@ class CollectivePermuteCycleDecomposer : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_CYCLE_DECOMPOSER_H_ diff --git a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc similarity index 90% rename from third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc rename to third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc index 7f297ad1e615f1..ae537d9ac9b019 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_cycle_decomposer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_cycle_decomposer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/collective_permute_cycle_decomposer.h" +#include "xla/service/gpu/transforms/collective_permute_cycle_decomposer.h" #include @@ -124,12 +124,11 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycle) { EXPECT_EQ(cp1->operand(0), cp2->operand(0)); EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value()); EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}")); - EXPECT_THAT(cp1->ToString(), - HasSubstr("_xla_send_recv_validation=\"{{3,10}}\"")); + EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}")); EXPECT_THAT(cp2->ToString(), HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}")); EXPECT_THAT(cp2->ToString(), - HasSubstr("_xla_send_recv_validation=\"{{0,7},{1,8},{2,9}}\"")); + HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}")); check_metadata(cp1); check_metadata(cp2); } @@ -150,11 +149,14 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { iter = u32[] get-tuple-element(param), index=0 data = f32[2,2] get-tuple-element(param), index=1 weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), channel_id=1, source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + cp = f32[2,2] collective-permute(data), + channel_id=1, + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}}, + frontend_attributes={_xla_send_recv_validation="{{0,7},{1,8},{2,9},{3,10}}"} + matmul = f32[2,2] dot(weights, cp), lhs_contracting_dims={1}, rhs_contracting_dims={0} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } ENTRY test_computation { @@ -178,8 +180,11 @@ TEST_F(CollectivePermuteCycleDecomposerTest, ForwardCycleWithMatmul) { DynCast( FindInstruction(module.get(), "collective-permute.1")); EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{3,0}}")); + EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{3,10}}")); EXPECT_THAT(cp2->ToString(), HasSubstr("source_target_pairs={{0,1},{1,2},{2,3}}")); + EXPECT_THAT(cp2->ToString(), + HasSubstr("_xla_send_recv_validation={{0,7},{1,8},{2,9}}")); } TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { @@ -216,12 +221,11 @@ TEST_F(CollectivePermuteCycleDecomposerTest, BackwardCycle) { EXPECT_EQ(cp1->operand(0), cp2->operand(0)); EXPECT_GT(cp2->channel_id().value(), cp1->channel_id().value()); EXPECT_THAT(cp1->ToString(), HasSubstr("source_target_pairs={{0,3}}")); - EXPECT_THAT(cp1->ToString(), - HasSubstr("_xla_send_recv_validation=\"{{0,7}}\"")); + EXPECT_THAT(cp1->ToString(), HasSubstr("_xla_send_recv_validation={{0,7}}")); EXPECT_THAT(cp2->ToString(), HasSubstr("source_target_pairs={{1,0},{2,1},{3,2}}")); EXPECT_THAT(cp2->ToString(), - HasSubstr("_xla_send_recv_validation=\"{{1,8},{2,9},{3,10}}\"")); + HasSubstr("_xla_send_recv_validation={{1,8},{2,9},{3,10}}")); check_metadata(cp1); check_metadata(cp2); } diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc similarity index 98% rename from third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc index b1e8812c87f351..e9df22abeddcfe 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h" +#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" #include "xla/literal_util.h" #include "xla/service/collective_ops_utils.h" diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h similarity index 90% rename from third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h rename to third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h index e6b04c953a7024..f8999a9a9ba9fc 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -55,4 +55,4 @@ class CollectivePermuteValidIterationAnnotator : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_GPU_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_COLLECTIVE_PERMUTE_VALID_ITERATION_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc index 3d1d0b4e3858cb..3585acc2a9fb2f 100644 --- a/third_party/xla/xla/service/gpu/collective_permute_valid_iteration_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/collective_permute_valid_iteration_annotator_test.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/collective_permute_valid_iteration_annotator.h" +#include "xla/service/gpu/transforms/collective_permute_valid_iteration_annotator.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instructions.h" diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc similarity index 98% rename from third_party/xla/xla/service/gpu/command_buffer_scheduling.cc rename to third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc index d113a6b5d01ec8..ffa2866ac739db 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/command_buffer_scheduling.h" +#include "xla/service/gpu/transforms/command_buffer_scheduling.h" #include #include @@ -178,7 +178,7 @@ static bool IsCommand(const HloInstruction* hlo, const auto& custom_config = backend_config.custom_fusion_config(); if (custom_config.name() == "address_computation") { auto fusion_analysis = - HloFusionAnalysis::Create(fusion, &config.device_description); + HloFusionAnalysis::Create(*hlo, config.device_description); const HloFusionAdaptor& adaptor = fusion_analysis.fusion(); auto custom_call_adaptor = HloBfsFindIf( adaptor.GetRoots(), adaptor, @@ -232,6 +232,9 @@ static bool IsAsyncStartCommand(const HloInstruction* hlo, } if (hlo->opcode() == HloOpcode::kAsyncStart) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } @@ -248,6 +251,9 @@ static bool IsAsyncDoneCommand(const HloInstruction* hlo, } if (hlo->opcode() == HloOpcode::kAsyncDone) { + if (IsCublasGemm(*hlo->async_wrapped_instruction())) { + return config.enabled_commands.contains(DebugOptions::CUBLAS); + } if (hlo->async_wrapped_opcode() == HloOpcode::kReduceScatter) { return config.enabled_commands.contains(DebugOptions::COLLECTIVES); } diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h similarity index 96% rename from third_party/xla/xla/service/gpu/command_buffer_scheduling.h rename to third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h index 78590a80359e9e..30e0249ca04175 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling.h +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ -#define XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_ #include #include @@ -140,4 +140,4 @@ class CommandBufferScheduling : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_COMMAND_BUFFER_SCHEDULING_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_COMMAND_BUFFER_SCHEDULING_H_ diff --git a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc rename to third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc index bda31a05980b19..843428f6467909 100644 --- a/third_party/xla/xla/service/gpu/command_buffer_scheduling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/command_buffer_scheduling_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/command_buffer_scheduling.h" +#include "xla/service/gpu/transforms/command_buffer_scheduling.h" #include #include @@ -29,7 +29,7 @@ limitations under the License. #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -1013,6 +1013,38 @@ ENTRY e { }); } -} // namespace +TEST_F(CommandBufferSchedulingTest, AsyncCustomCall) { + const char* hlo = R"( + HloModule m, is_scheduled=true + + ENTRY %main (a: s32[], b: s32[]) -> f32[2,2] { + %p = f32[2,2]{1,0} parameter(0) + %start1 = ((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) custom-call-start(f32[2,2] %p, f32[2,2] %p), custom_call_target="__cublas$gemm" + %start2 = ((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) custom-call-start(f32[2,2] %p, f32[2,2] %p), custom_call_target="__cublas$gemm" + %done1 = (f32[2,2], s8[4]) custom-call-done(((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) %start1) + %done2 = (f32[2,2], s8[4]) custom-call-done(((f32[2,2], f32[2,2]), (f32[2,2], s8[4]), u32[]) %start2) + %result1 = f32[2,2] get-tuple-element((f32[2,2], s8[4]) %done1), index=0 + %result2 = f32[2,2] get-tuple-element((f32[2,2], s8[4]) %done2), index=0 + ROOT %sum = f32[2,2] add(f32[2,2] %result1, f32[2,2] %result2) + })"; + + const char* expected = R"( +// CHECK: %command_buffer ([[P:.+]]: f32[2,2]) -> ((f32[2,2], s8[4]), (f32[2,2], s8[4])) { +// CHECK: %[[P]] = f32[2,2]{1,0} parameter(0) +// CHECK: %[[S1:.+]] = ((f32[2,2]{1,0}, f32[2,2]{1,0}), (f32[2,2]{1,0}, s8[4]{0}), u32[]) custom-call-start(%[[P]], %[[P]]), custom_call_target="__cublas$gemm" +// CHECK: %[[S2:.+]] = ((f32[2,2]{1,0}, f32[2,2]{1,0}), (f32[2,2]{1,0}, s8[4]{0}), u32[]) custom-call-start(%[[P]], %[[P]]), custom_call_target="__cublas$gemm" +// CHECK: %[[D1:.+]] = (f32[2,2]{1,0}, s8[4]{0}) custom-call-done(%[[S1]]) +// CHECK: %[[D2:.+]] = (f32[2,2]{1,0}, s8[4]{0}) custom-call-done(%[[S2]]) +// CHECK: ROOT %[[T:.+]] = ((f32[2,2]{1,0}, s8[4]{0}), (f32[2,2]{1,0}, s8[4]{0})) tuple(%[[D1]], %[[D2]]) +// CHECK: })"; + RunAndFilecheckHloRewrite( + hlo, CommandBufferScheduling(device_desc(), kCudaVersion, kCudaVersion), + expected, [](HloModule* module) { + EXPECT_TRUE(module->has_schedule()); + TF_CHECK_OK(module->schedule().Verify()); + }); +} + +} // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc index 0b55f7d264ff00..f072a91307644b 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" #include #include @@ -166,7 +166,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, } } // namespace -bool GpuConvPaddingLegalization::CanonicalizeForwardConvolution( +bool ConvPaddingLegalization::CanonicalizeForwardConvolution( HloInstruction* conv) { if (IsForwardConvolutionCanonical(*conv)) { return false; @@ -219,7 +219,7 @@ void IncreasePaddingHighBy(int64_t delta, WindowDimension* window_dim) { } } // namespace -bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( +bool ConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { CHECK_EQ(backward_conv->custom_call_target(), kCudnnConvBackwardFilterCallTarget); @@ -292,7 +292,7 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardFilterConvolution( return true; } -bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( +bool ConvPaddingLegalization::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; @@ -418,7 +418,7 @@ bool GpuConvPaddingLegalization::CanonicalizeBackwardInputConvolution( return true; } -absl::StatusOr GpuConvPaddingLegalization::RunOnComputation( +absl::StatusOr ConvPaddingLegalization::RunOnComputation( HloComputation* computation) { bool changed = false; std::vector convs; @@ -445,7 +445,7 @@ absl::StatusOr GpuConvPaddingLegalization::RunOnComputation( return changed; } -absl::StatusOr GpuConvPaddingLegalization::Run( +absl::StatusOr ConvPaddingLegalization::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h index 32e0238bed1b3d..1841c926d9545b 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization.h +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ -#define XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -30,10 +30,10 @@ namespace gpu { // An HLO pass that canonicalizes convolution instructions for GPU codegen. It // inserts Pad instructions before Convolution instructions with uncanonicalized // padding, so that they can be lowered to Cudnn/Miopen convolution. -class GpuConvPaddingLegalization : public HloModulePass { +class ConvPaddingLegalization : public HloModulePass { public: absl::string_view name() const override { - return "gpu-conv-padding-legalization"; + return "conv-padding-legalization"; } using HloPassInterface::Run; @@ -52,4 +52,4 @@ class GpuConvPaddingLegalization : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONV_PADDING_LEGALIZATION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_PADDING_LEGALIZATION_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc similarity index 93% rename from third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc rename to third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc index edaf9d053d77c9..06682e7d1affd6 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_padding_legalization_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_padding_legalization_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_padding_legalization.h" +#include "xla/service/gpu/transforms/conv_padding_legalization.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/gpu/cublas_cudnn.h" @@ -32,9 +32,9 @@ namespace { namespace m = ::xla::match; -using GpuConvPaddingLegalizationTest = HloTestBase; +using ConvPaddingLegalizationTest = HloTestBase; -TEST_F(GpuConvPaddingLegalizationTest, BackwardInputConvolve) { +TEST_F(ConvPaddingLegalizationTest, BackwardInputConvolve) { auto module = ParseAndReturnVerifiedModule(R"( HloModule convolution_module ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8[0]) { @@ -75,7 +75,7 @@ ENTRY %convolution (operand f64[2,2,2,3]{3,2,1,0}) -> (f64[2,2,4,4]{3,2,1,0}, u8 } )") .value(); - ASSERT_TRUE(GpuConvPaddingLegalization().Run(module.get()).value()); + ASSERT_TRUE(ConvPaddingLegalization().Run(module.get()).value()); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Tuple( m::Slice(m::GetTupleElement( diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc index cb5b1867241e58..e19622dc27911f 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include #include @@ -845,10 +845,10 @@ absl::StatusOr RunOnComputation(HloComputation* computation, } } // namespace -absl::StatusOr GpuConvRewriter::Run( +absl::StatusOr ConvRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), before:\n" + module->ToString()); + XLA_VLOG_LINES(2, "ConvRewriter::Run(), before:\n" + module->ToString()); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { @@ -856,11 +856,11 @@ absl::StatusOr GpuConvRewriter::Run( RunOnComputation(computation, compute_capability_)); changed |= result; } - XLA_VLOG_LINES(2, "GpuConvRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "ConvRewriter::Run(), after:\n" + module->ToString()); return changed; } -/*static*/ bool GpuConvRewriter::ConvIsLowerable(HloInstruction* conv) { +/*static*/ bool ConvRewriter::ConvIsLowerable(HloInstruction* conv) { return CanImplementAsGpuForwardConv(conv) || MatchBackwardFilter(conv) || MatchBackwardInput(conv); } diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h similarity index 82% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter.h index 74b860f239872c..69369f1f5cb54a 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ -#define XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -34,12 +34,12 @@ namespace gpu { // patterns of ops will be matched and fused into the custom call in // CudnnFusedConvRewriter. -class GpuConvRewriter : public HloModulePass { +class ConvRewriter : public HloModulePass { public: - explicit GpuConvRewriter(const se::GpuComputeCapability& compute_capability) + explicit ConvRewriter(const se::GpuComputeCapability& compute_capability) : compute_capability_(compute_capability) {}; - absl::string_view name() const override { return "gpu-conv-rewriter"; } + absl::string_view name() const override { return "conv-rewriter"; } static bool ConvIsLowerable(HloInstruction* conv); @@ -55,4 +55,4 @@ class GpuConvRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONV_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc index f83bae8fc54586..d01ffd1829b7f8 100644 --- a/third_party/xla/xla/service/gpu/gpu_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/conv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include #include @@ -45,9 +45,9 @@ namespace { namespace m = ::xla::match; -class GpuConvRewriterTest : public HloTestBase { +class ConvRewriterTest : public HloTestBase { public: - GpuConvRewriterTest() + ConvRewriterTest() : HloTestBase(/*verifier_layout_sensitive=*/true, /*allow_mixed_precision_in_hlo_verifier=*/false) { for (int i = 0; i < 2; ++i) { @@ -103,7 +103,7 @@ class GpuConvRewriterTest : public HloTestBase { } bool RunPass(HloModule* module) { - return GpuConvRewriter(GetComputeCapability()).Run(module).value(); + return ConvRewriter(GetComputeCapability()).Run(module).value(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -113,7 +113,7 @@ class GpuConvRewriterTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) { +TEST_F(ConvRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -154,8 +154,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolve) { << md_after_opt.DebugString() << " vs " << metadata.DebugString(); } -TEST_F(GpuConvRewriterTest, - BackwardFilterConvolveEquivalentToForwardConvolution) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -186,7 +185,7 @@ TEST_F(GpuConvRewriterTest, } // Extracted from block35 training. -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -216,7 +215,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedActivations) { } // Extracted from inception v3 training. -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -245,7 +244,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithPaddedGradients) { m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0))); } -TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(ConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -274,7 +273,7 @@ TEST_F(GpuConvRewriterTest, BackwardFilterConvolveWithUnevenPadding) { m::CustomCall({kCudnnConvBackwardFilterCallTarget}), 0))); } -TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) { +TEST_F(ConvRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -343,7 +342,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) { +TEST_F(ConvRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -381,7 +380,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolve1x1Filter) { // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(GpuConvRewriterTest, +TEST_F(ConvRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -427,7 +426,7 @@ TEST_F(GpuConvRewriterTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -479,7 +478,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(ConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -533,7 +532,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // GpuConvPaddingLegalization will canonicalize the fusion HLO. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { +TEST_F(ConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. HloInstruction* output = @@ -590,7 +589,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { // We currently don't fuse BC because GpuConvPaddingLegalization // doesn't support negative padding on the gradients of backward convolution // (b/32744257). -TEST_F(GpuConvRewriterTest, +TEST_F(ConvRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -632,7 +631,7 @@ TEST_F(GpuConvRewriterTest, // Check that we will materialize a reversed version of a constant in order to // pattern-match a backwards input convolution. -TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) { +TEST_F(ConvRewriterTest, BackwardInputConvolveConstantFilter) { Array4D constant_arr(4, 4, 2, 2); constant_arr.FillIota(0); std::string constant_str = @@ -659,7 +658,7 @@ TEST_F(GpuConvRewriterTest, BackwardInputConvolveConstantFilter) { 0))); } -TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) { +TEST_F(ConvRewriterTest, TestBackwardFilterPatternMatch) { // All filter dimensions are larger than the corresponding output dimensions. // This must be a backward filter convolution. const std::string module_str = absl::StrFormat(R"( @@ -681,7 +680,7 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { +TEST_F(ConvRewriterTest, TestBackwardFilterPatternNoMatch) { // At least one filter dimension is smaller than the corresponding output // dimension. This must be a forward convolution. const std::string module_str = absl::StrFormat(R"( @@ -703,7 +702,7 @@ TEST_F(GpuConvRewriterTest, TestBackwardFilterPatternNoMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { +TEST_F(ConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { // There exist one kernel dimension equal to output dimension, regard // it as backward filter if conv is 1d. const std::string module_str = absl::StrFormat(R"( @@ -726,7 +725,7 @@ TEST_F(GpuConvRewriterTest, TestConv1dBackwardFilterPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { +TEST_F(ConvRewriterTest, TestConv1dBackwardInputPatternMatch) { // For conv1d backward input, filter may reverse first and then reshape. const std::string module_str = absl::StrFormat(R"( HloModule Test @@ -749,7 +748,7 @@ TEST_F(GpuConvRewriterTest, TestConv1dBackwardInputPatternMatch) { 0))); } -TEST_F(GpuConvRewriterTest, TestInvalidTypes) { +TEST_F(ConvRewriterTest, TestInvalidTypes) { const std::string module_str = absl::StrFormat(R"( HloModule Test @@ -766,8 +765,7 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_with_type)); - absl::Status s = - GpuConvRewriter(GetComputeCapability()).Run(m.get()).status(); + absl::Status s = ConvRewriter(GetComputeCapability()).Run(m.get()).status(); EXPECT_THAT( s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, @@ -780,17 +778,14 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fn"}}); TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_with_type)); - absl::Status s = GpuConvRewriter(se::CudaComputeCapability::Ampere()) - .Run(m.get()) - .status(); + absl::Status s = + ConvRewriter(se::CudaComputeCapability::Ampere()).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, ::testing::HasSubstr( "FP8 convolutions are only supported on CUDA " "GPUs with compute capability at least 9.0"))); - s = GpuConvRewriter(se::RocmComputeCapability{"gfx942"}) - .Run(m.get()) - .status(); + s = ConvRewriter(se::RocmComputeCapability{"gfx942"}).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, ::testing::HasSubstr( @@ -799,7 +794,7 @@ TEST_F(GpuConvRewriterTest, TestInvalidTypes) { // Test unsupported FP8 type module_with_type = absl::StrReplaceAll(module_str, {{"TYPE", "f8e4m3fnuz"}}); TF_ASSERT_OK_AND_ASSIGN(m, ParseAndReturnVerifiedModule(module_with_type)); - s = GpuConvRewriter(GetComputeCapability()).Run(m.get()).status(); + s = ConvRewriter(GetComputeCapability()).Run(m.get()).status(); EXPECT_THAT(s, tsl::testing::StatusIs( absl::StatusCode::kUnimplemented, diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc similarity index 97% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc index b8c87e2e1978a2..a7dc96ebefaf7e 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.cc +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h index ea56f7a91914ce..6507080a5fa49b 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync.h +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ -#define XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ #include @@ -44,4 +44,4 @@ class GpuConvertAsyncCollectivesToSync : public ConvertAsyncCollectivesToSync { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CONVERT_ASYNC_COLLECTIVES_TO_SYNC_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc rename to third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc index 4daeb62905f8a2..d38ab70864ac4c 100644 --- a/third_party/xla/xla/service/gpu/gpu_convert_async_collectives_to_sync_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/convert_async_collectives_to_sync_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_convert_async_collectives_to_sync.h" +#include "xla/service/gpu/transforms/convert_async_collectives_to_sync.h" #include @@ -25,8 +25,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/copy_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/copy_fusion.cc index a83354ca5d0508..eb43ca2364f0c8 100644 --- a/third_party/xla/xla/service/gpu/copy_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/copy_fusion.h" +#include "xla/service/gpu/transforms/copy_fusion.h" #include #include diff --git a/third_party/xla/xla/service/gpu/copy_fusion.h b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h similarity index 90% rename from third_party/xla/xla/service/gpu/copy_fusion.h rename to third_party/xla/xla/service/gpu/transforms/copy_fusion.h index 973b671f5978ea..a6a1ae48319a04 100644 --- a/third_party/xla/xla/service/gpu/copy_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_COPY_FUSION_H_ -#define XLA_SERVICE_GPU_COPY_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -46,4 +46,4 @@ class CopyFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_COPY_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_COPY_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/copy_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc index d2116eb68b0c2c..1bd2d11fe7ddc7 100644 --- a/third_party/xla/xla/service/gpu/copy_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/copy_fusion.h" +#include "xla/service/gpu/transforms/copy_fusion.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc similarity index 97% rename from third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc rename to third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc index f0da0e5855a1f4..82f88398c2bb0d 100644 --- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" #include #include @@ -25,9 +25,9 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/literal_util.h" -#include "xla/service/gpu/fusions/triton/triton_support.h" -#include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/shape.h" #include "xla/stream_executor/device_description.h" #include "xla/util.h" diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h similarity index 91% rename from third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h rename to third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h index 2a1f9c6f161cd8..8c7d8e53f8b91d 100644 --- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms.h +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_ -#define XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_ #include @@ -60,4 +60,4 @@ class CublasPadForGemms : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUBLAS_PAD_FOR_GEMMS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUBLAS_PAD_FOR_GEMMS_H_ diff --git a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc rename to third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc index d20dd94a06e7a1..77a32c935e412e 100644 --- a/third_party/xla/xla/service/gpu/cublas_pad_for_gemms_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cublas_pad_for_gemms_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cublas_pad_for_gemms.h" +#include "xla/service/gpu/transforms/cublas_pad_for_gemms.h" #include #include diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc new file mode 100644 index 00000000000000..bbc5ed42e6d9dd --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.cc @@ -0,0 +1,340 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_compiler.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_clone_context.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/stream_executor_util.h" +#include "xla/stream_executor/cuda/cuda_dnn.h" +#include "xla/stream_executor/cuda/cudnn_frontend_helpers.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +inline absl::StatusOr AsCudnnFmhaMaskKind( + CudnnfMHABackendConfig_MaskType mask_type) { + switch (mask_type) { + case CudnnfMHABackendConfig::NO_MASK: + return CudnnfMHAMaskKind::kNoMask; + case CudnnfMHABackendConfig::PADDING: + return CudnnfMHAMaskKind::kPadding; + case CudnnfMHABackendConfig::CAUSAL: + return CudnnfMHAMaskKind::kCausal; + case CudnnfMHABackendConfig::PADDING_CAUSAL: + return CudnnfMHAMaskKind::kPaddingCausal; + case CudnnfMHABackendConfig::ALIBI: + return CudnnfMHAMaskKind::kAlibi; + default: + return xla::Internal("Unknown fmha mask kind."); + } +} + +using se::dnn::DataType; +using se::dnn::MatmulTensorDescriptor; +using se::dnn::TensorDescriptor; + +absl::StatusOr TensorDescriptorFor(const Shape &shape) { + TF_ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); + return TensorDescriptor::For(type, shape.dimensions(), + shape.layout().minor_to_major()); +} + +enum Side { LHS, RHS }; + +absl::StatusOr MatmulTensorDescriptorFor( + const Shape &shape, const DotDimensionNumbers &dnums, const Side side) { + TF_ASSIGN_OR_RETURN(const DataType type, + GetDNNDataTypeFromPrimitiveType(shape.element_type())); + return MatmulTensorDescriptor::For( + type, shape.dimensions(), shape.layout().minor_to_major(), + (side == LHS) ? dnums.lhs_batch_dimensions() + : dnums.rhs_batch_dimensions(), + (side == LHS) ? dnums.lhs_contracting_dimensions() + : dnums.rhs_contracting_dimensions()); +} + +absl::StatusOr HloCustomCallToCuDnnGraph( + se::dnn::DnnSupport &dnn_support, HloCustomCallInstruction *custom_call) { + if (IsFwdCustomCallTofMHA(*custom_call)) { + TF_ASSIGN_OR_RETURN(const xla::gpu::CudnnfMHAKind kind, + xla::gpu::GetCudnnfMHAKind(custom_call)); + TF_ASSIGN_OR_RETURN( + const auto gpu_config, + custom_call->backend_config()); + const xla::gpu::CudnnfMHABackendConfig &config = + gpu_config.cudnn_fmha_backend_config(); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor lhs_bmm1, + MatmulTensorDescriptorFor(custom_call->operand(0)->shape(), + config.bmm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor rhs_bmm1, + MatmulTensorDescriptorFor(custom_call->operand(1)->shape(), + config.bmm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor rhs_bmm2, + MatmulTensorDescriptorFor(custom_call->operand(2)->shape(), + config.bmm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + TensorDescriptor output, + TensorDescriptorFor(ShapeUtil::GetSubshape(custom_call->shape(), {0}))); + + std::optional activation; + const bool has_activation = + xla::ShapeUtil::TupleElementCount(custom_call->shape()) == 3; + if (has_activation) { + TF_ASSIGN_OR_RETURN( + activation, TensorDescriptorFor( + ShapeUtil::GetSubshape(custom_call->shape(), {1}))); + } + + std::optional bias; + if (kind == CudnnfMHAKind::kScaleBiasSoftmax || + kind == CudnnfMHAKind::kScaleBiasSoftmaxDropout) { + const HloInstruction &bias_hlo = *custom_call->operand(3); + TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(bias_hlo.shape())); + } + + const double dropout_rate = config.dropout_rate(); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionOperationGraph( + dnn_support, lhs_bmm1, rhs_bmm1, rhs_bmm2, output, bias, activation, + static_cast(config.fmha_scale()), dropout_rate > 0.0, + dropout_rate, dnn_mask_type)); + return std::move(graph); + } else { + TF_ASSIGN_OR_RETURN( + auto gpu_config, + custom_call->backend_config()); + xla::gpu::CudnnfMHABackendConfig &config = + *gpu_config.mutable_cudnn_fmha_backend_config(); + + int input_index = 0; + const Shape &bmm1_grad_gemm1_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape &bmm1_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape &bmm2_grad_gemm2_rhs_shape = + custom_call->operand(input_index++)->shape(); + const Shape bmm2_grad_gemm1_lhs_shape(config.intermediate_tensor_shape()); + ++input_index; + const Shape &d_output_shape = custom_call->operand(input_index++)->shape(); + + TF_ASSIGN_OR_RETURN(const CudnnfMHAKind kind, + GetCudnnfMHAKind(custom_call)); + + bool has_bias = (kind == CudnnfMHAKind::kBackwardScaleBiasSoftmax || + kind == CudnnfMHAKind::kBackwardScaleBiasSoftmaxDropout); + std::optional bias_shape; + if (has_bias) { + bias_shape = custom_call->operand(input_index++)->shape(); + } + + // Unused fwd_output_shape + ++input_index; + + if (config.mask_type() == xla::gpu::CudnnfMHABackendConfig::PADDING || + config.mask_type() == + xla::gpu::CudnnfMHABackendConfig::PADDING_CAUSAL) { + // skip q_seqlen and kv_seqlen + input_index += 2; + } + TF_RET_CHECK(input_index == custom_call->operand_count()); + + int output_index = 0; + const Shape &d_bmm1_lhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + const Shape &d_bmm1_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + const Shape &d_bmm2_rhs_shape = + ShapeUtil::GetSubshape(custom_call->shape(), {output_index++}); + bool has_dbias = custom_call->shape().tuple_shapes().size() == 5; + if (has_dbias) { + ++output_index; + } + // The last one is the workspace. + TF_RET_CHECK(output_index == + custom_call->shape().tuple_shapes().size() - 1); + + const DebugOptions &debug_options = + custom_call->GetModule()->config().debug_options(); + bool force_deterministic = + debug_options.xla_gpu_deterministic_ops() || + debug_options.xla_gpu_exclude_nondeterministic_ops(); + config.set_force_deterministic(force_deterministic); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(gpu_config)); + + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm1_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm1_rhs_shape, + config.bmm1_grad_gemm1_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm1_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm1_grad_gemm2_rhs_shape, + config.bmm1_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm1_lhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm1_lhs_shape, + config.bmm2_grad_gemm1_dot_dimension_numbers(), LHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor bmm2_grad_gemm2_rhs, + MatmulTensorDescriptorFor( + bmm2_grad_gemm2_rhs_shape, + config.bmm2_grad_gemm2_dot_dimension_numbers(), RHS)); + TF_ASSIGN_OR_RETURN( + MatmulTensorDescriptor d_output, + MatmulTensorDescriptorFor( + d_output_shape, config.bmm2_grad_gemm1_dot_dimension_numbers(), + RHS)); + + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_lhs, + TensorDescriptorFor(d_bmm1_lhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm1_rhs, + TensorDescriptorFor(d_bmm1_rhs_shape)); + TF_ASSIGN_OR_RETURN(TensorDescriptor d_bmm2_rhs, + TensorDescriptorFor(d_bmm2_rhs_shape)); + + std::optional bias; + if (bias_shape.has_value()) { + TF_ASSIGN_OR_RETURN(bias, TensorDescriptorFor(*bias_shape)); + } + + const double dropout_rate = config.dropout_rate(); + + TF_ASSIGN_OR_RETURN(CudnnfMHAMaskKind cudnn_mask_type, + AsCudnnFmhaMaskKind(config.mask_type())); + TF_ASSIGN_OR_RETURN( + se::dnn::FMHAMaskKind dnn_mask_type, + GetDNNFmhaMaskKindFromCudnnFmhaMaskKind(cudnn_mask_type)); + + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + se::gpu::GetCudnnFlashAttentionBackwardOperationGraph( + dnn_support, bmm1_grad_gemm1_rhs, bmm1_grad_gemm2_rhs, + bmm2_grad_gemm1_lhs, bmm2_grad_gemm2_rhs, d_output, d_bmm1_lhs, + d_bmm1_rhs, d_bmm2_rhs, bias, dropout_rate, config.seed(), + config.fmha_scale(), dropout_rate > 0.0, bias != std::nullopt, + dnn_mask_type, force_deterministic)); + return std::move(graph); + } +} + +class CuDnnCustomCallVisitor : public DfsHloRewriteVisitor { + public: + explicit CuDnnCustomCallVisitor(se::dnn::DnnSupport &dnn_support, + BinaryMap &compilation_results) + : dnn_support_(dnn_support), compilation_results_(compilation_results) {} + + void AddWorkspace(HloInstruction &hlo, int64_t workspace_size) { + if (workspace_size == 0) { + return; + } + VLOG(4) << "Applying workspace size " << workspace_size << " to " + << hlo.ToString(); + Shape *shape = hlo.mutable_shape(); + shape->mutable_tuple_shapes()->back().set_dimensions(0, workspace_size); + } + + absl::Status HandleCustomCall(HloInstruction *hlo) override { + if (!IsCustomCallTofMHA(*hlo)) { + return absl::OkStatus(); + } + + TF_ASSIGN_OR_RETURN(const std::string fingerprint_without_workspace, + FingerprintWithBackendConfig(*hlo)); + auto workspace_size_it = + workspace_sizes_.find(fingerprint_without_workspace); + if (workspace_size_it == workspace_sizes_.cend()) { + TF_ASSIGN_OR_RETURN( + se::gpu::CudnnGraph graph, + HloCustomCallToCuDnnGraph(dnn_support_, + DynCast(hlo))); + + const int64_t workspace_size = graph.Graph().get_workspace_size(); + workspace_sizes_.insert(workspace_size_it, + {fingerprint_without_workspace, workspace_size}); + AddWorkspace(*hlo, workspace_size); + + std::vector serialized_graph; + RETURN_IF_CUDNN_FRONTEND_ERROR(graph.Graph().serialize(serialized_graph)); + // Compute a new fingerprint with a potential workspace for the + // compilation results to match a fingerprint computed by the emitter. + TF_ASSIGN_OR_RETURN(const std::string fingerprint_with_workspace, + FingerprintWithBackendConfig(*hlo)); + compilation_results_[fingerprint_with_workspace] = + std::string(reinterpret_cast(serialized_graph.data()), + serialized_graph.size()); + } else { + VLOG(4) << "Cache hit."; + AddWorkspace(*hlo, workspace_size_it->second); + } + + MarkAsChanged(); + return absl::OkStatus(); + } + + private: + se::dnn::DnnSupport &dnn_support_; + BinaryMap &compilation_results_; + absl::flat_hash_map workspace_sizes_; +}; + +} // namespace + +absl::StatusOr CuDnnCustomCallCompiler::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + XLA_SCOPED_LOGGING_TIMER_LEVEL("cuDNN custom call compiler", 8); + return CuDnnCustomCallVisitor(dnn_support_, compilation_results_) + .RunOnModule(module, execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h similarity index 61% rename from third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h index de81d6d37f17fe..810286f91b8472 100644 --- a/third_party/xla/xla/service/gpu/cudnn_workspace_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_compiler.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_ -#define XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/hlo_pass_interface.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/stream_executor.h" @@ -27,14 +28,18 @@ limitations under the License. namespace xla { namespace gpu { -// Rewrite cuDNN custom call to have correct workspace size by build graph -// and serialize so we can use it later -class CuDnnWorkspaceRewriter : public HloModulePass { +// Compile cuDNN custom calls to binaries and serialize them. +// Also adjust them in HLO to have correct workspace size. +class CuDnnCustomCallCompiler : public HloModulePass { public: - explicit CuDnnWorkspaceRewriter(se::StreamExecutor& stream_exec) - : dnn_support_(*stream_exec.AsDnn()) {} + explicit CuDnnCustomCallCompiler(se::StreamExecutor& stream_exec, + BinaryMap& compilation_results) + : dnn_support_(*stream_exec.AsDnn()), + compilation_results_(compilation_results) {} - absl::string_view name() const override { return "cudnn-workspace-rewriter"; } + absl::string_view name() const override { + return "cudnn-custom-call-compiler"; + } using HloPassInterface::Run; absl::StatusOr Run( @@ -43,9 +48,10 @@ class CuDnnWorkspaceRewriter : public HloModulePass { private: se::dnn::DnnSupport& dnn_support_; + BinaryMap& compilation_results_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_WORKSPACE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_COMPILER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc new file mode 100644 index 00000000000000..71ed08c41fab7b --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.cc @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "tsl/platform/errors.h" + +namespace xla { +namespace gpu { +namespace { + +class CustomCallVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleCustomCall(HloInstruction *hlo) override { + if (hlo->custom_call_target() != kCuDnnFusionKind) { + return absl::OkStatus(); + } + HloComputation *computation = hlo->GetModule()->AddEmbeddedComputation( + hlo->called_computations()[0]->Clone()); + HloInstruction *fusion = + hlo->parent()->AddInstruction(HloInstruction::CreateFusion( + hlo->shape(), HloInstruction::FusionKind::kCustom, hlo->operands(), + computation)); + GpuBackendConfig gpu_config; + FusionBackendConfig &backend_config = + *gpu_config.mutable_fusion_backend_config(); + backend_config.set_kind(hlo->custom_call_target()); + TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(hlo, fusion)); + return absl::OkStatus(); + } +}; + +} // namespace + +absl::StatusOr CuDnnCustomCallConverter::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + return CustomCallVisitor().RunOnModule(module, execution_threads); +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h new file mode 100644 index 00000000000000..5397a4db80e91d --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter.h @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Converts custom calls with kCuDnnFusionKind backend config to +// fusions with the same backend config. Frameworks can pass computations +// outlined this way through StableHLO; after the conversion they can be +// processed by XLA using the existing pipeline for custom fusions. +class CuDnnCustomCallConverter : public HloModulePass { + public: + absl::string_view name() const override { + return "cudnn-custom-call-converter"; + } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_CUSTOM_CALL_CONVERTER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc new file mode 100644 index 00000000000000..ad29e1566330b8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_custom_call_converter_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/cudnn_custom_call_converter.h" + +#include +#include "xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +using ConverterTest = HloTestBase; + +TEST_F(ConverterTest, CustomCallGetsConvertedToCustomFusion) { + RunAndFilecheckHloRewrite(R"( +f { + a = s8[] parameter(0) + ROOT r = s8[] add(a, a) +} + +ENTRY e { + b = s8[] parameter(0) + ROOT c = s8[] custom-call(b), + custom_call_target="__cudnn$fusion", called_computations={f} +})", + CuDnnCustomCallConverter(), R"( +; CHECK: ROOT %fusion = s8[] fusion(%b), kind=kCustom, calls=%f +; CHECK-SAME: "fusion_backend_config":{"kind":"__cudnn$fusion"} + )"); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc index e9cb21b9fa6bd7..c51a76f63caf75 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fused_conv_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h similarity index 96% rename from third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h index 906a67a05dd3ad..5caf9c0d0a43af 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ -#define XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_ #include @@ -132,4 +132,4 @@ class CudnnFusedConvRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_FUSED_CONV_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_CONV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc index ac03122baebd03..5e22a6f2ec1af3 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_conv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_conv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fused_conv_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_conv_rewriter.h" #include #include @@ -35,34 +35,33 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/stream_executor_util.h" -#include "xla/service/hlo_module_config.h" -#include "xla/stream_executor/device_description.h" -#include "xla/stream_executor/dnn.h" -#include "xla/tests/verified_hlo_module.h" -#include "tsl/platform/statusor.h" - -#if GOOGLE_CUDA -#include "third_party/gpus/cuda/include/cuda.h" -#elif TENSORFLOW_USE_ROCM -#include "rocm/rocm_config.h" -#endif // GOOGLE_CUDA - #include "xla/service/algebraic_simplifier.h" #include "xla/service/convert_mover.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/gpu_conv_rewriter.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/conv_rewriter.h" #include "xla/service/hlo_constant_folding.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/service/reshape_mover.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/dnn.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" + +#if GOOGLE_CUDA +#include "third_party/gpus/cuda/include/cuda.h" +#elif TENSORFLOW_USE_ROCM +#include "rocm/rocm_config.h" +#endif // GOOGLE_CUDA namespace xla { namespace gpu { @@ -244,7 +243,7 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { // On older architectures, disregard layout information and only verify // the basic configuration of the convolution Custom Call using the number // of operands and the window_size and serialized graph attributes based - // on the GpuConvRewriter and CudnnFusedConvRewriter passes. + // on the ConvRewriter and CudnnFusedConvRewriter passes. std::string::size_type p0 = custom_call_string.find(':'); std::string::size_type p1 = custom_call_string.find("custom-call"); custom_call_string.erase(p0 + 1, p1 - p0 - 2); @@ -254,8 +253,8 @@ class CudnnFusedConvRewriterTest : public GpuCodegenTest { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(pre_hlo_string)); TF_ASSERT_OK_AND_ASSIGN( - bool changed, RunHloPass(GpuConvRewriter(GetCudaComputeCapability()), - module.get())); + bool changed, + RunHloPass(ConvRewriter(GetCudaComputeCapability()), module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( module->ToString(HloPrintOptions{}.set_print_operand_shape(false)), @@ -1317,7 +1316,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloat) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1351,7 +1350,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToInt8BiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1392,7 +1391,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestReluAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1446,7 +1445,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, TestConvInt8ToFloatBiasSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1492,7 +1491,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, Int8SideInputWithScaleAndReshape) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1547,7 +1546,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseAlpha) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1587,7 +1586,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1628,7 +1627,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseReluIfMultipleUses) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1678,7 +1677,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseElu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // elu fusion is only active on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1727,7 +1726,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseEluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1780,7 +1779,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseRelu6) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // relu6 fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1824,7 +1823,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseRelu6IfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1872,7 +1871,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseLeakyRelu) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); // Leaky-relu fusion is only enabled on Ampere+. CudnnFusedConvRewriter fuser{se::CudaComputeCapability(8, 0), GetDnnVersion(), @@ -1919,7 +1918,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseLeakyReluIfMultipleUses) { debug_opts.set_xla_gpu_use_runtime_fusion(true); m->mutable_config().set_debug_options(debug_opts); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -1967,7 +1966,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseAlphaIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2007,7 +2006,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2046,7 +2045,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2085,7 +2084,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseBiasThroughRelu) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2121,7 +2120,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseSideInputIfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2158,7 +2157,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseConvertToF16IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2192,7 +2191,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontFuseToS8IfMultipleUsers) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2225,7 +2224,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS32ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2252,7 +2251,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingS8ToF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2279,7 +2278,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, RemoveConvertByFusingF32ToS8) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2307,7 +2306,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, DontRemoveConvertDuetoMultpleUser) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2337,7 +2336,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2369,7 +2368,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2412,7 +2411,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseScaledSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2455,7 +2454,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseBiasAndSideInput) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2493,7 +2492,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, EffectiveScalarBias) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2536,7 +2535,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, StrengthReduceF32ToF16) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2583,7 +2582,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BroadcastReshapeTransposeAfterConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2636,7 +2635,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, NoStrengthReduceF32ToF16IfBiasIsF32) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2692,7 +2691,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32Constants) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2746,7 +2745,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, F32ConstantsNotLosslesslyConvertible) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2810,7 +2809,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, FuseReluBeforeConvert) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; @@ -2853,7 +2852,7 @@ TEST_F(CudnnFusedConvRewriterHloTest, BiasTypeMatchesConvTypeIfFp) { })"; TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - GpuConvRewriter rewriter{GetCudaComputeCapability()}; + ConvRewriter rewriter{GetCudaComputeCapability()}; TF_ASSERT_OK(RunHloPass(&rewriter, m.get()).status()); CudnnFusedConvRewriter fuser{GetCudaComputeCapability(), GetDnnVersion(), GetToolkitVersion()}; diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc index 23b238dad651a3..7fbd5898b27f28 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fused_mha_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" #include #include @@ -292,10 +292,7 @@ bool IsComputeCapabilityAndCudnnSupported( stream_executor::CudaComputeCapability cc, stream_executor::dnn::VersionInfo cudnn_version, stream_executor::dnn::VersionInfo supported_cudnn_version) { - // Enforce capability minor == 0 because hardware with a non-zero minor - // number typically has insufficient shared memory for cuDNN FMHA. - if (cc.IsAtLeastAmpere() && cc.minor == 0 && - cudnn_version >= supported_cudnn_version) { + if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) { return true; } VLOG(2) << absl::StrFormat( @@ -1636,7 +1633,7 @@ absl::StatusOr CudnnFusedMHARewriter::Run( if (!debug_options.xla_gpu_enable_cudnn_fmha() || !IsComputeCapabilityAndCudnnSupported( compute_capability_, cudnn_version, - stream_executor::dnn::VersionInfo(8, 9, 4))) { + stream_executor::dnn::VersionInfo(9, 0, 0))) { return false; } for (HloInstruction* instr : comp->MakeInstructionPostOrder()) { @@ -1723,9 +1720,8 @@ absl::StatusOr CudnnFusedMHARewriter::Run( } if (matched_bwd_result.matched_dbias && !(compute_capability_.IsAtLeastHopper() && - compute_capability_.minor == 0 && - cudnn_version >= stream_executor::dnn::VersionInfo(8, 9, 6))) { - VLOG(2) << "Flash attention dbias requires cudnn 8.9.6 + hopper."; + cudnn_version >= stream_executor::dnn::VersionInfo(9, 0, 0))) { + VLOG(2) << "Flash attention dbias requires cudnn 9.0.0 + hopper."; // restore fwd graph if bwd pattern match failed TF_RETURN_IF_ERROR( RestoreFwdGraph(comp, fwd_fmha_call, original_bmm2, activation, diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h similarity index 91% rename from third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h index f0aa6871caf90e..6a985eea5bf1ca 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_ -#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -56,4 +56,4 @@ class CudnnFusedMHARewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc index 2cf88b01a8fe8b..a64fd0624bea62 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fused_mha_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_rewriter.h" #include #include @@ -30,7 +30,7 @@ limitations under the License. #include "xla/service/computation_layout.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" -#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_module_config.h" @@ -44,9 +44,9 @@ limitations under the License. #include "xla/stream_executor/dnn.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA @@ -80,7 +80,7 @@ class CudnnFusedMhaRewriterTestHloTest : public HloTestBase { // Fake a supported compute capability to run tests, // we don't run any kernels in these tests so they should be safe // to run anywhere. - return se::dnn::VersionInfo(8, 9, 4); + return se::dnn::VersionInfo(9, 0, 0); } CudnnFusedMhaRewriterTestHloTest() @@ -1714,94 +1714,6 @@ ENTRY main { EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot())); } -TEST_F(CudnnFusedMhaRewriterTestHloTest, - F16TrainingBmm1ScaleBiasSoftmaxBmm2NonContractingDimNotDivisibleBy64) { - if (skip_reason_) GTEST_SKIP() << *skip_reason_; - const char* module_str = R"( -HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(f16[2,6,64,100]{3,2,1,0},f16[2,6,64,100]{3,2,1,0},f16[2,6,100,64]{3,2,1,0},f16[2,6,100,64]{3,2,1,0})->(f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0})}, allow_spmd_sharding_propagation_to_output={true,true,true,true} - -region_0.21 { - Arg_0.22 = f16[] parameter(0) - Arg_1.23 = f16[] parameter(1) - ROOT maximum = f16[] maximum(Arg_0.22, Arg_1.23) -} - -region_1.33 { - Arg_0.34 = f32[] parameter(0) - Arg_1.35 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0.34, Arg_1.35) -} - -region_2.55 { - Arg_0.56 = f16[] parameter(0) - Arg_1.57 = f16[] parameter(1) - ROOT add.1 = f16[] add(Arg_0.56, Arg_1.57) -} - -ENTRY main.82 { - Arg_0.1 = f16[2,6,64,100]{3,2,1,0} parameter(0), sharding={replicated} - Arg_1.2 = f16[2,6,64,100]{3,2,1,0} parameter(1), sharding={replicated} - dot.17 = f16[2,6,100,100]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - constant.22 = f16[] constant(2) - broadcast.24 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.22), dimensions={} - multiply.2 = f16[2,6,100,100]{3,2,1,0} multiply(dot.17, broadcast.24) - constant.19 = f16[] constant(1) - broadcast.13 = f16[2,6,100,100]{3,2,1,0} broadcast(constant.19), dimensions={} - add.3 = f16[2,6,100,100]{3,2,1,0} add(multiply.2, broadcast.13) - constant.21 = f16[] constant(0) - constant.15 = f16[] constant(-inf) - reduce.25 = f16[2,6,100]{2,1,0} reduce(add.3, constant.15), dimensions={3}, to_apply=region_0.21 - broadcast.17 = f16[2,6,100,100]{3,2,1,0} broadcast(reduce.25), dimensions={0,1,2} - subtract.1 = f16[2,6,100,100]{3,2,1,0} subtract(add.3, broadcast.17) - exponential.1 = f16[2,6,100,100]{3,2,1,0} exponential(subtract.1) - convert.5 = f32[2,6,100,100]{3,2,1,0} convert(exponential.1) - constant.17 = f32[] constant(0) - reduce.37 = f32[2,6,100]{2,1,0} reduce(convert.5, constant.17), dimensions={3}, to_apply=region_1.33 - convert.9 = f16[2,6,100]{2,1,0} convert(reduce.37) - broadcast.26 = f16[2,6,100,100]{3,2,1,0} broadcast(convert.9), dimensions={0,1,2} - divide.5 = f16[2,6,100,100]{3,2,1,0} divide(exponential.1, broadcast.26) - Arg_2.3 = f16[2,6,100,64]{3,2,1,0} parameter(2), sharding={replicated} - dot.46 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - Arg_3.4 = f16[2,6,100,64]{3,2,1,0} parameter(3), sharding={replicated} - dot.49 = f16[2,6,100,100]{3,2,1,0} dot(Arg_3.4, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} - divide.4 = f16[2,6,100,100]{3,2,1,0} divide(dot.49, broadcast.26) - broadcast.20 = f16[2,6,100]{2,1,0} broadcast(constant.19), dimensions={} - multiply.3 = f16[2,6,100]{2,1,0} multiply(convert.9, convert.9) - divide.3 = f16[2,6,100]{2,1,0} divide(broadcast.20, multiply.3) - broadcast.21 = f16[2,6,100,100]{3,2,1,0} broadcast(divide.3), dimensions={0,1,2} - multiply.4 = f16[2,6,100,100]{3,2,1,0} multiply(dot.49, broadcast.21) - multiply.5 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.4, exponential.1) - reduce.59 = f16[2,6,100]{2,1,0} reduce(multiply.5, constant.21), dimensions={3}, to_apply=region_2.55 - negate.2 = f16[2,6,100]{2,1,0} negate(reduce.59) - broadcast.25 = f16[2,6,100,100]{3,2,1,0} broadcast(negate.2), dimensions={0,1,2} - add.5 = f16[2,6,100,100]{3,2,1,0} add(divide.4, broadcast.25) - multiply.8 = f16[2,6,100,100]{3,2,1,0} multiply(add.5, exponential.1) - multiply.9 = f16[2,6,100,100]{3,2,1,0} multiply(multiply.8, broadcast.24) - dot.80 = f16[2,6,100,64]{3,2,1,0} dot(multiply.9, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} - dot = f16[2,6,64,100]{3,2,1,0} dot(Arg_0.1, multiply.9), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - dot.1 = f16[2,6,100,64]{3,2,1,0} dot(divide.5, Arg_3.4), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} - ROOT tuple.81 = (f16[2,6,100,64]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}, f16[2,6,64,100]{3,2,1,0}, f16[2,6,100,64]{3,2,1,0}) tuple(dot.46, dot.80, dot, dot.1) -} -)"; - - TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(module_str)); - CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), - GetCudnnVersion()}; - const auto status_or = RunHloPass(&fusedMhaRewriter, m.get()); - TF_ASSERT_OK(status_or.status()); - EXPECT_FALSE(status_or.value()); - - HloDCE dce; - TF_ASSERT_OK(RunHloPass(&dce, m.get()).status()); - - ComputationLayout computation_layout( - m->entry_computation()->ComputeProgramShape()); - - SCOPED_TRACE(m->ToString()); - EXPECT_THAT(m->entry_computation()->root_instruction(), - GmockMatch(m::Tuple(m::Dot(), m::Dot(), m::Dot(), m::Dot()))); -} - TEST_F(CudnnFusedMhaRewriterTestHloTest, F16TrainingBmm2Grad1IncorrectPattern) { if (skip_reason_) GTEST_SKIP() << *skip_reason_; const char* module_str = R"( @@ -2975,80 +2887,6 @@ ENTRY main.164_spmd { })))))); } -constexpr absl::string_view hlo_head_dim_not_multiple_of_64 = R"( -HloModule jit__reference, entry_computation_layout={(f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0}, f16[4,48,1024,16]{3,2,1,0})->f16[4,48,1024,16]{3,2,1,0}} - -region_0.26 { - Arg_0.27 = f32[] parameter(0) - Arg_1.28 = f32[] parameter(1) - ROOT maximum = f32[] maximum(Arg_0.27, Arg_1.28) -} - -region_1.37 { - Arg_0.38 = f32[] parameter(0) - Arg_1.39 = f32[] parameter(1) - ROOT add = f32[] add(Arg_0.38, Arg_1.39) -} - -ENTRY main.49 { - iota.2 = s32[1024,1024]{1,0} iota(), iota_dimension=0 - iota.3 = s32[1024,1024]{1,0} iota(), iota_dimension=1 - compare = pred[1024,1024]{1,0} compare(iota.2, iota.3), direction=GE - broadcast.4 = pred[4,48,1024,1024]{3,2,1,0} broadcast(compare), dimensions={2,3} - Arg_0.1 = f16[4,48,1024,16]{3,2,1,0} parameter(0) - Arg_1.2 = f16[4,48,1024,16]{3,2,1,0} parameter(1) - dot.9 = f16[4,48,1024,1024]{3,2,1,0} dot(Arg_0.1, Arg_1.2), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3} - constant.4 = f16[] constant(0.5) - broadcast.6 = f16[4,48,1024,1024]{3,2,1,0} broadcast(constant.4), dimensions={} - multiply = f16[4,48,1024,1024]{3,2,1,0} multiply(dot.9, broadcast.6) - convert.1 = f32[4,48,1024,1024]{3,2,1,0} convert(multiply) - constant.7 = f32[] constant(-inf) - reduce.30 = f32[4,48,1024]{2,1,0} reduce(convert.1, constant.7), dimensions={3}, to_apply=region_0.26 - broadcast.8 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.30), dimensions={0,1,2} - subtract = f32[4,48,1024,1024]{3,2,1,0} subtract(convert.1, broadcast.8) - exponential = f32[4,48,1024,1024]{3,2,1,0} exponential(subtract) - constant.6 = f32[] constant(0) - reduce.41 = f32[4,48,1024]{2,1,0} reduce(exponential, constant.6), dimensions={3}, to_apply=region_1.37 - broadcast.9 = f32[4,48,1024,1024]{3,2,1,0} broadcast(reduce.41), dimensions={0,1,2} - divide = f32[4,48,1024,1024]{3,2,1,0} divide(exponential, broadcast.9) - convert.2 = f16[4,48,1024,1024]{3,2,1,0} convert(divide) - Arg_2.3 = f16[4,48,1024,16]{3,2,1,0} parameter(2) - ROOT dot.48 = f16[4,48,1024,16]{3,2,1,0} dot(convert.2, Arg_2.3), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2} -} // main.49 -)"; - -TEST_F(CudnnFusedMhaRewriterTestHloTest, HeadDimNotMultipleOf64) { - if (skip_reason_) GTEST_SKIP() << *skip_reason_; - TF_ASSERT_OK_AND_ASSIGN( - auto m, ParseAndReturnVerifiedModule(hlo_head_dim_not_multiple_of_64, - GetModuleConfig())); - CudnnFusedMHARewriter fusedMhaRewriter{GetCudaComputeCapability(), - GetCudnnVersion()}; - TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); - - // head dim not a multiple of 64 should not be lowered with cuDNN < 8.9.6 - SCOPED_TRACE(m->ToString()); - EXPECT_THAT(m->entry_computation()->root_instruction(), GmockMatch(m::Dot())); - - // should be lowered with cuDNN >= 8.9.6 - CudnnFusedMHARewriter fusedMhaRewriterWithcuDNN8907{ - GetCudaComputeCapability(), se::dnn::VersionInfo(8, 9, 7)}; - TF_ASSERT_OK(RunHloPass(&fusedMhaRewriterWithcuDNN8907, m.get()).status()); - const HloInstruction* fmha; - - SCOPED_TRACE(m->ToString()); - EXPECT_THAT( - m->entry_computation()->root_instruction(), - GmockMatch(m::GetTupleElement( - m::CustomCall(&fmha, {kCudnnfMHASoftmaxCallTarget}), 0) - .WithShape(F16, {4, 48, 1024, 16}))); - TF_ASSERT_OK_AND_ASSIGN(auto gpu_config, - fmha->backend_config()); - const CudnnfMHABackendConfig& config = gpu_config.cudnn_fmha_backend_config(); - EXPECT_EQ(config.fmha_scale(), 0.5); - EXPECT_EQ(config.dropout_rate(), 0.0); -} - constexpr absl::string_view hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias = R"( HloModule jit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})->(bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[2,1024,4,64]{3,2,1,0}, bf16[4,1024,1024]{2,1,0})}, allow_spmd_sharding_propagation_to_parameters={true,true,true,true,true}, allow_spmd_sharding_propagation_to_output={true,true,true,true,true} @@ -3131,7 +2969,7 @@ TEST_F(CudnnFusedMhaRewriterTestHloTest, BF16Bmm1BiasSoftmaxBmm2PatternDbias) { ParseAndReturnVerifiedModule(hlo_BF16Bmm1BiasSoftmaxBmm2Pattern_dbias)); // require cudnn 8.9.6 + hopper for dbias CudnnFusedMHARewriter fusedMhaRewriter{se::CudaComputeCapability(9, 0), - se::dnn::VersionInfo(8, 9, 6)}; + se::dnn::VersionInfo(9, 0, 0)}; TF_ASSERT_OK(RunHloPass(&fusedMhaRewriter, m.get()).status()); ComputationLayout computation_layout( diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc index 665cc0bf824383..7299643818f5c7 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fused_mha_transpose_fusion.h" +#include "xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h similarity index 85% rename from third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h index 94ec229d7709db..825d97ed926560 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_transpose_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fused_mha_transpose_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ -#define XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -42,4 +42,4 @@ class CudnnFusedMHATransposeFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSED_MHA_TRANSPOSE_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc similarity index 90% rename from third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc index 18067960f80211..3ffd74e9e594b7 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_fusion_compiler.h" +#include "xla/service/gpu/transforms/cudnn_fusion_compiler.h" #include #include @@ -49,6 +49,7 @@ limitations under the License. #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/kernel_reuse_cache.h" #include "xla/service/gpu/matmul_utils.h" +#include "xla/service/gpu/stream_executor_util.h" #include "xla/service/gpu/triton_fusion_analysis.h" #include "xla/shape_util.h" #include "xla/stream_executor/cuda/cuda_dnn.h" @@ -212,10 +213,13 @@ class GemmDimensionAdapter { return GemmDimensionAdapter{*dot, std::move(analysis)}; } - bool DimensionsAndStrides(const HloInstruction& hlo, - const TritonFusionAnalysis::Scope scope, - std::vector& dimensions, - std::vector& strides) { + struct Result { + std::vector sizes; + std::vector strides; + }; + + std::optional DimensionsAndStrides( + const HloInstruction& hlo, const TritonFusionAnalysis::Scope scope) { const DotDimensionNumbers& dims = dot_.dot_dimension_numbers(); // GEMM fusions require a specific canonical order of dimensions. constexpr int kBatchDimensionIndex = 0; @@ -252,29 +256,33 @@ class GemmDimensionAdapter { case TritonFusionAnalysis::Scope::META: LOG(FATAL) << "Unsupported scope."; } - dimensions.reserve(dim_indices.size()); - strides.reserve(dim_indices.size()); + + Result result; + result.sizes.reserve(dim_indices.size()); + result.strides.reserve(dim_indices.size()); + for (const int index : dim_indices) { const auto* spec = analysis_.IterSpec(scope, &hlo, index); if (spec == nullptr) { - dimensions.push_back(1); - strides.push_back(strides.empty() ? 1 : strides.back()); + result.sizes.push_back(1); + result.strides.push_back( + result.strides.empty() ? 1 : result.strides.back()); continue; } else { if (spec->size() == 1) { // The dimension is not split, nothing to do. } else if (spec->size() == 2) { if (FusionLevel(hlo) < 3) { - return false; + return std::nullopt; } if (!dims.lhs_batch_dimensions().empty()) { VLOG(8) << "Noncontracting dimension split is not compatible with " "batch dimensions."; - return false; + return std::nullopt; } if (index != lhs_noncontracting_index) { VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } switch (scope) { case TritonFusionAnalysis::Scope::LHS: @@ -284,40 +292,40 @@ class GemmDimensionAdapter { if (lhs_noncontracting_split_ != spec->back().count) { VLOG(8) << "Output non-contracting dimension has to be split " "the same way as the LHS input one if it is split."; - return false; + return std::nullopt; } break; default: VLOG(8) << "Only LHS noncontracting dimension can be split."; - return false; + return std::nullopt; } // Assign the major part of the noncontracting dimension to the // unused batch one. - CHECK_EQ(dimensions[kBatchDimensionIndex], 1); - dimensions[kBatchDimensionIndex] = spec->back().count; - strides[kBatchDimensionIndex] = spec->back().stride; + CHECK_EQ(result.sizes[kBatchDimensionIndex], 1); + result.sizes[kBatchDimensionIndex] = spec->back().count; + result.strides[kBatchDimensionIndex] = spec->back().stride; } else { VLOG(8) << "The dimension is split multiple times."; - return false; + return std::nullopt; } - dimensions.push_back(spec->front().count); - strides.push_back(spec->front().stride); + result.sizes.push_back(spec->front().count); + result.strides.push_back(spec->front().stride); } } if (lhs_noncontracting_split_ > 1 && scope == TritonFusionAnalysis::Scope::OUTPUT && - dimensions[kBatchDimensionIndex] == 1) { + result.sizes[kBatchDimensionIndex] == 1) { // LHS input noncontracting dimension is split but the corresponding // output one is not. Assign part of the output one to the unused batch // dimension. - dimensions[kBatchDimensionIndex] = lhs_noncontracting_split_; - dimensions[kOutputLHSNonContractingDimensionIndex] /= + result.sizes[kBatchDimensionIndex] = lhs_noncontracting_split_; + result.sizes[kOutputLHSNonContractingDimensionIndex] /= lhs_noncontracting_split_; - strides[kBatchDimensionIndex] = - strides[kOutputLHSNonContractingDimensionIndex] * - dimensions[kOutputLHSNonContractingDimensionIndex]; + result.strides[kBatchDimensionIndex] = + result.strides[kOutputLHSNonContractingDimensionIndex] * + result.sizes[kOutputLHSNonContractingDimensionIndex]; } - return true; + return result; } private: @@ -396,8 +404,7 @@ absl::StatusOr> HloFusionToCuDnnGraph( return std::nullopt; } auto add_parameter = [&](const HloInstruction& parameter, - std::vector& dimensions, - std::vector strides) { + const GemmDimensionAdapter::Result& dims) { const std::optional data_type = ToCudnnDataType(parameter.shape().element_type()); if (!data_type.has_value()) { @@ -406,8 +413,8 @@ absl::StatusOr> HloFusionToCuDnnGraph( } hlo_to_cudnn[¶meter] = graph.tensor( graph::Tensor_attributes() - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims.sizes) + .set_stride(dims.strides) .set_data_type(*data_type) .set_name(std::string(parameter.name())) .set_uid(se::gpu::CuDnnTensorUID(parameter.parameter_number()))); @@ -418,14 +425,13 @@ absl::StatusOr> HloFusionToCuDnnGraph( TritonFusionAnalysis::Scope::OUTPUT}) { for (const HloInstruction* parameter : adapter->analysis_.ScopeParameters(scope)) { - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides(*parameter, scope, dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*parameter, scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } - if (!add_parameter(*parameter, dimensions, strides)) { + if (!add_parameter(*parameter, *dims)) { return std::nullopt; } } @@ -506,19 +512,19 @@ absl::StatusOr> HloFusionToCuDnnGraph( // setting output of the unary shapes results in the rejection of // the cuDNN graph. if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) { - const auto scope = adapter->analysis_.QueryInstructionScope(*hlo); - std::vector dimensions; - std::vector strides; + const std::optional scope = + adapter->analysis_.QueryInstructionScope(*hlo); if (!scope.has_value()) { LOG(FATAL) << "No scope for instruction: " << hlo->ToShortString(); } - if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions, - strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*hlo, *scope); + if (!dims.has_value()) { VLOG(3) << "Unsupported hlo for querying dimensions: " << hlo->ToShortString(); } else { - hlo_to_cudnn[hlo]->set_dim(dimensions); + hlo_to_cudnn[hlo]->set_dim(dims->sizes); } } } else if (hlo->operand_count() == 2) { @@ -562,17 +568,17 @@ absl::StatusOr> HloFusionToCuDnnGraph( if (instructions.back()->shape().IsTuple()) { output = instructions.back()->operand(0); } - std::vector dimensions; - std::vector strides; - if (!adapter->DimensionsAndStrides( - *output, TritonFusionAnalysis::Scope::OUTPUT, dimensions, strides)) { + const std::optional dims = + adapter->DimensionsAndStrides(*output, + TritonFusionAnalysis::Scope::OUTPUT); + if (!dims.has_value()) { VLOG(3) << "Unsupported dimensions."; return std::nullopt; } hlo_to_cudnn[output] ->set_output(true) - .set_dim(dimensions) - .set_stride(strides) + .set_dim(dims->sizes) + .set_stride(dims->strides) .set_uid(se::gpu::CuDnnTensorUID(fusion.operand_count())); if (!fusion.GetModule()->config().debug_options().xla_dump_to().empty()) { json dump; @@ -596,7 +602,10 @@ absl::StatusOr PrepareGraph( if (!graph.has_value()) { return absl::InternalError("Construction of cuDNN graph failed."); } - TF_RETURN_IF_ERROR(graph->Prepare(dnn_support)); + TF_RETURN_IF_ERROR(graph->Prepare( + dnn_support, + se::NumericOptions{RequireDeterminism(hlo.GetModule()->config()), + /*allow_tf32=*/true})); return *graph; } @@ -726,8 +735,9 @@ int CuDnnFusionCompiler::GetAvailablePlanCount( if (!graph.ok()) { return 0; } - constexpr int64_t kMaxPlans = 10; - return std::min(graph->Graph().get_execution_plan_count(), kMaxPlans); + return std::min( + static_cast(graph->Graph().get_execution_plan_count()), + hlo.GetModule()->config().debug_options().xla_gpu_cudnn_gemm_max_plans()); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h similarity index 91% rename from third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h index f34bbb1086a0f3..4917914f1a4cc5 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_fusion_compiler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ -#define XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -56,4 +56,4 @@ class CuDnnFusionCompiler : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_FUSION_COMPILER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_FUSION_COMPILER_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc index 5e78f4864ec334..5d5e089933fd88 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_norm_rewriter.h" +#include "xla/service/gpu/transforms/cudnn_norm_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h similarity index 89% rename from third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h index 7b3ef8d66e15fb..a2332d30052c8e 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_ -#define XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -45,4 +45,4 @@ class CudnnNormRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_NORM_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_NORM_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc similarity index 88% rename from third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc index 754563a535c23b..a3dbc71132949a 100644 --- a/third_party/xla/xla/service/gpu/cudnn_norm_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_norm_rewriter_test.cc @@ -47,6 +47,18 @@ class CudnnNormRewriterTest : public GpuCodegenTest { } protected: + void SetUp() override { +#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) + GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; +#endif + if (!(GetCudaComputeCapability().major == + se::CudaComputeCapability::AMPERE) && + !(GetCudaComputeCapability().major == + se::CudaComputeCapability::HOPPER)) { + GTEST_SKIP() + << "Layer norm kernels require Ampere or Hopper architectures."; + } + } void TestNorm(std::string hlo_text, std::string optimized_hlo) { EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-3, 1e-3})); MatchOptimizedHlo(hlo_text, optimized_hlo); @@ -56,16 +68,6 @@ class CudnnNormRewriterTest : public GpuCodegenTest { // The following tests evaluate LayerNormXDY configurations, with X the rank of // the input and Y the dimensions that are normalized. TEST_F(CudnnNormRewriterTest, LayerNorm2D1) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -125,16 +127,6 @@ TEST_F(CudnnNormRewriterTest, LayerNorm2D1) { } TEST_F(CudnnNormRewriterTest, LayerNorm4D3) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -194,16 +186,6 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3) { } TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -263,16 +245,6 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3Degenerate0) { } TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -314,7 +286,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -326,23 +298,14 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -384,7 +347,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,1,6,8], {{.*}}: f32[6], {{.*}}: f32[6]) -> f32[2,1,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,6]{3,2,1,0} transpose([[P0]]), dimensions={1,0,3,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -396,23 +359,14 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D2Degenerate1) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,6,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,1,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,6,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,1,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -454,7 +408,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> f32[2,4,6,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -466,23 +420,14 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -524,7 +469,7 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> f32[2,4,1,8] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -536,23 +481,14 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D12Degenerate2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: ROOT [[FUSION:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: [[FUSION:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-NEXT: ROOT {{.*}} = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -600,16 +536,6 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3IncorrectScaleBroadcast) { } TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -658,16 +584,6 @@ TEST_F(CudnnNormRewriterTest, LayerNorm4D3InputOutputTypeMismatch) { } TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -735,16 +651,6 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain2D1) { } TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -812,16 +718,6 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D3) { } TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -865,7 +761,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6]) -> (f32[2,4,6,8], f32[2,8], f32[2,8], f32[2,8]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -877,29 +773,20 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1 ; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]]) ; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2 ; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]]) ; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -943,7 +830,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1]) -> (f32[2,4,1,8], f32[2,8], f32[2,8], f32[2,8]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -955,29 +842,20 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrain4D12Degenerate2) { ; CHECK-DAG: "epsilon":0.001 ; CHECK: } ; CHECK-NEXT: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=0 -; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[FUSION0:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-NEXT: [[BITCAST:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-NEXT: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=1 ; CHECK-NEXT: [[GTE1_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE1]]) ; CHECK-NEXT: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC]]), index=2 ; CHECK-NEXT: [[GTE2_BITCAST:%[^ ]+]] = f32[2,8]{1,0} bitcast([[GTE2]]) ; CHECK-NEXT: [[FUSION1:%[^ ]+]] = f32[2,8]{1,0} fusion([[GTE2]]), kind=kLoop, calls=[[FUSED_COMPUTATION1:%[^ ]+]] -; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[FUSION0]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) +; CHECK-NEXT: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}, f32[2,8]{1,0}) tuple([[BITCAST]], [[GTE1_BITCAST]], [[GTE2_BITCAST]], [[FUSION1]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1083,16 +961,6 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward2D1) { } TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1198,16 +1066,6 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D3) { } TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1277,7 +1135,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[6], {{.*}}: f32[6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[6], f32[6]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P0]]), dimensions={0,1,3,2} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[6]{0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} bitcast([[P1]]) @@ -1290,9 +1148,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8,6]{3,2,1,0} transpose([[P3]]), dimensions={0,1,3,2} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[8,8,6]{2,1,0} fusion([[P3]]), kind=kLoop, calls{{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[64,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[64,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, f32[1,6,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1302,30 +1162,19 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D2) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[64,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[8,6,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,6,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[6]{0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[6]{0}, f32[6]{0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1395,7 +1244,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,6,8], {{.*}}: f32[4,6], {{.*}}: f32[4,6], {{.*}}: f32[2,4,6,8]) -> (f32[2,4,6,8], f32[2,4,6,8], f32[4,6], f32[4,6]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P0]]), dimensions={0,3,1,2} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,6]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} bitcast([[P1]]) @@ -1408,9 +1257,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[2,8,4,6]{3,2,1,0} transpose([[P3]]), dimensions={0,3,1,2} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[2,8,24]{2,1,0} fusion([[P3]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, f32[1,4,6,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1420,30 +1271,19 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} get-tuple-element([[FUSION]]), index=1 +; CHECK-DAG: [[FUSION:%[^ ]+]] = f32[2,24,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,6,8]{3,2,1,0} bitcast([[FUSION]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,6,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,6]{1,0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,6,8]{3,2,1,0}, f32[2,4,6,8]{3,2,1,0}, f32[4,6]{1,0}, f32[4,6]{1,0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); } TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1513,7 +1353,7 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-LABEL: ENTRY %test ({{.*}}: f32[2,4,1,8], {{.*}}: f32[4,1], {{.*}}: f32[4,1], {{.*}}: f32[2,4,1,8]) -> (f32[2,4,1,8], f32[2,4,1,8], f32[4,1], f32[4,1]) { ; CHECK-NEXT: [[P0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(0) -; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P0]]), dimensions={2,0,3,1} +; CHECK-NEXT: [[TRANSPOSE0:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P0]]), kind=kLoop, calls={{.*}} ; CHECK-NEXT: [[P0_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE0]]) ; CHECK-NEXT: [[P1:%[^ ]+]] = f32[4,1]{1,0} parameter(1) ; CHECK-NEXT: [[P1_BITCAST:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} bitcast([[P1]]) @@ -1526,9 +1366,11 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-DAG: "kind":"LAYER_FWD_TRAIN" ; CHECK: } ; CHECK-DAG: [[GTE0:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=0 +; CHECK-DAG: [[TRANSPOSE1:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE0]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[BITCAST:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[TRANSPOSE1]]) ; CHECK-DAG: [[P3:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} parameter(3) -; CHECK-NEXT: [[TRANSPOSE1:%[^ ]+]] = f32[1,2,8,4]{3,2,1,0} transpose([[P3]]), dimensions={2,0,3,1} -; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE1]]) +; CHECK-NEXT: [[TRANSPOSE2:%[^ ]+]] = f32[2,8,4]{2,1,0} fusion([[P3]]), kind=kLoop, calls={{.*}} +; CHECK-DAG: [[P3_BITCAST:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} bitcast([[TRANSPOSE2]]) ; CHECK-DAG: [[GTE1:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=1 ; CHECK-DAG: [[GTE2:%[^ ]+]] = f32[16,1,1,1]{3,2,1,0} get-tuple-element([[CC0]]), index=2 ; CHECK-NEXT: [[CC1:%[^ ]+]] = (f32[16,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, f32[1,4,1,1]{3,2,1,0}, u8[{{.*}}]{0}) custom-call([[P0_BITCAST]], [[P1_BITCAST]], [[P3_BITCAST]], [[GTE1]], [[GTE2]]), @@ -1538,14 +1380,13 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { ; CHECK-DAG: "kind":"LAYER_BWD" ; CHECK: } ; CHECK-DAG: [[GTE3:%[^ ]+]] = f32[16,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=0 -; CHECK-DAG: [[FUSION0:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}) fusion([[GTE0]], [[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] -; CHECK-DAG: [[GTEF0:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=0 -; CHECK-DAG: [[GTEF1:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} get-tuple-element([[FUSION0]]), index=1 +; CHECK-DAG: [[FUSION0:%[^ ]+]] = f32[2,4,8]{2,1,0} fusion([[GTE3]]), kind=kLoop, calls=[[FUSED_COMPUTATION0:%[^ ]+]] +; CHECK-DAG: [[BITCAST2:%[^ ]+]] = f32[2,4,1,8]{3,2,1,0} bitcast([[FUSION0]]) ; CHECK-DAG: [[GTE4:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=1 ; CHECK-DAG: [[GTE4_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE4]]) ; CHECK-DAG: [[GTE5:%[^ ]+]] = f32[1,4,1,1]{3,2,1,0} get-tuple-element([[CC1]]), index=2 ; CHECK-DAG: [[GTE5_BITCAST:%[^ ]+]] = f32[4,1]{1,0} bitcast([[GTE5]]) -; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[GTEF0]], [[GTEF1]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) +; CHECK-DAG: ROOT [[OUT:%[^ ]+]] = (f32[2,4,1,8]{3,2,1,0}, f32[2,4,1,8]{3,2,1,0}, f32[4,1]{1,0}, f32[4,1]{1,0}) tuple([[BITCAST]], [[BITCAST2]], [[GTE4_BITCAST]], [[GTE5_BITCAST]]) )"; TestNorm(hlo_text, optimized_hlo); @@ -1554,16 +1395,6 @@ TEST_F(CudnnNormRewriterTest, LayerNormTrainBackward4D12Degenerate2) { // TODO(b/343124533) Reenable when fixed TEST_F(CudnnNormRewriterTest, DISABLED_LayerNormTrainBackward4D1DoutputReshapeSplit) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test @@ -1675,16 +1506,6 @@ TEST_F(CudnnNormRewriterTest, // TODO(b/343124533) Reenable when fixed TEST_F(CudnnNormRewriterTest, DISABLED_LayerNormTrainBackward4D1DoutputReshapeCombine) { -#if (CUDA_VERSION < 12000 || CUDNN_VERSION < 8905) - GTEST_SKIP() << "Layer norm kernels require CUDA 12 and cuDNN 8.9.5."; -#endif - if (!(GetCudaComputeCapability().major == - se::CudaComputeCapability::AMPERE) && - !(GetCudaComputeCapability().major == - se::CudaComputeCapability::HOPPER)) { - GTEST_SKIP() - << "Layer norm kernels require Ampere or Hopper architectures."; - } const char* hlo_text = R"( HloModule test diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc index ed83a622c9da7c..2acdf9aa6e2a31 100644 --- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h similarity index 89% rename from third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h index be7fae26d6cd08..719efecb4bb39f 100644 --- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ -#define XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -50,4 +50,4 @@ class CudnnPadForConvolutions : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_PAD_FOR_CONVOLUTIONS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_PAD_FOR_CONVOLUTIONS_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc index 2bae2393581827..7cee2c54f166e7 100644 --- a/third_party/xla/xla/service/gpu/cudnn_pad_for_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_pad_for_convolutions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc index c8f87f7103e3ae..30e0a4b3bef621 100644 --- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_simplify_padding.h" +#include "xla/service/gpu/transforms/cudnn_simplify_padding.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h similarity index 93% rename from third_party/xla/xla/service/gpu/cudnn_simplify_padding.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h index 5811d26144c4fb..67580b44869d4f 100644 --- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_ -#define XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -64,4 +64,4 @@ class CudnnSimplifyPadding : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_CUDNN_SIMPLIFY_PADDING_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_SIMPLIFY_PADDING_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc index 4cd9b72ef8ea65..e924ccaed110c7 100644 --- a/third_party/xla/xla/service/gpu/cudnn_simplify_padding_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_simplify_padding_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_simplify_padding.h" +#include "xla/service/gpu/transforms/cudnn_simplify_padding.h" #include #include @@ -27,8 +27,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/service/algebraic_simplifier.h" #include "xla/service/call_inliner.h" -#include "xla/service/gpu/cudnn_pad_for_convolutions.h" -#include "xla/service/gpu/cudnn_vectorize_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_pad_for_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" #include "xla/service/hlo_pass_fix.h" #include "xla/service/hlo_pass_pipeline.h" #include "xla/service/pattern_matcher.h" @@ -38,8 +38,8 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/dnn.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc index 3846f01136a81f..698b8fb73dd579 100644 --- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_vectorize_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h similarity index 91% rename from third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h rename to third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h index 43165f24c25623..6fd2f6e7445d7f 100644 --- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions.h +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_ -#define XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -70,4 +70,4 @@ class CudnnVectorizeConvolutions : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUDNN_VECTORIZE_CONVOLUTIONS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUDNN_VECTORIZE_CONVOLUTIONS_H_ diff --git a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc rename to third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc index aa15fc73093fc6..7528870af4c605 100644 --- a/third_party/xla/xla/service/gpu/cudnn_vectorize_convolutions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/cudnn_vectorize_convolutions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cudnn_vectorize_convolutions.h" +#include "xla/service/gpu/transforms/cudnn_vectorize_convolutions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc index 814ccf05003804..af9591a16c67ad 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h similarity index 93% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h index cb19d91a3dd571..849fdbbc4a10d8 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ -#define XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -83,4 +83,4 @@ class CustomKernelFusionRewriter : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_CUSTOM_KERNEL_FUSION_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_CUSTOM_KERNEL_FUSION_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc index f2c824cac7e1f7..235e9ded150bfd 100644 --- a/third_party/xla/xla/service/gpu/custom_kernel_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/custom_kernel_fusion_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/custom_kernel_fusion_rewriter.h" +#include "xla/service/gpu/transforms/custom_kernel_fusion_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/dot_dimension_sorter.cc rename to third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc index 38920ee2abdbb7..b1e0b98c319340 100644 --- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_dimension_sorter.h" +#include "xla/service/gpu/transforms/dot_dimension_sorter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h similarity index 90% rename from third_party/xla/xla/service/gpu/dot_dimension_sorter.h rename to third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h index 5eadeb14ceb50a..872fb725b53004 100644 --- a/third_party/xla/xla/service/gpu/dot_dimension_sorter.h +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_ -#define XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -49,4 +49,4 @@ class DotDimensionSorter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_DOT_DIMENSION_SORTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_DIMENSION_SORTER_H_ diff --git a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc rename to third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc index fedd1eae6b65c5..364c1405f09267 100644 --- a/third_party/xla/xla/service/gpu/dot_dimension_sorter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_dimension_sorter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_dimension_sorter.h" +#include "xla/service/gpu/transforms/dot_dimension_sorter.h" #include diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc similarity index 97% rename from third_party/xla/xla/service/gpu/dot_operand_converter.cc rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc index 2a298e67eaf70e..d9e095e2c57ce0 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_operand_converter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter.h b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h similarity index 88% rename from third_party/xla/xla/service/gpu/dot_operand_converter.h rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h index d277a24100c0b3..b269bed8b6a6f3 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter.h +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ -#define XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ #include @@ -43,4 +43,4 @@ class DotOperandConverter : public OpExpanderPass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_DOT_OPERAND_CONVERTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_OPERAND_CONVERTER_H_ diff --git a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/dot_operand_converter_test.cc rename to third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc index 63b0017012f419..be05b6767abbfd 100644 --- a/third_party/xla/xla/service/gpu/dot_operand_converter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_operand_converter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_operand_converter.h" +#include "xla/service/gpu/transforms/dot_operand_converter.h" #include diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc index 0f410916039242..637689a8063a93 100644 --- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_sparsity_rewriter.h" +#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" #include diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h similarity index 87% rename from third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h index b4221978b74f71..b912e2bd323d48 100644 --- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_ -#define XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -39,4 +39,4 @@ class DotSparsityRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_DOT_SPARSITY_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_SPARSITY_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc index c608f8d814410a..28f813fb5be4c4 100644 --- a/third_party/xla/xla/service/gpu/dot_sparsity_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dot_sparsity_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dot_sparsity_rewriter.h" +#include "xla/service/gpu/transforms/dot_sparsity_rewriter.h" #include diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc similarity index 99% rename from third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc rename to third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc index 9cd9113c176f83..a4f901f09a286b 100644 --- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/double_buffer_loop_unrolling.h" +#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" #include #include diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h similarity index 93% rename from third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h rename to third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h index 120070dbccd452..26bb178db155bf 100644 --- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling.h +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_ -#define XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -72,4 +72,4 @@ class DoubleBufferLoopUnrolling : public HloModulePass { } // end namespace gpu } // end namespace xla -#endif // XLA_SERVICE_GPU_DOUBLE_BUFFER_LOOP_UNROLLING_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOUBLE_BUFFER_LOOP_UNROLLING_H_ diff --git a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc rename to third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc index 8fed3192b08598..05e704d6935f07 100644 --- a/third_party/xla/xla/service/gpu/double_buffer_loop_unrolling_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/double_buffer_loop_unrolling_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/double_buffer_loop_unrolling.h" +#include "xla/service/gpu/transforms/double_buffer_loop_unrolling.h" #include #include @@ -853,10 +853,10 @@ ENTRY main { VLOG(0) << module->ToString(); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body {{.+}} { - // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"} + // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}} // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"} + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) // CHECK: } // CHECK: ENTRY %main {{.+}} { @@ -907,13 +907,13 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}"} + // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6},{3,6}{{[}]}}} // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]) - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}"} + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{0,3},{1,4},{1,4},{2,5},{2,5},{3,6}{{[}]}}} // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) // CHECK: ENTRY %main {{.+}} { - // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}"} + // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0}{{[}]}}} // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}}) // CHECK: %[[while:.+]] = {{.+}} while({{.+}} %[[out_peeled]]) // CHECK: } @@ -961,10 +961,10 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}"} + // CHECK: %[[cp1:.+]] = f32[] collective-permute(f32[] %param_0), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{4,6},{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3}{{[}]}}} // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}"} + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,5},{2,5},{2,4},{1,4},{1,3},{0,3},{0,2}{{[}]}}} // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) // CHECK: ENTRY %main // CHECK-NOT: collective-permute @@ -1013,14 +1013,14 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}"} + // CHECK: %[[cp1:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3}{{[}]}}} // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}"} + // CHECK: %[[cp2:.+]] = {{.+}} collective-permute({{.+}} %[[param2]]), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{3,6},{2,5},{2,5},{1,4},{1,4},{0,3},{0,3},{0,2}{{[}]}}} // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) // CHECK: } // CHECK: ENTRY %main - // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}"} + // CHECK: %[[cp_peeled:.+]] = {{.+}} collective-permute({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{1,0},{0,0}{{[}]}}} // CHECK: %[[out_peeled:.+]] = {{.+}} tuple({{.+}} %[[cp_peeled]], {{.+}}) // CHECK: ROOT {{.+}} = {{.+}} while({{.+}} %[[out_peeled]]) // CHECK: } @@ -1069,11 +1069,11 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}"} + // CHECK: %[[cp_start1:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}} // CHECK: %[[cp1:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start1]]) // CHECK: %[[out1:.+]] = {{.+}} tuple({{.+}} %[[cp1]], {{.+}}) // CHECK: %[[param2:.+]] = {{.+}} get-tuple-element({{.+}} %[[out1]]), index=0 - // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}"} + // CHECK: %[[cp_start2:.+]] = {{.+}} collective-permute-start({{.+}}), {{.+}}, frontend_attributes={_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}} // CHECK: %[[cp2:.+]] = {{.+}} collective-permute-done({{.+}} %[[cp_start2]]) // CHECK: ROOT {{.+}} = {{.+}} tuple({{.+}} %[[cp2]], {{.+}}) // CHECK: } @@ -1133,8 +1133,8 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}" - // CHECK: %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}" + // CHECK: %[[recv1:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}} + // CHECK: %[[recv2:.+]] = {{.+}} recv({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}} // CHECK: ENTRY %main // CHECK-NOT: recv // CHECK: } @@ -1190,8 +1190,8 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: %body - // CHECK: %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}}" - // CHECK: %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation="{{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}}" + // CHECK: %[[send1:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,3},{1,3},{1,4},{2,4}{{[}]}} + // CHECK: %[[send2:.+]] = {{.+}} send({{.+}}), {{.+}},_xla_send_recv_validation={{[{]}}{0,2},{0,3},{1,3},{1,4}{{[}]}} // CHECK: ENTRY %main // CHECK-NOT: send // CHECK: } diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc similarity index 78% rename from third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc index 09192416db5217..1eadef692a6839 100644 --- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h" +#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" #include #include @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/gpu_constants.h" #include "xla/service/gpu/hlo_traversal.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" @@ -56,6 +57,8 @@ namespace gpu { namespace { +namespace m = ::xla::match; + // A dataflow path flowing from a definition to a user. using DefUseDataflowPath = absl::InlinedVector; @@ -149,6 +152,98 @@ bool IsAlignedSlice(const HloInstruction* slice) { return true; } +// Pattern matches the following IR (generated by `jax.lax.scan`) to check if +// the offset is a loop iteration number: + +// clang-format off +// param = (s32[], s32[], s32[16]{0}, s32[16]{0}) parameter(0) +// // the index in `gte` has to be the loop iteration index +// gte = s32[] get-tuple-element(param), index=0 +// c0 = s32[] constant(0) compare = pred[] compare(gte, c0), direction=LT +// c_trip_count = s32[] constant(16) +// add = s32[] add(gte, c_trip_count) select = s32[] select(compare, add, gte) +// clang-format on + +bool IsLoopIterationNumber(const HloInstruction& offset) { + const HloComputation* parent = offset.parent(); + if (!parent->IsWhileBodyComputation()) return false; + + // Scan loops trip count must be known at compile time as it iterates over the + // leading dimension of the statically shaped input. + const HloInstruction* while_instr = parent->WhileCallInstruction(); + auto config = while_instr->backend_config(); + if (!config.ok() || !config->has_known_trip_count()) return false; + int32_t trip_count = config->known_trip_count().n(); + + // First lets check the offset computation pattern + if (!Match(&offset, m::Select(m::Lt(m::GetTupleElement(m::Parameter(0)), + m::ConstantScalar(0)), + m::Add(m::GetTupleElement(m::Parameter(0)), + m::ConstantScalar(trip_count)), + m::GetTupleElement(m::Parameter())))) { + return false; + } + + // Next, we check that the parameter used in offset computation is the loop + // induction variable + int64_t param_idx = offset.operand(2)->tuple_index(); + const HloInstruction* root = offset.parent()->root_instruction(); + if (root->opcode() != HloOpcode::kTuple) { + return false; + } + // Check the update operation + const HloInstruction* updated_var = + offset.parent()->root_instruction()->operand(param_idx); + if (!Match(updated_var, m::Add(m::GetTupleElement(m::Parameter(0), param_idx), + m::ConstantScalar(1)))) { + return false; + } + // Check that the condition considers this. + const HloInstruction* condition_root = + while_instr->while_condition()->root_instruction(); + if (!Match(condition_root, + m::Lt(m::GetTupleElement(m::Parameter(0), param_idx), + m::ConstantScalar(trip_count)))) { + return false; + } + // Check init + const HloInstruction* init_loop_iter = + while_instr->operand(0)->operand(param_idx); + if (!Match(init_loop_iter, m::ConstantScalar(0))) { + return false; + } + + return true; +} + +// This returns true for the constants that are handled in the dynamic slice +// fusion runtime. These constants do not force a D2H copy and hence preserve +// the cuda graph. +bool IsHandledConstantForDynamicSliceFusion(const HloInstruction& offset) { + if (auto* cst = DynCast(&offset)) { + switch (cst->shape().element_type()) { + case PrimitiveType::S32: + case PrimitiveType::S64: + case PrimitiveType::U32: + case PrimitiveType::U64: + return true; + default: + return false; + }; + } + return false; +} + +// This checks whether a dynamic index operation has all offsets that are either +// constant or loop iteration offsets. +bool HasConstantOrLoopIterationOffsets( + const HloDynamicIndexInstruction& instr) { + return llvm::all_of(instr.index_operands(), [](const HloInstruction* offset) { + return IsLoopIterationNumber(*offset) || + IsHandledConstantForDynamicSliceFusion(*offset); + }); +} + UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { UseDefDataflowPaths sliced_operand_paths; @@ -193,8 +288,15 @@ UseDefDataflowPaths GetSlicedOperandPaths(const HloInstruction* instr) { }); if (maybe_slice_instr == std::nullopt) continue; - - if (slice_found || processed_instrs.contains(maybe_slice_instr.value())) { + auto dynamic_index_operation = + DynCast(maybe_slice_instr.value()); + bool valid_slice_found = + slice_found && + ((dynamic_index_operation && + HasConstantOrLoopIterationOffsets(*dynamic_index_operation)) || + (*maybe_slice_instr)->opcode() == HloOpcode::kSlice); + if (valid_slice_found || + processed_instrs.contains(maybe_slice_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced operand path // during the latest traversal. @@ -241,7 +343,12 @@ DefUseDataflowPaths GetSlicedUserPaths(const HloInstruction* instr) { }, /*visit_operands=*/false); if (maybe_dus_instr == std::nullopt) return; - if (dus_found || processed_instrs.contains(maybe_dus_instr.value())) { + auto dynamic_index_operation = + DynCast(maybe_dus_instr.value()); + bool valid_dus_found = + dus_found && dynamic_index_operation && + HasConstantOrLoopIterationOffsets(*dynamic_index_operation); + if (valid_dus_found || processed_instrs.contains(maybe_dus_instr.value())) { // Even in the case of stopping at a match that has been processed, we // still need to add instructions encountered in the sliced user path // during the latest traversal. @@ -405,17 +512,21 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( const absl::flat_hash_set& execution_threads) { absl::flat_hash_map> - matches; + matches_kv; + std::vector matches; // Collect all potential custom call matches in the non-fusion computations. for (HloComputation* computation : module->computations()) { if (computation->IsFusionComputation()) continue; for (HloInstruction* instr : computation->instructions()) { - if (IsLegacyCublasMatmul(*instr) || - (IsCustomCall(instr, platform_name_))) { - UseDefDataflowPaths sliced_operand_paths = GetSlicedOperandPaths(instr); - bool has_sliced_operand_paths = sliced_operand_paths.size() > 1; - + UseDefDataflowPaths sliced_operand_paths = {instr}; + bool has_sliced_operand_paths = false; + if (IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { + sliced_operand_paths = GetSlicedOperandPaths(instr); + has_sliced_operand_paths = sliced_operand_paths.size() > 1; + } + if (instr->opcode() == HloOpcode::kReduceScatter || + IsLegacyCublasMatmul(*instr) || IsCustomCall(instr, platform_name_)) { DefUseDataflowPaths sliced_user_paths = GetSlicedUserPaths(instr); bool has_sliced_user_paths = absl::c_any_of( sliced_user_paths, @@ -430,8 +541,9 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( } if (has_sliced_operand_paths || has_sliced_user_paths) { - matches[instr] = std::make_pair(std::move(sliced_operand_paths), - std::move(sliced_user_paths)); + matches_kv[instr] = std::make_pair(std::move(sliced_operand_paths), + std::move(sliced_user_paths)); + matches.push_back(instr); } } } @@ -439,7 +551,8 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( if (matches.empty()) return false; - for (auto& [hero, paths] : matches) { + for (HloInstruction* hero : matches) { + auto& paths = matches_kv[hero]; auto& [sliced_operand_paths, sliced_user_paths] = paths; std::vector matched_instrs; absl::c_copy(sliced_operand_paths, std::back_inserter(matched_instrs)); @@ -503,6 +616,10 @@ absl::StatusOr DynamicSliceFusionRewriter::Run( } TF_RETURN_IF_ERROR( parent->ReplaceInstruction(instr_to_be_replaced, fusion)); + // This is required for collective operations which will not be removed. + if (hero->parent()) { + TF_RETURN_IF_ERROR(hero->parent()->RemoveInstruction(hero)); + } } } diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h similarity index 93% rename from third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h index 15da28410f1382..ad996deb60995e 100644 --- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_ -#define XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_ #include #include @@ -88,4 +88,4 @@ class DynamicSliceFusionRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_DYNAMIC_SLICE_FUSION_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_DYNAMIC_SLICE_FUSION_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc similarity index 84% rename from third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc index 3d3eef1e4a3687..2bd7168adfc06c 100644 --- a/third_party/xla/xla/service/gpu/dynamic_slice_fusion_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/dynamic_slice_fusion_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/dynamic_slice_fusion_rewriter.h" +#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h" #include #include @@ -942,7 +942,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCall) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -990,7 +990,7 @@ TEST_F(DynamicSliceFusionRewriterTest, SimpleCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1050,7 +1050,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupleSliceCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1122,7 +1122,7 @@ TEST_F(DynamicSliceFusionRewriterTest, TupledOutputCustomCallLegacy) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1183,7 +1183,7 @@ TEST_F(DynamicSliceFusionRewriterTest, UnalignedSlice) { xla::ProgramShape(computation.proto().host_program_shape()), /*ignore_layouts=*/false); DebugOptions debug_options = GetDebugOptionsForTest(); - debug_options.set_xla_gpu_enable_address_computation_fusion(false); + debug_options.set_xla_gpu_enable_dynamic_slice_fusion(false); hlo_config.set_debug_options(debug_options); TF_ASSERT_OK_AND_ASSIGN(auto hlo, xla::HloModule::CreateFromProto( computation.proto(), hlo_config)); @@ -1752,4 +1752,313 @@ TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmWorkspaceIgnored) { RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); } +TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSConstantOffset) { + const char* hlo = R"( + HloModule test, replica_count=2 + + add { + param_0 = f16[] parameter(0) + param_1 = f16[] parameter(1) + ROOT add.1 = f16[] add(param_0, param_1) + } + + ENTRY main.9 { + param_0 = f16[128,128]{1,0} parameter(0) + param_1 = f16[128,128]{1,0} parameter(1) + constant_20 = u32[] constant(20) + constant_0 = u32[] constant(0) + reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + ROOT loop_dynamic_update_slice_fusion = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, constant_20, constant_0) + } + )"; + + const char* expected = R"( + // CHECK: %address-computation{{.+}} { + // CHECK: %[[RS:.+]] = f16[64,128]{1,0} reduce-scatter({{.+}}) + // CHECK: ROOT %{{.+}} = f16[128,128]{1,0} dynamic-update-slice(%{{.+}}, %[[RS]], %{{.+}}, %{{.+}}) + // CHECK: } + // CHECK: ENTRY {{.+}} { + // CHECK-NOT: reduce-scatter + // CHECK: ROOT %{{.+}} = {{.+}} fusion(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation" + // CHECK: } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSParameterOffset) { + const char* hlo = R"( + HloModule test, replica_count=2 + + add.clone { + x.1 = f16[] parameter(0) + y.1 = f16[] parameter(1) + ROOT add.462 = f16[] add(x.1, y.1) + } + + ENTRY %main.9 { + param_0 = f16[128,128]{1,0} parameter(0) + param_1 = f16[128,128]{1,0} parameter(1) + param_2 = u32[] parameter(2) + constant_0 = u32[] constant(0) + reduce-scatter = f16[64,128]{1,0} reduce-scatter(param_0), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add.clone + ROOT dynamic-update-slice = f16[128,128]{1,0} dynamic-update-slice(param_1, reduce-scatter, param_2, constant_0) + })"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); +} + +TEST_F(DynamicSliceFusionRewriterTest, ReduceScatterDUSLoopIterationOffset) { + const char* hlo = R"( + HloModule jit_scan, replica_count=2 + + add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT add.6 = f32[] add(param_0, param_1) + } + + Body { + arg_tuple.1 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.5 = s32[] get-tuple-element(arg_tuple.1), index=0 + constant.1 = s32[] constant(1) + add.7 = s32[] add(get-tuple-element.5, constant.1) + get-tuple-element.6 = f32[128,128]{1,0} get-tuple-element(arg_tuple.1), index=3 + get-tuple-element.7 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.1), index=2 + reduce-scatter.0 = f32[64,128]{1,0} reduce-scatter(get-tuple-element.6), channel_id=64, replica_groups={{0,1}}, use_global_device_ids=true, dimensions={0}, to_apply=add + bitcast.63 = f32[1,64,128]{2,1,0} bitcast(reduce-scatter.0) + constant.2 = s32[] constant(0) + compare.4 = pred[] compare(get-tuple-element.5, constant.2), direction=LT + constant.3 = s32[] constant(128) + add.8 = s32[] add(get-tuple-element.5, constant.3) + select.2 = s32[] select(compare.4, add.8, get-tuple-element.5) + dynamic-update-slice.2 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.7, bitcast.63, select.2, constant.2, constant.2) + ROOT tuple.1 = tuple(add.7, get-tuple-element.6, dynamic-update-slice.2, get-tuple-element.6) + } // Body + + Cond { + arg_tuple.0 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(arg_tuple.0), index=0 + constant.0 = s32[] constant(128) + ROOT compare.5 = pred[] compare(get-tuple-element.4, constant.0), direction=LT + } + + ENTRY main.55 { + Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) + constant.4 = s32[] constant(0) + Arg_1.2 = f32[128,128]{1,0} parameter(1) + constant.5 = f32[] constant(0) + broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={} + Arg_0.1 = f32[128,128]{1,0} parameter(0) + tuple = tuple(constant.4, Arg_1.2, broadcast.1, Arg_0.1) + while = while(tuple), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}} + get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while), index=1 + get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while), index=2 + ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51) + })"; + const char* expected = R"( + // CHECK: %address-computation{{.*}}{ + // CHECK: {{.+}} = {{.*}}reduce-scatter({{.+}}) + // CHECK: {{.+}} = {{.*}}dynamic-update-slice({{.+}}) + // CHECK: } + // CHECK: Body{{.+}}{ + // CHECK-NOT: {{.+}} = {{.*}}reduce-scatter({{.+}}) + // CHECK: {{.+}} = {{.+}}fusion({{.+}}), kind=kCustom, calls=%address-computation{{.*}}"name":"dynamic_address_computation" + // CHECK: } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLoopIteration) { + const char* hlo = R"( + HloModule test + + %Body { + param = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) + p0 = get-tuple-element(param), index=0 + p1 = get-tuple-element(param), index=1 + p2 = get-tuple-element(param), index=2 + loop_iter = get-tuple-element(param), index=3 + + bitcast.41 = f16[8,8]{1,0} bitcast(p0) + bitcast.42 = f16[8,8]{1,0} bitcast(p1) + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), custom_call_target="__cublas$gemm", backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + c0 = u32[] constant(0) + c_trip_count = u32[] constant(11) + compare = pred[] compare(loop_iter, c0), direction=LT + add = u32[] add(loop_iter, c_trip_count) + offset = u32[] select(compare, add, loop_iter) + dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, offset, c0, c0) + c1 = u32[] constant(1) + add2 = u32[] add(loop_iter, c1) + ROOT tuple = tuple(p0, p1, dus, u32[] add2) + } + + %Cond { + %param.1 = (f16[1,8,8]{2,1,0}, f16[1,8,8]{2,1,0}, f16[4,8,8]{2,1,0}, u32[]) parameter(0) + %i.1 = u32[] get-tuple-element(%param.1), index=3 + %trip_count = u32[] constant(11) + ROOT %done = pred[] compare(u32[] %i.1, u32[] %trip_count), direction=LT + } + + ENTRY %test { + %p0.1 = f16[1,8,8]{2,1,0} parameter(0) + %p1.1 = f16[1,8,8]{2,1,0} parameter(1) + %p2.1 = f16[4,8,8]{2,1,0} parameter(2) + %c0.1 = u32[] constant(0) + %initial_tuple = tuple(%p0.1, %p1.1, %p2.1, u32[] %c0.1) + ROOT %while = while(%initial_tuple), condition=%Cond, body=%Body, backend_config={"known_trip_count":{"n":"11"}} + })"; + + const char* expected = R"( + // CHECK: %Body{{.+}}{ + // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK: %[[LOOP_ITER:.+]] = u32[] get-tuple-element(%[[PARAM]]), index=3 + // CHECK: %[[OFFSET:.+]] = u32[] select({{.+}}) + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, {{.+}}, {{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation, {{.+}}"name":"dynamic_address_computation" + // CHECK: ROOT %tuple = {{.+}} tuple(%{{.+}}, %{{.+}}, %[[ADDRESS_COMPUTATION]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %test{{.+}}{ + // CHECK: ROOT %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"11"}} + } + )"; + + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmParameterOffset) { + const char* hlo = R"( + HloModule test + + ENTRY main.9 { + p0 = f16[1,8,8]{2,1,0} parameter(0) + p1 = f16[1,8,8]{2,1,0} parameter(1) + p2 = f16[4,8,8]{2,1,0} parameter(2) + p3 = s32[] parameter(3) + c1_s32 = s32[] constant(1) + c0_s32 = s32[] constant(0) + bitcast.41 = f16[8,8]{1,0} bitcast(p0) + bitcast.42 = f16[8,8]{1,0} bitcast(p1) + + custom-call.1 = f16[8,8]{1,0} custom-call(bitcast.41, bitcast.42), + custom_call_target="__cublas$gemm", + backend_config={"gemm_backend_config":{ + "alpha_real":1, + "beta":0, + "dot_dimension_numbers":{ + "lhs_contracting_dimensions":["1"], + "rhs_contracting_dimensions":["0"], + "lhs_batch_dimensions":[], + "rhs_batch_dimensions":[] + }, + "alpha_imag":0, + "precision_config":{"operand_precision":["DEFAULT","DEFAULT"]}, + "epilogue":"DEFAULT", + "lhs_stride":"64", + "rhs_stride":"64", + "grad_x":false, + "grad_y":false + }} + bitcast.43 = f16[1,8,8]{2,1,0} bitcast(custom-call.1) + ROOT dus = f16[4,8,8]{2,1,0} dynamic-update-slice(p2, bitcast.43, p3, c0_s32, c0_s32) + })"; + + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), + std::nullopt); +} + +TEST_F(DynamicSliceFusionRewriterTest, DUSSimpleGemmLaxScan) { + const char* hlo = R"( + HloModule lax_scan + + // This is the HLO generated for the following: + // + // inp = jax.random.uniform(jax.random.key(128), (128, 128, 128)) + // init = jnp.identity(128) + // ans = jax.lax.scan(lambda carry, x : (init, x@carry), init, inp) + + Body { + arg_tuple.15 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.16 = s32[] get-tuple-element(arg_tuple.15), index=0 + constant.21 = s32[] constant(1) + add.2 = s32[] add(get-tuple-element.16, constant.21) + get-tuple-element.30 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=4 + get-tuple-element.18 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=2 + get-tuple-element.19 = f32[128,128,128]{2,1,0} get-tuple-element(arg_tuple.15), index=3 + constant.23 = s32[] constant(0) + compare.2 = pred[] compare(get-tuple-element.16, constant.23), direction=LT + constant.22 = s32[] constant(128) + add.3 = s32[] add(get-tuple-element.16, constant.22) + select.1 = s32[] select(compare.2, add.3, get-tuple-element.16) + dynamic-slice.1 = f32[1,128,128]{2,1,0} dynamic-slice(get-tuple-element.19, select.1, constant.23, constant.23), dynamic_slice_sizes={1,128,128} + bitcast.72 = f32[128,128]{1,0} bitcast(dynamic-slice.1) + get-tuple-element.17 = f32[128,128]{1,0} get-tuple-element(arg_tuple.15), index=1 + custom-call.1 = (f32[128,128]{1,0}, s8[131072]{0}) custom-call(bitcast.72, get-tuple-element.17), custom_call_target="__cublas$gemm" + get-tuple-element = f32[128,128]{1,0} get-tuple-element(custom-call.1), index=0 + bitcast.77 = f32[1,128,128]{2,1,0} bitcast(get-tuple-element) + dynamic-update-slice.1 = f32[128,128,128]{2,1,0} dynamic-update-slice(get-tuple-element.18, bitcast.77, select.1, constant.23, constant.23) + ROOT tuple.38 = tuple(add.2, get-tuple-element.30, dynamic-update-slice.1, get-tuple-element.19, get-tuple-element.30) + } // Body + + Cond { + arg_tuple.40 = (s32[], f32[128,128]{1,0}, f32[128,128,128]{2,1,0}, f32[128,128,128]{2,1,0}, f32[128,128]{1,0}) parameter(0) + get-tuple-element.41 = s32[] get-tuple-element(arg_tuple.40), index=0 + constant.46 = s32[] constant(128) + ROOT compare.3 = pred[] compare(get-tuple-element.41, constant.46), direction=LT + } + + ENTRY main { + constant.4 = s32[] constant(0) + Arg_1.2 = f32[128,128]{1,0} parameter(1) + constant.5 = f32[] constant(0) + broadcast.1 = f32[128,128,128]{2,1,0} broadcast(constant.5), dimensions={} + Arg_2.3 = f32[128,128,128]{2,1,0} parameter(2) + Arg_0.1 = f32[128,128]{1,0} parameter(0) + tuple.7 = tuple(constant.4, Arg_1.2, broadcast.1, Arg_2.3, Arg_0.1) + while.48 = while(tuple.7), condition=Cond, body=Body, backend_config={"known_trip_count":{"n":"128"}} + get-tuple-element.50 = f32[128,128]{1,0} get-tuple-element(while.48), index=1 + get-tuple-element.51 = f32[128,128,128]{2,1,0} get-tuple-element(while.48), index=2 + ROOT tuple.54 = (f32[128,128]{1,0}, f32[128,128,128]{2,1,0}) tuple(get-tuple-element.50, get-tuple-element.51) + } // main.55 + +)"; + auto device = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + const char* expected = R"( + // CHECK: %address-computation{{.*}} {{.+}} { + // CHECK: {{.+}} = {{.+}}dynamic-slice + // CHECK: {{.+}} = {{.+}}custom-call + // CHECK: {{.+}} = {{.+}}dynamic-update-slice + // CHECK: } + // CHECK: %Body{{.+}}{ + // CHECK: %[[PARAM:.+]] = {{.+}} parameter(0) + // CHECK: %[[LOOP_ITER:.+]] = s32[] get-tuple-element(%[[PARAM]]), index=0 + // CHECK: %[[OFFSET:.+]] = s32[] select({{.+}}) + // CHECK: %[[ADDRESS_COMPUTATION:.+]] = {{.+}} fusion({{.+}}, %[[OFFSET]], %{{.+}}), kind=kCustom, calls=%address-computation{{.+}}"name":"dynamic_address_computation" + // CHECK: %[[GTE:.+]] = {{.+}} get-tuple-element(%[[ADDRESS_COMPUTATION]]), index=0 + // CHECK: ROOT %{{.+}} = {{.+}} tuple(%{{.+}}, %[[GTE]], %{{.+}}) + // CHECK: } + // CHECK: ENTRY %main{{.+}}{ + // CHECK: %{{.+}} = {{.+}} while(%{{.+}}), condition=%{{.+}}, body=%Body{{.*}}, backend_config={"known_trip_count":{"n":"128"}} + // CHECK: } + )"; + RunAndFilecheckHloRewrite(hlo, DynamicSliceFusionRewriter("gpu"), expected); +} + } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusion_merger.cc rename to third_party/xla/xla/service/gpu/transforms/fusion_merger.cc index 3703c98da4d911..5a09bf5359c86e 100644 --- a/third_party/xla/xla/service/gpu/fusion_merger.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusion_merger.h" +#include "xla/service/gpu/transforms/fusion_merger.h" #include #include diff --git a/third_party/xla/xla/service/gpu/fusion_merger.h b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h similarity index 95% rename from third_party/xla/xla/service/gpu/fusion_merger.h rename to third_party/xla/xla/service/gpu/transforms/fusion_merger.h index acbc93e7781fbb..15ea9600de9273 100644 --- a/third_party/xla/xla/service/gpu/fusion_merger.h +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSION_MERGER_H_ -#define XLA_SERVICE_GPU_FUSION_MERGER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -82,4 +82,4 @@ class FusionMerger : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSION_MERGER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_FUSION_MERGER_H_ diff --git a/third_party/xla/xla/service/gpu/fusion_merger_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/fusion_merger_test.cc rename to third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc index de45a4b9d3273e..5068f65a49b867 100644 --- a/third_party/xla/xla/service/gpu/fusion_merger_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusion_merger.h" +#include "xla/service/gpu/transforms/fusion_merger.h" #include #include @@ -135,42 +135,42 @@ f32add { } comp0 { - p = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0) - gte0 = f32[100000000] get-tuple-element(p), index=0 - gte1 = f32[100000000] get-tuple-element(p), index=1 - add.9 = f32[100000000] add(gte0, gte1) - gte2 = f32[100000000] get-tuple-element(p), index=2 - add.10 = f32[100000000] add(add.9, gte2) - gte3 = f32[100000000] get-tuple-element(p), index=3 - add.11 = f32[100000000] add(add.10, gte3) - p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1) - gte4 = f32[100000000] get-tuple-element(p1), index=0 - gte5 = f32[100000000] get-tuple-element(p1), index=1 - add.12 = f32[100000000] add(gte4, gte5) - gte6 = f32[100000000] get-tuple-element(p1), index=2 - add.13 = f32[100000000] add(add.12, gte6) - gte7 = f32[100000000] get-tuple-element(p1), index=3 - add.14 = f32[100000000] add(add.13, gte7) - ROOT r = f32[100000000] add(add.14, add.11) + p = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0) + gte0 = f32[2048] get-tuple-element(p), index=0 + gte1 = f32[2048] get-tuple-element(p), index=1 + add.9 = f32[2048] add(gte0, gte1) + gte2 = f32[2048] get-tuple-element(p), index=2 + add.10 = f32[2048] add(add.9, gte2) + gte3 = f32[2048] get-tuple-element(p), index=3 + add.11 = f32[2048] add(add.10, gte3) + p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1) + gte4 = f32[2048] get-tuple-element(p1), index=0 + gte5 = f32[2048] get-tuple-element(p1), index=1 + add.12 = f32[2048] add(gte4, gte5) + gte6 = f32[2048] get-tuple-element(p1), index=2 + add.13 = f32[2048] add(add.12, gte6) + gte7 = f32[2048] get-tuple-element(p1), index=3 + add.14 = f32[2048] add(add.13, gte7) + ROOT r = f32[2048] add(add.14, add.11) } comp1 { - p = f32[100000000] parameter(0) + p = f32[2048] parameter(0) c0 = f32[] constant(0) ROOT r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add } comp2 { - p = f32[100000000] parameter(0) + p = f32[2048] parameter(0) c0 = f32[] constant(0) r = f32[] reduce(p, c0), dimensions={0}, to_apply=f32add ROOT n = f32[] negate(r) } ENTRY m.Computation2 { - p0 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(0) - p1 = (f32[100000000], f32[100000000], f32[100000000], f32[100000000]) parameter(1) - fusion.0 = f32[100000000] fusion(p0, p1), kind=kLoop, calls=comp0 + p0 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(0) + p1 = (f32[2048], f32[2048], f32[2048], f32[2048]) parameter(1) + fusion.0 = f32[2048] fusion(p0, p1), kind=kLoop, calls=comp0 fusion.1 = f32[] fusion(fusion.0), kind=kLoop, calls=comp1 fusion.2 = f32[] fusion(fusion.0), kind=kLoop, calls=comp2 ROOT tuple = (f32[], f32[]) tuple(fusion.1, fusion.2) @@ -362,14 +362,14 @@ TEST_F(FusionMergerTest, WillMergeReduceNotTooUnfriendlyLayouts) { f2_computation { f2_p0 = f32[16,16,256]{2,1,0} parameter(0) f2_zero = f32[] constant(0) - ROOT f2_root = f32[] reduce(f2_p0, f2_zero), dimensions={0,1,2}, + ROOT f2_root = f32[16,16] reduce(f2_p0, f2_zero), dimensions={2}, to_apply=add_computation } ENTRY entry { p0 = f32[16,16,256]{0,1,2} parameter(0) f1 = f32[16,16,256]{2,1,0} fusion(p0), kind=kLoop, calls=f1_computation - ROOT f2 = f32[] fusion(f1), kind=kInput, calls=f2_computation + ROOT f2 = f32[16,16] fusion(f1), kind=kInput, calls=f2_computation })") .value(); EXPECT_TRUE(fusion_merger_.Run(module.get()).value()); @@ -685,6 +685,12 @@ ENTRY entry { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + // For some reason, we would not merge any fusions when using the MLIR + // reduction emitter. The cost model queries the reduction emitter regarding + // the launch dimensions, so it seems likely that it is caused by different + // launch dimensions. + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_TRUE(fusion_merger_.Run(module.get()).value()); } @@ -995,6 +1001,8 @@ ENTRY e { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_FALSE(fusion_merger_.Run(module.get()).value()); } diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc similarity index 98% rename from third_party/xla/xla/service/gpu/fusion_wrapper.cc rename to third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc index 2cb847183d9f4a..d7f8505c420fb4 100644 --- a/third_party/xla/xla/service/gpu/fusion_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusion_wrapper.h" +#include "xla/service/gpu/transforms/fusion_wrapper.h" #include diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper.h b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h similarity index 89% rename from third_party/xla/xla/service/gpu/fusion_wrapper.h rename to third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h index fc466925ce086b..30b1c8abc804f2 100644 --- a/third_party/xla/xla/service/gpu/fusion_wrapper.h +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSION_WRAPPER_H_ -#define XLA_SERVICE_GPU_FUSION_WRAPPER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -39,4 +39,4 @@ class FusionWrapper : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FUSION_WRAPPER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_FUSION_WRAPPER_H_ diff --git a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fusion_wrapper_test.cc rename to third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index 397fe754843b68..a46338f93ea0a0 100644 --- a/third_party/xla/xla/service/gpu/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fusion_wrapper.h" +#include "xla/service/gpu/transforms/fusion_wrapper.h" #include diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc index a6cbbf11c94fd0..3395ed8a7db1c8 100644 --- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" +#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" #include diff --git a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h similarity index 88% rename from third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h index bac14bc9711387..8606136551212a 100644 --- a/third_party/xla/xla/service/gpu/gemm_broadcast_folding_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_ -#define XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -48,4 +48,4 @@ class GemmBroadcastFoldingRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_BROADCAST_FOLDING_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_BROADCAST_FOLDING_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc index 7e98e974e12d5a..57e68fc463de2f 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_broadcast_folding_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_broadcast_folding_rewriter_test.cc @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h" + #include #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/gemm_broadcast_folding_rewriter.h" -#include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gemm_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc index 4e37a4b7fc59b6..65326580472470 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" #include #include @@ -44,6 +44,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_padding_requirements.h" #include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/triton_fusion_analysis.h" @@ -783,7 +784,6 @@ absl::StatusOr RunOnComputation( return visitor.changed(); } - } // namespace bool ShouldTritonHandleGEMM(HloDotInstruction& dot, diff --git a/third_party/xla/xla/service/gpu/gemm_fusion.h b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h similarity index 92% rename from third_party/xla/xla/service/gpu/gemm_fusion.h rename to third_party/xla/xla/service/gpu/transforms/gemm_fusion.h index c858b43822f194..7f8fe6f94778f9 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_FUSION_H_ -#define XLA_SERVICE_GPU_GEMM_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_ // This file contains the code for fusing dots and other operations into Triton // GEMM fusions. @@ -54,4 +54,4 @@ class GemmFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc similarity index 94% rename from third_party/xla/xla/service/gpu/gemm_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc index 44430f514cd99a..f72650ef7ff9aa 100644 --- a/third_party/xla/xla/service/gpu/gemm_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_fusion_test.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" #include +#include #include #include @@ -1329,6 +1330,85 @@ ENTRY main { EXPECT_FALSE(result.ok()); } +constexpr auto kInt4Dot = R"( +ENTRY e { + p0 = s8[16,16] parameter(0) + p1 = s4[16,16] parameter(1) + p1c = bf16[16,16] convert(p1) + ROOT dot = bf16[16,16] dot(p0, p1c), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + +TEST_F(SmallDotGemmFusionTest, Int4DotIsRewritten) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(SmallDotGemmFusionTest, Int4DotIsNotRewritten) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + EXPECT_FALSE(GemmFusion(gpu_version_).Run(module.get()).value()); +} + +TEST_F(SmallDotGemmFusionTest, Int4ConcatPlusConvertIsRewritten) { + const std::string kInt4Dot = R"( + ENTRY main { + lhs1 = s4[4,1024]{1,0} parameter(0) + lhs2 = s4[4,1024]{1,0} parameter(1) + rhs = bf16[1024,4]{1,0} parameter(2) + lhs_concat = s4[8,1024]{1,0} concatenate(lhs1, lhs2), dimensions={0} + lhs_converted = bf16[8,1024]{1,0} convert(lhs_concat) + ROOT dot = bf16[8,4]{1,0} dot(lhs_converted, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + + // Check that the fusion is present and that the lhs is not converted. + MatchHloModule(*module, R"( +CHECK: gemm_fusion_dot_computation +CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) +CHECK: ENTRY +CHECK-DAG: ROOT {{.*}} = bf16[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs_concat, bf16[1024,4]{1,0} %rhs) +})"); +} + +TEST_F(SmallDotGemmFusionTest, Int4ConvertPlusNegateIsRewritten) { + const std::string kInt4Dot = R"( + ENTRY main { + lhs = s4[8,1024]{1,0} parameter(0) + rhs = f32[1024,4]{1,0} parameter(1) + lhs_converted = f32[8,1024]{1,0} convert(lhs) + lhs_negated = f32[8,1024]{1,0} negate(lhs_converted) + ROOT dot = f32[8,4]{1,0} dot(lhs_negated, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kInt4Dot)); + module->mutable_config() + .mutable_debug_options() + .set_xla_gpu_enable_triton_gemm_int4(true); + EXPECT_TRUE(GemmFusion(gpu_version_).Run(module.get()).value()); + // Check that the fusion is present and that convert and negation is fused in + // it. + MatchHloModule(*module, R"( +CHECK: gemm_fusion_dot_computation +CHECK: %parameter_0 = s4[8,1024]{1,0} parameter(0) +CHECK: ENTRY +CHECK-DAG: ROOT {{.*}} = f32[8,4]{1,0} fusion(s4[8,1024]{1,0} %lhs, f32[1024,4]{1,0} %rhs) +})"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc similarity index 92% rename from third_party/xla/xla/service/gpu/gemm_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc index e3dd0cfa5fc75f..82895a5b3ae967 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.cc @@ -14,7 +14,7 @@ limitations under the License. = =============================================================================*/ -#include "xla/service/gpu/gemm_rewriter.h" +#include "xla/service/gpu/transforms/gemm_rewriter.h" #include #include @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -44,6 +45,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" @@ -550,10 +552,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { public: explicit GemmRewriterVisitor(const se::GpuComputeCapability &gpu_version, const int32_t toolkit_version, - const bool f8_rewrite) + const GemmRewriterOptions options) : gpu_version_(gpu_version), toolkit_version_(toolkit_version), - f8_rewrite_(f8_rewrite) {} + options_(options) {} absl::Status HandleDot(HloInstruction *instr) override { if (!IsMatrixMultiplication(*instr) && @@ -618,50 +620,54 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { gemm_backend_config.set_lhs_stride(lhs_stride); gemm_backend_config.set_rhs_stride(rhs_stride); - if (f8_rewrite_) { - // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( - bool supported_by_cublaslt, - GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); - std::optional a, b; - if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot && - (a = MatchFp8Param( - const_cast(instr->operand(0)))) && - (b = MatchFp8Param( - const_cast(instr->operand(1))))) { - if (IsRocm(gpu_version_) && toolkit_version_ < 60200 && - instr->shape().element_type() != F16 && - instr->shape().element_type() != F32) { - TF_ASSIGN_OR_RETURN(instr, - TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + switch (options_.dtype) { + case GemmRewriterOptions::DType::kFp8Only: { + // Rewrite FP8 GEMMs into a type-specific cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + bool supported_by_cublaslt, + GemmIsSupportedByCublasLt(*instr, gemm_backend_config)); + std::optional a, b; + if (supported_by_cublaslt && instr->opcode() == HloOpcode::kDot && + (a = MatchFp8Param( + const_cast(instr->operand(0)))) && + (b = MatchFp8Param( + const_cast(instr->operand(1))))) { + if (IsRocm(gpu_version_) && toolkit_version_ < 60200 && + instr->shape().element_type() != F16 && + instr->shape().element_type() != F32) { + TF_ASSIGN_OR_RETURN( + instr, TurnF8DotWithUnsupportedOutputTypeIntoF32(instr)); + } + TF_ASSIGN_OR_RETURN(bool created_call, + CreateF8CustomCall(instr, gpu_backend_config, + a.value(), b.value())); + if (created_call) { + return absl::OkStatus(); + } } - TF_ASSIGN_OR_RETURN(bool created_call, - CreateF8CustomCall(instr, gpu_backend_config, - a.value(), b.value())); - if (created_call) { - return absl::OkStatus(); + if (IsF8Type(instr->operand(0))) { + // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt + // custom call, so turn into an FP16 dot which may be rewritten as an + // FP16 Triton, cublas or cublasLt call. + TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); } + break; } - if (IsF8Type(instr->operand(0))) { - // FP8 rewriter couldn't rewrite dot with FP8 inputs into cublasLt - // custom call, so turn into an FP16 dot which may be rewritten as an - // FP16 Triton, cublas or cublasLt call. - TF_ASSIGN_OR_RETURN(instr, TurnF8DotIntoF16Dot(instr)); - } - } else { - // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. - TF_ASSIGN_OR_RETURN( - absl::string_view gemm_custom_call_target, - GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); - const Shape &output_shape = instr->shape(); - HloInstruction *gemm_call = - instr->AddInstruction(HloInstruction::CreateCustomCall( - output_shape, - {instr->mutable_operand(0), instr->mutable_operand(1)}, - gemm_custom_call_target)); - TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); - } + case GemmRewriterOptions::DType::kNonFp8Only: { + // Rewrite non-FP8 GEMMs into a cublas or cublasLT Custom Call. + TF_ASSIGN_OR_RETURN( + absl::string_view gemm_custom_call_target, + GetNonFp8GemmCustomCallTarget(*instr, gemm_backend_config)); + const Shape &output_shape = instr->shape(); + HloInstruction *gemm_call = + instr->AddInstruction(HloInstruction::CreateCustomCall( + output_shape, + {instr->mutable_operand(0), instr->mutable_operand(1)}, + gemm_custom_call_target)); + TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gpu_backend_config)); + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, gemm_call)); + } break; + }; return absl::OkStatus(); } @@ -757,6 +763,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } absl::Status HandleAdd(HloInstruction *instr) override { + if (options_.bias_mode == GemmRewriterOptions::BiasMode::kNoBias) { + // See comments for `GemmRewriterOptions::BiasMode` for details. + return absl::OkStatus(); + } + HloInstruction *bias, *existing_gemm = nullptr; HloInstruction *optional_slice = nullptr; HloInstruction *optional_convert = nullptr; @@ -1062,8 +1073,11 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } } - absl::Span batch_dims = + absl::Span a_batch_dims = + gemm_backend_config.dot_dimension_numbers().lhs_batch_dimensions(); + absl::Span b_batch_dims = gemm_backend_config.dot_dimension_numbers().rhs_batch_dimensions(); + const size_t num_batch_dims = a_batch_dims.size(); // cuBLASLt FP8 GEMM kernels require the scaling factors to be in F32 // format. Set the factors to one when no scaling factors were captured. @@ -1129,22 +1143,16 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { "dimension."; return false; } - if ((a.commutative_ops.empty() ? a.fp8_input - : a.commutative_ops.back().first) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2 || - (b.commutative_ops.empty() ? b.fp8_input - : b.commutative_ops.back().first) - ->shape() - .dimensions_size() - - batch_dims.size() != - 2) { - VLOG(1) << "Failed to rewrite " << instr->ToShortString() - << "into FP8 Custom Call. A and B must have one non-contracting " - "dimension."; - return false; + for (const MatchedFp8Param ¶m : {a, b}) { + const HloInstruction *input = param.commutative_ops.empty() + ? param.fp8_input + : param.commutative_ops.back().first; + if (input->shape().rank() != num_batch_dims + 2) { + VLOG(1) << "Failed to rewrite " << instr->ToShortString() + << "into FP8 Custom Call. Inputs must have exactly one " + "contracting and one non-contracting dimension."; + return false; + } } // Sequentially apply the collected unary, dynamic-slice, pad and select ops @@ -1192,49 +1200,49 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { shift_ops(a.fp8_input, a.commutative_ops); shift_ops(b.fp8_input, b.commutative_ops); - TF_ASSIGN_OR_RETURN(bool a_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "a")); - TF_ASSIGN_OR_RETURN(bool b_is_col_major, - MatrixIsColumnMajor(*instr, gemm_backend_config, "b")); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, + GemmConfig::For(instr, gemm_backend_config)); DotDimensionNumbers *dim_nums = gemm_backend_config.mutable_dot_dimension_numbers(); - int batch_dim_offset = batch_dims.size(); // cuBLASLt FP8 GEMM kernels currently require the first operand, i.e. A, to // be row-major. If A is column-major, swap the contracting and // non-contracting dimension and transpose the matrix to effectively make it // column-major. // TODO(philipphack): Remove once cuBLASLt supports A being column-major - if (a_is_col_major) { - CHECK(a_contracting_dims[0] == batch_dim_offset || - a_contracting_dims[0] == batch_dim_offset + 1); - if (a_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset + 1); + if (gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor) { + CHECK(a_contracting_dims[0] == num_batch_dims || + a_contracting_dims[0] == num_batch_dims + 1); + if (a_contracting_dims[0] == num_batch_dims) { + dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims + 1); } else { - dim_nums->set_lhs_contracting_dimensions(0, batch_dim_offset); + dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims); } a.fp8_input = - TransposeMatrix(a.fp8_input, a_contracting_dims[0], batch_dims); + TransposeMatrix(a.fp8_input, a_contracting_dims[0], a_batch_dims); } // Similarly, cuBLASLt requires the second operand to be column-major, so // make it column-major if it is currently row-major. - if (!b_is_col_major) { - CHECK(b_contracting_dims[0] == batch_dim_offset || - b_contracting_dims[0] == batch_dim_offset + 1); - if (b_contracting_dims[0] == batch_dim_offset) { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset + 1); + if (gemm_config.rhs_layout.order == MatrixLayout::Order::kRowMajor) { + CHECK(b_contracting_dims[0] == num_batch_dims || + b_contracting_dims[0] == num_batch_dims + 1); + if (b_contracting_dims[0] == num_batch_dims) { + dim_nums->set_rhs_contracting_dimensions(0, num_batch_dims + 1); } else { - dim_nums->set_rhs_contracting_dimensions(0, batch_dim_offset); + dim_nums->set_rhs_contracting_dimensions(0, num_batch_dims); } b.fp8_input = - TransposeMatrix(b.fp8_input, b_contracting_dims[0], batch_dims); + TransposeMatrix(b.fp8_input, b_contracting_dims[0], b_batch_dims); } - a.fp8_input = PadOperandToMultipleOf16(batch_dims, a.fp8_input); - b.fp8_input = PadOperandToMultipleOf16(batch_dims, b.fp8_input); - Shape new_output_shape = PadShapeToMultipleOf16(instr->shape(), batch_dims); + a.fp8_input = PadOperandToMultipleOf16(a_batch_dims, a.fp8_input); + b.fp8_input = PadOperandToMultipleOf16(b_batch_dims, b.fp8_input); + std::vector out_batch_dims(num_batch_dims); + std::iota(out_batch_dims.begin(), out_batch_dims.end(), 0); + Shape new_output_shape = + PadShapeToMultipleOf16(instr->shape(), out_batch_dims); std::vector operands_list = { a.fp8_input, b.fp8_input, scales_f32[0], scales_f32[1], one, one}; @@ -1820,7 +1828,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { private: se::GpuComputeCapability gpu_version_; int32_t toolkit_version_; - bool f8_rewrite_; + GemmRewriterOptions options_; // Choose cublas or cublasLt for the target of the custom call that instr will // be rewritten into. @@ -2120,47 +2128,10 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { output_dtype)); } - absl::StatusOr MatrixIsColumnMajor( - const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config, - const std::string matrix_name = "output") const { - const HloInstruction *lhs = instr.operand(0); - const HloInstruction *rhs = instr.operand(1); - - const DotDimensionNumbers &dot_dims = - gemm_backend_config.dot_dimension_numbers(); - // We use ALG_UNSET and kDefaultComputePrecision because we don't care about - // the precision, just the layout, since we're just checking if the matrix - // is column-major. - TF_ASSIGN_OR_RETURN( - GemmConfig gemm_config, - GemmConfig::For( - lhs->shape(), dot_dims.lhs_batch_dimensions(), - dot_dims.lhs_contracting_dimensions(), rhs->shape(), - dot_dims.rhs_batch_dimensions(), - dot_dims.rhs_contracting_dimensions(), - /*output_shape=*/instr.shape(), gemm_backend_config.alpha_real(), - gemm_backend_config.alpha_imag(), gemm_backend_config.beta(), - /*precision_algorithm=*/PrecisionConfig::ALG_UNSET, - /*algorithm*/ std::nullopt, se::blas::kDefaultComputePrecision, - gemm_backend_config.grad_x(), gemm_backend_config.grad_y())); - - if (matrix_name == "lhs" || matrix_name == "a") { - return gemm_config.lhs_layout.order == MatrixLayout::Order::kColumnMajor; - } else if (matrix_name == "rhs" || matrix_name == "b") { - return gemm_config.rhs_layout.order == MatrixLayout::Order::kColumnMajor; - } else if (matrix_name == "output" || matrix_name == "d") { - return gemm_config.output_layout.order == - MatrixLayout::Order::kColumnMajor; - } else { - return Internal("Invalid matrix name."); - } - } - absl::StatusOr GemmIsSupportedByCublasLt( const HloInstruction &instr, const GemmBackendConfig &gemm_backend_config) const { const HloInstruction *lhs = instr.operand(0); - const HloInstruction *rhs = instr.operand(1); const Shape &output_shape = instr.shape(); TF_ASSIGN_OR_RETURN( @@ -2187,9 +2158,6 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return false; } - TF_ASSIGN_OR_RETURN(bool output_is_column_major, - MatrixIsColumnMajor(instr, gemm_backend_config)); - if (auto isrocm = std::get_if(&gpu_version_); isrocm) { if (!isrocm->has_hipblaslt()) { @@ -2206,10 +2174,7 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { } if (std::holds_alternative(gpu_version_)) { - auto cuda_compute_capability_ = - std::get(gpu_version_); - if (cuda_compute_capability_.IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + if (std::get(gpu_version_).IsAtLeastAmpere()) { // cuBlasLt has an implementation for complex data with compute type // 32F_FAST_32TF that uses tensor cores and that is free from the // restriction. This implementation only works on Ampere @@ -2217,36 +2182,12 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return true; } } - // Get the rhs non-contracting dimensions as they will eventually be at the - // cublasLt level. - std::vector rhs_non_contracting_dims; - const DotDimensionNumbers &dot_dims = - gemm_backend_config.dot_dimension_numbers(); - - if (!output_is_column_major) { - // cublasLt's matmul output is column major by default. This gemm requires - // the output to be in row major. Later we will swap lhs & rhs (and - // transpose each operand) of this gemm. Since we care about the rhs at - // the cublasLt level, this swap means that we care about the lhs right - // here. - TF_ASSIGN_OR_RETURN( - rhs_non_contracting_dims, - GetNonContractingDims(lhs->shape(), dot_dims.lhs_batch_dimensions(), - dot_dims.lhs_contracting_dimensions())); - } else { - TF_ASSIGN_OR_RETURN( - rhs_non_contracting_dims, - GetNonContractingDims(rhs->shape(), dot_dims.rhs_batch_dimensions(), - dot_dims.rhs_contracting_dimensions())); - } - const auto lhs_non_contracting_dimension_size = absl::c_accumulate( - rhs_non_contracting_dims, 1, [&](int64_t size, int64_t dim) { - return size * lhs->shape().dimensions(dim); - }); + TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, + GemmConfig::For(&instr, gemm_backend_config)); // Check that the size of the non-contracting dimension is not too large. - return lhs_non_contracting_dimension_size <= kMaxDimensionSize; + return gemm_config.rhs_layout.num_cols <= kMaxDimensionSize; } // Turns an F8 dot with unsupported output type into an F8 dot with F32 @@ -2263,16 +2204,20 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor { return f32_dot; } - // Turns an F8 dot into an F16 dot, converting operands to F16 and + // Turns an F8 dot into an F16 dot, converting operands to F16 (or BF16) and // converting the output back to F8. absl::StatusOr TurnF8DotIntoF16Dot(HloInstruction *instr) { DCHECK(IsF8Type(instr->operand(0))); DCHECK(IsF8Type(instr->operand(1))); - // Convert operands to F16 + // If the output type is BF16, the input types have to be BF16 as well. + PrimitiveType conv_type = + instr->shape().element_type() == BF16 ? BF16 : F16; + + // Convert operands to F16 (or BF16). for (int i = 0; i < 2; ++i) { Shape operand_f16_shape = instr->operand(i)->shape(); - operand_f16_shape.set_element_type(F16); + operand_f16_shape.set_element_type(conv_type); HloInstruction *convert = instr->AddInstruction(HloInstruction::CreateConvert( operand_f16_shape, instr->mutable_operand(i))); @@ -2395,8 +2340,8 @@ class GemmWorkspaceRewriteVisitor : public DfsHloRewriteVisitor { absl::StatusOr RunOnComputation(HloComputation *computation, se::GpuComputeCapability gpu_version, int32_t toolkit_version, - bool f8_rewrite) { - GemmRewriterVisitor visitor(gpu_version, toolkit_version, f8_rewrite); + GemmRewriterOptions options) { + GemmRewriterVisitor visitor(gpu_version, toolkit_version, options); TF_RETURN_IF_ERROR(computation->Accept(&visitor)); GemmWorkspaceRewriteVisitor workspace_visitor(gpu_version); TF_RETURN_IF_ERROR(computation->Accept(&workspace_visitor)); @@ -2406,10 +2351,10 @@ absl::StatusOr RunOnComputation(HloComputation *computation, } // anonymous namespace GemmRewriter::GemmRewriter(se::GpuComputeCapability gpu_version, - int32_t toolkit_version, bool f8_rewrite) + int32_t toolkit_version, GemmRewriterOptions options) : gpu_version_(gpu_version), toolkit_version_(toolkit_version), - f8_rewrite_(f8_rewrite) {} + options_(options) {} absl::StatusOr GemmRewriter::Run( HloModule *module, @@ -2419,7 +2364,7 @@ absl::StatusOr GemmRewriter::Run( module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation, gpu_version_, - toolkit_version_, f8_rewrite_)); + toolkit_version_, options_)); changed |= result; } return changed; diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h similarity index 64% rename from third_party/xla/xla/service/gpu/gemm_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h index 161a29a3b26bbf..cce09c45c464f6 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMM_REWRITER_H_ -#define XLA_SERVICE_GPU_GEMM_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_ #include @@ -45,12 +45,40 @@ namespace gpu { // (we assume transposes are already folded), and rewrites it into a custom call // where (A, B, C) are three operands respectively, and `alpha` and `beta` are // stored in the backend config. + +struct GemmRewriterOptions { + // The DType of the GEMM to rewrite. + enum class DType { kFp8Only, kNonFp8Only }; + DType dtype = DType::kNonFp8Only; + + // Disabling bias prevents using the `beta * C` term the GEMM, which can + // remove dependencies between multiple matrix multiplications. This, in + // turn, can improve the performance of overall computation by allowing + // multiple GEMMs to be scheduled in parallel. + // + // As an example, consider the following computation: `(A * A) + (B * B)`. + // With bias enabled, the `GemmRewriter` will emit the following GEMMs: + // + // AA := GEMM(A * A) + // ROOT := GEMM(B * B + AA) + // + // Because the second GEMM depends on the first, they cannot be scheduled in + // parallel. Instead, with bias disabled, the `GemmRewriter` will emit the + // following: + // + // AA := GEMM(A * A) + // BB := GEMM(B * B) + // ROOT := AA + BB + // + // In this case, the two GEMMs can be scheduled in parallel. + enum class BiasMode { kBias, kNoBias }; + BiasMode bias_mode = BiasMode::kBias; +}; + class GemmRewriter : public HloModulePass { public: - // When f8_rewrite is true, only FP8 GEMMs are rewritten. Otherwise, non-FP8 - // GEMMs are rewritten. GemmRewriter(se::GpuComputeCapability gpu_version, int32_t toolkit_version, - bool f8_rewrite = false); + GemmRewriterOptions options = {}); absl::string_view name() const override { return "cublas-gemm-rewriter"; } using HloPassInterface::Run; @@ -61,10 +89,10 @@ class GemmRewriter : public HloModulePass { private: se::GpuComputeCapability gpu_version_; int32_t toolkit_version_; - bool f8_rewrite_; + GemmRewriterOptions options_; }; } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMM_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc rename to third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc index cd423569345b0f..1bcf7aed0b9689 100644 --- a/third_party/xla/xla/service/gpu/tests/gemm_rewrite_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemm_rewriter_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/transforms/gemm_rewriter.h" + #include #include #include @@ -35,7 +37,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/buffer_assignment.h" #include "xla/service/executable.h" -#include "xla/service/gpu/gemm_rewriter.h" #include "xla/service/gpu/gpu_executable.h" #include "xla/service/gpu/tests/gpu_codegen_test.h" #include "xla/service/hlo_module_config.h" @@ -48,8 +49,8 @@ limitations under the License. #include "xla/test.h" #include "xla/tests/filecheck.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #if GOOGLE_CUDA @@ -4809,9 +4810,41 @@ class ParameterizedFp8GemmRewriteTest : public ParameterizedGemmRewriteTest { static constexpr const char* kF8E4M3AmaxPlaceholder{"<>"}; }; +TEST_P(ParameterizedFp8GemmRewriteTest, SupportsF8NonMajorBatchDim) { + const char* hlo_text = R"( +HloModule t + +ENTRY main { + %bitcast.73421 = f8e4m3fn[16,8,640]{2,1,0} parameter(0) + %parameter_1.5 = f8e4m3fn[8,640,5120]{2,1,0} parameter(1) + %parameter_2 = f8e4m3fn[8,640,5120]{2,1,0} parameter(2) + %concatenate.2145 = f8e4m3fn[8,640,10240]{2,1,0} concatenate( + f8e4m3fn[8,640,5120]{2,1,0} %parameter_1.5, + f8e4m3fn[8,640,5120]{2,1,0} %parameter_2), + dimensions={2} + %dot.6237 = f32[8,16,10240]{2,1,0} dot( + f8e4m3fn[16,8,640]{2,1,0} %bitcast.73421, + f8e4m3fn[8,640,10240]{2,1,0} %concatenate.2145), + lhs_batch_dims={1}, + lhs_contracting_dims={2}, + rhs_batch_dims={0}, + rhs_contracting_dims={1} + ROOT %convert.20480 = bf16[8,16,10240]{2,1,0} convert( + f32[8,16,10240]{2,1,0} %dot.6237) +})"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-2, 1e-2})); + MatchOptimizedHlo(hlo_text, + R"( +; CHECK: custom-call({{.*}}"lhs_batch_dimensions":["1"],"rhs_batch_dimensions":["0"] + )"); +} + TEST_P(ParameterizedFp8GemmRewriteTest, DoNotRewriteToF8OnPreAda) { + if (!IsCuda()) { + GTEST_SKIP() << "FP8 Rewrite pattern is different on ROCM-6.2 "; + } if (HasFp8Support()) { - GTEST_SKIP() << "Test requires a pre-Ada GPU or an AMD GPU prior to MI300."; + GTEST_SKIP() << "Test requires a pre-Ada GPU"; } const char* hlo_text = R"( HloModule test @@ -4883,7 +4916,8 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnsupportedTypesF8) { ErrorSpec{1e-2, 1e-2})); RunAndFilecheckHloRewrite( hlo_text, - GemmRewriter(Capability(), GetToolkitVersion(), /*f8_rewrite=*/true), + GemmRewriter(Capability(), GetToolkitVersion(), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %unsupported_types ({{.*}}: <>[16,16], {{.*}}: <>[16,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,16]{1,0} parameter(0) @@ -4919,7 +4953,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -4982,7 +5016,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: <>[16,16]) -> <>[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5044,7 +5078,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5105,7 +5139,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDPaddedF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[13,17], {{.*}}: <>[17,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[13,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[13,17]{1,0} parameter(0) @@ -5172,7 +5206,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDBitcastF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5210,7 +5244,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDWithConvertF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5274,7 +5308,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDUnaryOpsF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5342,7 +5376,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3], {{.*}}: <>[32,16]) -> f32[16,16] { @@ -5408,7 +5442,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5416,7 +5450,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDDynamicSliceF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[32,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[32,32]{1,0} parameter(0) @@ -5480,7 +5514,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -5488,7 +5522,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDSelectF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: pred[16,32]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5556,7 +5590,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); } @@ -5593,7 +5627,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, BatchedScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[10,16,32], {{.*}}: <>[10,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[10,16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[10,16,32]{2,1,0} parameter(0) @@ -5657,7 +5691,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABAlphaDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5722,7 +5756,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDReluActivationF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -5808,7 +5842,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[], {{.*}}: bf16[16]) -> bf16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5909,7 +5943,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: bf16[], {{.*}}: bf16[]) -> bf16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -5987,7 +6021,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, InvScaledABUnscaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -6030,7 +6064,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { @@ -6098,7 +6132,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDMatrixBiasPaddedF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[14,31], {{.*}}: <>[31,14], {{.*}}: f32[14,14], {{.*}}: f32[], {{.*}}: f32[]) -> f32[14,14] { @@ -6170,7 +6204,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6228,7 +6262,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6285,7 +6319,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABInvScaledF32DF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6344,7 +6378,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABScaledF32DMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6417,7 +6451,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6491,7 +6525,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABInvScaledDF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-NOT: divide @@ -6543,7 +6577,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDReluActivationF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> <>[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6631,7 +6665,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDMatrixBiasWithDAmaxF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { @@ -6709,7 +6743,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDVectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> <>[16,16] { @@ -6784,7 +6818,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF32VectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6853,7 +6887,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -6920,7 +6954,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -6934,7 +6968,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDVectorBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[32], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) @@ -7006,7 +7040,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7022,7 +7056,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,15,15], {{.*}}: <>[15,31], {{.*}}: f32[31], {{.*}}: f16[], {{.*}}: f16[]) -> f16[4,15,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,15,15]{2,1,0} parameter(0) @@ -7098,7 +7132,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7112,7 +7146,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, Rank3ScaledABUnscaledDMatrixBiasF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[4,16,16], {{.*}}: <>[16,32], {{.*}}: f32[4,16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[4,16,32] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[4,16,16]{2,1,0} parameter(0) @@ -7180,7 +7214,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); @@ -7196,7 +7230,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[3,15,15], {{.*}}: <>[15,31], {{.*}}: f32[3,15,31], {{.*}}: f32[], {{.*}}: f32[]) -> f32[3,15,31] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[3,15,15]{2,1,0} parameter(0) @@ -7272,14 +7306,14 @@ TEST_P(ParameterizedFp8GemmRewriteTest, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_TRUE(changed); RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[48,16], {{.*}}: <>[16,32], {{.*}}: f32[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[32,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = <>[48,16]{1,0} parameter(0) @@ -7348,7 +7382,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllGatherF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,32] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7416,7 +7450,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDWithAllToAllF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7482,7 +7516,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[16,32], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7550,7 +7584,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[16], {{.*}}: f16[16,16], {{.*}}: f16[], {{.*}}: f16[]) -> f16[16,16] { ; CHECK-DAG: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7637,7 +7671,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABScaledDWithDAmaxF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7724,7 +7758,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f16[], {{.*}}: f16[], {{.*}}: f16[]) -> (<>[16,16], f16[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7814,7 +7848,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: <>[16,32], {{.*}}: <>[32,16], {{.*}}: f32[], {{.*}}: f32[], {{.*}}: f32[]) -> (<>[16,16], f32[]) { ; CHECK: [[P0:%[^ ]+]] = <>[16,32]{1,0} parameter(0) @@ -7955,7 +7989,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8Parameterized) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8031,7 +8065,7 @@ ENTRY f { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8070,7 +8104,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDF8TF32E5M2) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK: custom_call_target="__cublas$lt$matmul$f8", )"); @@ -8107,7 +8141,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo_text)); GemmRewriter pass(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true); + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}); TF_ASSERT_OK_AND_ASSIGN(bool changed, this->RunHloPass(&pass, module.get())); EXPECT_FALSE(changed); #endif @@ -8116,7 +8150,7 @@ TEST_P(ParameterizedFp8GemmRewriteTest, FnuzTypeF8) { RunAndFilecheckHloRewrite( hlo_text, GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(), - /*f8_rewrite=*/true), + GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}), R"( ; CHECK-LABEL: ENTRY %test ({{.*}}: f8e4m3fnuz[16,32], {{.*}}: f8e4m3fnuz[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] { ; CHECK-NEXT: [[P0:%[^ ]+]] = f8e4m3fnuz[16,32]{1,0} parameter(0) @@ -8212,6 +8246,119 @@ ENTRY main { )"); } +TEST_F(GemmRewriteTest, DotWithBias) { + const char* hlo = R"( + HloModule m + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + p2 = f32[1024,1024] parameter(2) + p3 = f32[1024,1024] parameter(3) + dot0 = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot1 = f32[1024,1024] dot(p2, p3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT root = f32[1024,1024] add(dot0, dot1) + })"; + + const char* expected = R"() + // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0) + // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1) + // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2) + // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3) + // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]]) + // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0 + // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]], %[[S0]]) + // CHECK: ROOT %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0 + })"; + + RunAndFilecheckHloRewrite( + hlo, + GemmRewriter( + se::CudaComputeCapability{}, /*toolkit_version=*/0, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only}), + expected); +} + +TEST_F(GemmRewriteTest, DotWithoutBias) { + const char* hlo = R"( + HloModule m + + ENTRY main { + p0 = f32[1024,1024] parameter(0) + p1 = f32[1024,1024] parameter(1) + p2 = f32[1024,1024] parameter(2) + p3 = f32[1024,1024] parameter(3) + dot0 = f32[1024,1024] dot(p0, p1), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + dot1 = f32[1024,1024] dot(p2, p3), + lhs_contracting_dims={1}, rhs_contracting_dims={0} + ROOT root = f32[1024,1024] add(dot0, dot1) + })"; + + const char* expected = R"() + // CHECK: %[[P0:.*]] = f32[1024,1024]{1,0} parameter(0) + // CHECK: %[[P1:.*]] = f32[1024,1024]{1,0} parameter(1) + // CHECK: %[[TUPLE0:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P0]], %[[P1]]) + // CHECK: %[[S0:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE0]]), index=0 + // CHECK: %[[P2:.*]] = f32[1024,1024]{1,0} parameter(2) + // CHECK: %[[P3:.*]] = f32[1024,1024]{1,0} parameter(3) + // CHECK: %[[TUPLE1:.*]] = (f32[1024,1024]{1,0}, s8[4194304]{0}) custom-call(%[[P2]], %[[P3]]) + // CHECK: %[[S1:.*]] = f32[1024,1024]{1,0} get-tuple-element(%[[TUPLE1]]), index=0 + // CHECK: ROOT %[[S2:.*]] = f32[1024,1024]{1,0} add(%[[S0]], %[[S1]]) + })"; + + RunAndFilecheckHloRewrite( + hlo, + GemmRewriter(se::CudaComputeCapability{}, /*toolkit_version=*/0, + GemmRewriterOptions{GemmRewriterOptions::DType::kNonFp8Only, + GemmRewriterOptions::BiasMode::kNoBias}), + expected); +} + +TEST_F(CublasLtGemmRewriteTest, CublasLtSuccessfullyMatchesLargeC64Lhs) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + p0 = c64[2000,3000,3]{2,1,0} parameter(0) + p1 = c64[3,6]{1,0} parameter(1) + ROOT dot = c64[2000,3000,6]{2,1,0} dot(p0, p1), lhs_contracting_dims={2}, rhs_contracting_dims={0} +} +)"; + // Large lhs is fine for cuBLASlt. + if (IsCuda()) { + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); + } else { + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$gemm")"); + } +} + +TEST_F(CublasLtGemmRewriteTest, CublasLtOnlyMatchesLargeC64RhsPostAmpere) { + const char* hlo_text = R"( +HloModule test + +ENTRY test { + p0 = c64[6,3]{1,0} parameter(0) + p1 = c64[3,2000,3000]{2,1,0} parameter(1) + ROOT dot = c64[6,2000,3000]{2,1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + if (HasCudaComputeCapability(se::CudaComputeCapability::Ampere())) { + // From Ampere onwards, cuBLASlt supports large rhs. + MatchOptimizedHlo(hlo_text, + R"(; CHECK: custom_call_target="__cublas$lt$matmul")"); + } else { + // Rhs with non-contracting dimensions > 4194240 (combined) is not fine for + // C64 type. + MatchOptimizedHlo( + hlo_text, R"(; CHECK-NOT: custom_call_target="__cublas$lt$matmul")"); + } +} + class GemmRewriteAllocationTest : public GpuCodegenTest { public: void CheckNumberOfAllocations(const std::string& hlo, diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gemv_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc index 21e5f477e4b059..fddb9e662a0aad 100644 --- a/third_party/xla/xla/service/gpu/gemv_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemv_rewriter.h" +#include "xla/service/gpu/transforms/gemv_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h similarity index 90% rename from third_party/xla/xla/service/gpu/gemv_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h index a041138b8af5c6..933910106c7b4b 100644 --- a/third_party/xla/xla/service/gpu/gemv_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GEMV_REWRITER_H_ -#define XLA_SERVICE_GPU_GEMV_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -41,4 +41,4 @@ class GemvRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GEMV_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMV_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gemv_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc index 2a8b8103e0a94e..d2555286297d81 100644 --- a/third_party/xla/xla/service/gpu/gemv_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/gemv_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gemv_rewriter.h" +#include "xla/service/gpu/transforms/gemv_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/cusolver_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc index ddfda663821eb4..ef78dbd575d7cf 100644 --- a/third_party/xla/xla/service/gpu/cusolver_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/cusolver_rewriter.h" +#include "xla/service/gpu/transforms/gpusolver_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/cusolver_rewriter.h b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h similarity index 89% rename from third_party/xla/xla/service/gpu/cusolver_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h index fd1d84dfa99368..cdc0ff24f9fb1a 100644 --- a/third_party/xla/xla/service/gpu/cusolver_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/gpusolver_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ -#define XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -44,4 +44,4 @@ class GpusolverRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_CUSOLVER_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_GPUSOLVER_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc similarity index 97% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index c6938569686611..befe869ac072df 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" #include #include @@ -169,13 +169,13 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { } // namespace -absl::StatusOr GpuHorizontalInputFusion::RunOnComputation( +absl::StatusOr HorizontalInputFusion::RunOnComputation( HloComputation* computation) { HorizontalInputFusionImpl horizontal_fusion_impl(computation, device_info_); return horizontal_fusion_impl.Run(); } -absl::StatusOr GpuHorizontalInputFusion::Run( +absl::StatusOr HorizontalInputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h similarity index 74% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion.h rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h index 370ce7bd0509af..a08168d4c3f5a5 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ -#define XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -29,24 +29,22 @@ namespace gpu { // This optimization pass horizontally fuses kInput fusions to both reduce the // kernel launch overhead and increase parallelism degree. See -// GpuHorizontalFusion for general description and motivation about horizontal -// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals +// HorizontalLoopFusion for general description and motivation about horizontal +// fusion. HorizontalLoopFusion deals with kLoop fusions while this pass deals // with kInput fusions. // -// Following GpuHorizontalFusion, a simple yet effective heuristic is used +// Following HorizontalLoopFusion, a simple yet effective heuristic is used // to search the fusion candidates while avoiding creating cycles. That is, // we simply search for fusion candidates by looking for instructions whose // outputs are all consumed by the same instruction. This catches the typical // target cases; often, the candidate instructions are just consumed by the // ROOT tuple of the entry computation. -class GpuHorizontalInputFusion : public HloModulePass { +class HorizontalInputFusion : public HloModulePass { public: - explicit GpuHorizontalInputFusion(const se::DeviceDescription& d) + explicit HorizontalInputFusion(const se::DeviceDescription& d) : device_info_(d) {} - absl::string_view name() const override { - return "gpu_horizontal_input_fusion"; - } + absl::string_view name() const override { return "horizontal_input_fusion"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -62,4 +60,4 @@ class GpuHorizontalInputFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_INPUT_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc index 2d458f9db452d1..5fc1a54acd8d53 100644 --- a/third_party/xla/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_input_fusion.h" +#include "xla/service/gpu/transforms/horizontal_input_fusion.h" #include #include @@ -42,7 +42,7 @@ class HorizontalInputFusionTest : public GpuCodegenTest { public: se::DeviceDescription device_description_{ TestGpuDeviceInfo::RTXA6000DeviceInfo()}; - GpuHorizontalInputFusion horizontal_input_fusion_{device_description_}; + HorizontalInputFusion horizontal_input_fusion_{device_description_}; }; TEST_F(HorizontalInputFusionTest, BasicTest) { diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 80c46cb7a5d5af..0a3d705103c416 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" #include #include @@ -713,13 +713,13 @@ absl::StatusOr HorizontalLoopFusionImpl::Run() { } // namespace -absl::StatusOr GpuHorizontalLoopFusion::RunOnComputation( +absl::StatusOr HorizontalLoopFusion::RunOnComputation( HloComputation* computation) { HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_); return horizontal_fusion_impl.Run(); } -absl::StatusOr GpuHorizontalLoopFusion::Run( +absl::StatusOr HorizontalLoopFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { VLOG(2) << "Run horizontal fusion."; diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h similarity index 92% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion.h rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h index 5daed0378aa903..f29bcd31044991 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ -#define XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ #include @@ -122,15 +122,12 @@ namespace gpu { // outputs of Mul and Add are row-major. // // Note, reshapes are added only if the tensors isn't already a vector. -class GpuHorizontalLoopFusion : public HloModulePass { +class HorizontalLoopFusion : public HloModulePass { public: - GpuHorizontalLoopFusion() = default; - explicit GpuHorizontalLoopFusion(absl::string_view prefix) - : prefix_(prefix) {} + HorizontalLoopFusion() = default; + explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {} - absl::string_view name() const override { - return "gpu_horizontal_loop_fusion"; - } + absl::string_view name() const override { return "horizontal_loop_fusion"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -145,4 +142,4 @@ class GpuHorizontalLoopFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_HORIZONTAL_LOOP_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_HORIZONTAL_LOOP_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index 935c21c6e23fed..781d27a64d716c 100644 --- a/third_party/xla/xla/service/gpu/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/horizontal_loop_fusion.h" +#include "xla/service/gpu/transforms/horizontal_loop_fusion.h" #include #include @@ -27,7 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_pass_fix.h" @@ -39,7 +39,7 @@ limitations under the License. #include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace gpu { @@ -85,7 +85,7 @@ TEST_F(HorizontalLoopFusionTest, BasicTest) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -136,7 +136,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { @@ -172,7 +172,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { @@ -259,7 +259,7 @@ TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); int input_fusion_count = 0; int loop_fusion_count = 0; @@ -308,7 +308,7 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { fusion.AddPass(/*may_duplicate=*/true, device_info); EXPECT_TRUE(fusion.Run(module.get()).value()); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); VLOG(2) << "Dump after horizontal fusion:"; @@ -415,7 +415,7 @@ TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { )") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -545,7 +545,7 @@ TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { })") .value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -586,7 +586,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { )") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { @@ -627,7 +627,7 @@ TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(); iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); @@ -699,7 +699,7 @@ TEST_F(HorizontalLoopFusionTest, TraversalOrder) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); // Verify that the total number of fusion instructions is 2 so that we @@ -773,7 +773,7 @@ ENTRY main { )"; auto module = ParseAndReturnUnverifiedModule(hlo_text).value(); - EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); VLOG(2) << module->ToString(); @@ -843,7 +843,7 @@ TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) { })") .value(); - EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc similarity index 99% rename from third_party/xla/xla/service/gpu/instruction_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index 8751d44f8972ea..5e32f2ec0c2ee1 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include #include diff --git a/third_party/xla/xla/service/gpu/instruction_fusion.h b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h similarity index 94% rename from third_party/xla/xla/service/gpu/instruction_fusion.h rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion.h index 29eb0325e1a23b..d7fb7f2cb47ded 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ -#define XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ #include @@ -79,4 +79,4 @@ class GpuInstructionFusion : public InstructionFusion { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_INSTRUCTION_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/instruction_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc index fa96edfd364aa2..140cc6e52641ea 100644 --- a/third_party/xla/xla/service/gpu/instruction_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/instruction_fusion.h" +#include "xla/service/gpu/transforms/instruction_fusion.h" #include #include @@ -126,12 +126,14 @@ TEST_F(InstructionFusionTest, TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); - HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); - HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( - ShapeUtil::MakeShape(F32, {}), HloOpcode::kExp, const0)); - HloInstruction* transpose2 = builder.AddInstruction( - HloInstruction::CreateTranspose(ShapeUtil::MakeShape(F32, {}), exp1, {})); + Shape operand_shape = ShapeUtil::MakeShape(F32, {64, 32}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, operand_shape, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateUnary(operand_shape, HloOpcode::kExp, param)); + HloInstruction* transpose2 = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {32, 64}), exp1, {1, 0})); auto module = CreateNewVerifiedModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -464,7 +466,7 @@ TEST_F(InstructionFusionTest, MultiOutputFusion) { .value(); // Multi-output fusion is disabled here and performed in the - // GpuMultiOutputFusion pass instead. + // MultiOutputFusion pass instead. ASSERT_FALSE(duplicating_instruction_fusion_.Run(module.get()).value()); } diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment.cc rename to third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index b260353292de47..9c62c35417708d 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/transforms/layout_assignment.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h similarity index 94% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment.h rename to third_party/xla/xla/service/gpu/transforms/layout_assignment.h index 70741fea030efb..efa58f3f8c3c72 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment.h +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ -#define XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ #include #include @@ -78,4 +78,4 @@ class GpuLayoutAssignment : public LayoutAssignment { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_LAYOUT_ASSIGNMENT_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc rename to third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 81f9e00548d9da..dd1cbc65bb3fde 100644 --- a/third_party/xla/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_layout_assignment.h" +#include "xla/service/gpu/transforms/layout_assignment.h" #include #include diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc similarity index 99% rename from third_party/xla/xla/service/gpu/move_copy_to_users.cc rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc index acc10db6af6927..ae66093da4507d 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users.cc +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/move_copy_to_users.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" #include diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users.h b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h similarity index 87% rename from third_party/xla/xla/service/gpu/move_copy_to_users.h rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h index 4a7dfb43bbf6ec..698db0460602f1 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users.h +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ -#define XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -36,4 +36,4 @@ class MoveCopyToUsers : public HloModulePass { } // end namespace xla -#endif // XLA_SERVICE_GPU_MOVE_COPY_TO_USERS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_MOVE_COPY_TO_USERS_H_ diff --git a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/move_copy_to_users_test.cc rename to third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc index 10179c1b32cacd..85999dbf63a5b5 100644 --- a/third_party/xla/xla/service/gpu/move_copy_to_users_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/move_copy_to_users_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/move_copy_to_users.h" +#include "xla/service/gpu/transforms/move_copy_to_users.h" #include diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc similarity index 96% rename from third_party/xla/xla/service/gpu/multi_output_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 6ac1217151aa65..35bfe8eb092038 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/multi_output_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include #include @@ -307,13 +307,13 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, } // namespace -void GpuMultiOutputFusion::RecomputeReachability() { +void MultiOutputFusion::RecomputeReachability() { reachability_ = HloDfsReachability::Build(computation_); } -bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, - FusionInfoCache* fusion_info_cache, - GpuHloCostAnalysis* cost_analysis) { +bool MultiOutputFusion::FuseSiblings(HloInstruction* parent, + FusionInfoCache* fusion_info_cache, + GpuHloCostAnalysis* cost_analysis) { const HloComputation* computation = parent->parent(); const HloModule* module = computation->parent(); bool dump_fusion = @@ -402,7 +402,7 @@ bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent, return changed; } -absl::StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { +absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { bool changed = false; RecomputeReachability(); GpuHloCostAnalysis cost_analysis({shape_size_function_, @@ -494,9 +494,9 @@ absl::StatusOr GpuMultiOutputFusion::DoMultiOutputFusion() { return changed; } -void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer, - absl::string_view label, - const HloInstruction* producer) { +void MultiOutputFusion::DumpFusionState(const HloInstruction& consumer, + absl::string_view label, + const HloInstruction* producer) { if (consumer.GetModule() ->config() .debug_options() @@ -505,7 +505,7 @@ void GpuMultiOutputFusion::DumpFusionState(const HloInstruction& consumer, } } -absl::StatusOr GpuMultiOutputFusion::Run( +absl::StatusOr MultiOutputFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion.h b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h similarity index 93% rename from third_party/xla/xla/service/gpu/multi_output_fusion.h rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h index 82789d3be5791d..9ebabe6b460000 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ -#define XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ #include @@ -74,7 +74,7 @@ namespace gpu { // Note that sibling (1) and producer-consumer (2) multi-output fusion can be // combined. // -// The GpuMultiOutputFusion pass modifies the HLO in reverse post-order (defs +// The MultiOutputFusion pass modifies the HLO in reverse post-order (defs // before uses). First, it attempts to fuse the consumer ops of the current op, // which are siblings (1). Hereafter, it attempts to fuse the current op with // one of its consumers (2). This order avoids a phase ordering issue (described @@ -83,7 +83,7 @@ namespace gpu { // order of traversal, and hence, not get into the way of subsequent fusion // attempts. // -// The GpuMultiOutputFusion pass ensures several conditions are met for fusion. +// The MultiOutputFusion pass ensures several conditions are met for fusion. // Some of them are relevant for correctness. In particular, no cycles must be // introduced into the HLO module. Moreover, the code emitters for multi-output // fusion must support the combination of ops and their shapes. Other @@ -92,9 +92,9 @@ namespace gpu { // * Sibling fusion (1) does not fuse kInput fusions with kLoop fusions, i.e. // the fusion kinds must match. -class GpuMultiOutputFusion : public HloModulePass { +class MultiOutputFusion : public HloModulePass { public: - explicit GpuMultiOutputFusion( + explicit MultiOutputFusion( const se::DeviceDescription& device_info, HloCostAnalysis::ShapeSizeFunction shape_size_function) : device_info_(device_info), shape_size_function_(shape_size_function) {} @@ -131,4 +131,4 @@ class GpuMultiOutputFusion : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_MULTI_OUTPUT_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_MULTI_OUTPUT_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc similarity index 95% rename from third_party/xla/xla/service/gpu/multi_output_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc index b333a04a841882..4b6920464c8b51 100644 --- a/third_party/xla/xla/service/gpu/multi_output_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/multi_output_fusion.h" +#include "xla/service/gpu/transforms/multi_output_fusion.h" #include #include @@ -48,17 +48,15 @@ class MultiOutputFusionTest : public HloTestBase { } public: - GpuMultiOutputFusion mof_{ - TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}; + MultiOutputFusion mof_{TestGpuDeviceInfo::RTXA6000DeviceInfo(), + ShapeSizeBytesFunction()}; - void CheckGpuMultiOutputFusion(absl::string_view hlo, - std::optional expected) { + void CheckMultiOutputFusion(absl::string_view hlo, + std::optional expected) { RunAndFilecheckHloRewrite( hlo, - GpuMultiOutputFusion{ - TestGpuDeviceInfo::RTXA6000DeviceInfo(), - ShapeSizeBytesFunction()}, + MultiOutputFusion{TestGpuDeviceInfo::RTXA6000DeviceInfo(), + ShapeSizeBytesFunction()}, expected); } }; @@ -179,7 +177,7 @@ ENTRY entry { ROOT root = (f32[512]{0}, f16[512]{0}) tuple(reduce.1, fusion) })"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation // CHECK-NEXT: [[param_0_2_0:%[^ ]+]] = f32[128,512,28,28]{3,2,1,0} parameter(0) // CHECK-NEXT: [[c_1_1:%[^ ]+]] = f16[128,512,28,28]{3,2,1,0} convert([[param_0_2_0]]) @@ -1529,6 +1527,8 @@ ENTRY main { } )") .value(); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_FALSE(mof_.Run(module.get()).value()); } @@ -1779,7 +1779,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) @@ -1797,28 +1797,28 @@ TEST_F(TransposeMultiOutputFusionTest, MultipleTransposes) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - ROOT c.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0} + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + ROOT c.1 = f32[1,32,16]{2,1,0} transpose(s.1), dimensions={0,2,1} } ENTRY main { - p = f32[16,32]{1,0} parameter(0) - fusion = f32[32,16]{1,0} fusion(p), kind=kInput, calls=fused_computation - c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0} - ROOT t = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple(fusion, c1) + p = f32[1,16,32]{2,1,0} parameter(0) + fusion = f32[1,32,16]{2,1,0} fusion(p), kind=kInput, calls=fused_computation + c1 = f32[1,32,16]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT t = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( -// CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[32,16], f32[32,16]) { -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]), dimensions={1,0} -// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[32,16]{1,0} transpose([[param_0_1_0]]), dimensions={1,0} -// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) tuple([[c_1_2]], [[c1_1_3]]) + CheckMultiOutputFusion(hlo, R"( +// CHECK: %fused_computation (param_0.1: f32[1,16,32]) -> (f32[1,32,16], f32[1,32,16]) { +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]), dimensions={0,2,1} +// CHECK-NEXT: [[c1_1_3:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[param_0_1_0]]), dimensions={0,2,1} +// CHECK-NEXT: ROOT [[tuple_4:%[^ ]+]] = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) tuple([[c_1_2]], [[c1_1_3]]) // CHECK-NEXT: } -// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16]{1,0}, f32[32,16]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] +// CHECK: [[fusion_0:%[^ ]+]] = (f32[1,32,16]{2,1,0}, f32[1,32,16]{2,1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] )"); } @@ -1827,27 +1827,27 @@ TEST_F(TransposeMultiOutputFusionTest, CopyAndTranspose) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - ROOT c.1 = f32[16,32]{0,1} copy(s.1) + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + ROOT c.1 = f32[1,16,32]{1,2,0} copy(s.1) } ENTRY main { - p = f32[16,32]{1,0} parameter(0) - fusion = f32[16,32]{0,1} fusion(p), kind=kInput, calls=fused_computation - c1 = f32[32,16]{1,0} transpose(p), dimensions={1,0} - ROOT t = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple(fusion, c1) + p = f32[1,16,32]{2,1,0} parameter(0) + fusion = f32[1,16,32]{1,2,0} fusion(p), kind=kInput, calls=fused_computation + c1 = f32[1,32,16]{2,1,0} transpose(p), dimensions={0,2,1} + ROOT t = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( - // CHECK: %fused_computation ({{[^ ]+}} f32[16,32]) -> (f32[16,32], f32[32,16]) { - // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) - // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0]]) - // CHECK-NEXT: [[copy:%[^ ]+]] = f32[16,32]{0,1} copy([[s_1]]) - // CHECK-NEXT: [[transpose:[^ ]+]] = f32[32,16]{1,0} transpose([[param_0]]), dimensions={1,0} - // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[16,32]{0,1}, f32[32,16]{1,0}) tuple([[copy]], [[transpose]]) - // CHECK: %fusion = (f32[16,32]{0,1}, f32[32,16]{1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation + CheckMultiOutputFusion(hlo, R"( + // CHECK: %fused_computation ({{[^ ]+}} f32[1,16,32]) -> (f32[1,16,32], f32[1,32,16]) { + // CHECK-NEXT: [[param_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) + // CHECK-NEXT: [[s_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0]]) + // CHECK-NEXT: [[copy:%[^ ]+]] = f32[1,16,32]{1,2,0} copy([[s_1]]) + // CHECK-NEXT: [[transpose:[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[param_0]]), dimensions={0,2,1} + // CHECK-NEXT: ROOT {{[^ ]+}} = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) tuple([[copy]], [[transpose]]) + // CHECK: %fusion = (f32[1,16,32]{1,2,0}, f32[1,32,16]{2,1,0}) fusion(%{{.*}}), kind=kInput, calls=%fused_computation )"); } @@ -1869,7 +1869,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f16[16,32]) -> (f32[16,32], f16[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f16[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} convert([[param_0_1_0]]) @@ -1906,7 +1906,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } // Do not group incompatible transposes. @@ -1939,7 +1939,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } // A variation of the test above, where no CSE was run. @@ -1973,7 +1973,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, std::nullopt); + CheckMultiOutputFusion(hlo, std::nullopt); } TEST_F(TransposeMultiOutputFusionTest, CopyAndInput) { @@ -1994,7 +1994,7 @@ ENTRY main { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation (param_0.1: f32[16,32]) -> (f32[16,32], f32[16,32]) { // CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) // CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) @@ -2011,30 +2011,30 @@ TEST_F(TransposeMultiOutputFusionTest, TransposeAndInputEpilogueFusion) { HloModule module fused_computation { - param_0.1 = f32[16,32]{1,0} parameter(0) - s.1 = f32[16,32]{1,0} sqrt(param_0.1) - t.1 = f32[32,16]{1,0} transpose(s.1), dimensions={1,0} + param_0.1 = f32[1,16,32]{2,1,0} parameter(0) + s.1 = f32[1,16,32]{2,1,0} sqrt(param_0.1) + t.1 = f32[1,32,16]{2,1,0} transpose(s.1), dimensions={0,2,1} ROOT out = f32[32,16,1]{2,1,0} bitcast(t.1) } ENTRY main { - p = f32[16,32]{1,0} parameter(0) + p = f32[1,16,32]{2,1,0} parameter(0) fusion = f32[32,16,1]{2,1,0} fusion(p), kind=kInput, calls=fused_computation - c1 = exponential(p) - ROOT t = tuple(fusion, c1) + c1 = f32[1,16,32]{2,1,0} exponential(p) + ROOT t = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) tuple(fusion, c1) } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation -// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[16,32]{1,0} parameter(0) -// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[16,32]{1,0} sqrt([[param_0_1_0]]) -// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[32,16]{1,0} transpose([[s_1_1]]) +// CHECK-NEXT: [[param_0_1_0:%[^ ]+]] = f32[1,16,32]{2,1,0} parameter(0) +// CHECK-NEXT: [[s_1_1:%[^ ]+]] = f32[1,16,32]{2,1,0} sqrt([[param_0_1_0]]) +// CHECK-NEXT: [[c_1_2:%[^ ]+]] = f32[1,32,16]{2,1,0} transpose([[s_1_1]]) // CHECK-NEXT: [[out_3:%[^ ]+]] = f32[32,16,1]{2,1,0} bitcast([[c_1_2]]) -// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[16,32]{1,0} exponential([[param_0_1_0]]) -// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) tuple([[out_3]], [[c1_1_4]]) +// CHECK-NEXT: [[c1_1_4:%[^ ]+]] = f32[1,16,32]{2,1,0} exponential([[param_0_1_0]]) +// CHECK-NEXT: ROOT [[tuple_5:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) tuple([[out_3]], [[c1_1_4]]) // CHECK-NEXT: } -// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[16,32]{1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] +// CHECK: [[fusion_0:%[^ ]+]] = (f32[32,16,1]{2,1,0}, f32[1,16,32]{2,1,0}) fusion([[p_1:%[^ ]+]]), kind=kInput, calls=[[fused_computation_2:%[^ ]+]] )"); } @@ -2071,7 +2071,7 @@ ENTRY computation { )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_elementwise // CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[200]{0} parameter(0) // CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[200]{0} sqrt([[p_1_0]]) @@ -2115,7 +2115,7 @@ ENTRY computation { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_elementwise (p.1: f32[10,20]) -> (f32[10,20], f32[]) { // CHECK-NEXT: [[p_1_0:%[^ ]+]] = f32[10,20]{1,0} parameter(0) // CHECK-NEXT: [[r_1_1:%[^ ]+]] = f32[10,20]{1,0} sqrt([[p_1_0]]) @@ -2177,7 +2177,7 @@ ENTRY computation { } )"; - CheckGpuMultiOutputFusion(hlo, R"( + CheckMultiOutputFusion(hlo, R"( // CHECK: %fused_computation.1 (param_0.8: f32[], param_1.10: f32[], param_2.7: f16[100,200]) -> (f16[100,200], f32[]) { // CHECK-NEXT: [[one_3_0:%[^ ]+]] = f32[] constant(1) // CHECK-NEXT: [[one_b_3_1:%[^ ]+]] = f32[100,200]{1,0} broadcast([[one_3_0]]), dimensions={} diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc new file mode 100644 index 00000000000000..0ddc33e3d21c46 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.cc @@ -0,0 +1,33 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +absl::StatusOr PGLEAccuracyChecker::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + TF_RETURN_IF_ERROR(pgle_estimator_.CheckAccuracy(*module)); + return false; +} + +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h new file mode 100644 index 00000000000000..35bfd100f36a33 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" +#include "xla/service/profile_guided_latency_estimator.h" + +namespace xla::gpu { + +// This pass checks the accuracy of the input feedback-driven optimization (FDO) +// profile. If any non-NOP instruction from the given HloModule is not present +// in the profile this pass fails. +class PGLEAccuracyChecker : public HloModulePass { + public: + explicit PGLEAccuracyChecker(ProfileGuidedLatencyEstimator& pgle_estimator) + : pgle_estimator_(pgle_estimator) {} + absl::string_view name() const override { return "pgle-accuracy-checker"; } + + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + ProfileGuidedLatencyEstimator& pgle_estimator_; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_PGLE_ACCURACY_CHECKER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc new file mode 100644 index 00000000000000..3f2d1ab6426fd0 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/pgle_accuracy_checker_test.cc @@ -0,0 +1,159 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/pgle_accuracy_checker.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/gpu_latency_hiding_scheduler.h" +#include "xla/service/latency_hiding_scheduler.h" +#include "xla/service/profile_guided_latency_estimator.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/protobuf.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" + +namespace xla::gpu { +namespace { + +using PGLEAccuracyCheckerTest = HloTestBase; +using ::tensorflow::profiler::ProfiledInstructionsProto; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::StatusIs; + +// Constructs PGLE estimator for a given `profile`. +std::unique_ptr GetProfileGuidedLatencyEstimator( + ProfiledInstructionsProto& profile) { + auto gpu_latency_estimator = + std::make_unique(/*pointer_size=*/8); + SchedulerConfig config; + auto aggregator = std::make_unique(); + return std::make_unique( + config, std::move(gpu_latency_estimator), profile, std::move(aggregator)); +} + +TEST_F(PGLEAccuracyCheckerTest, + ReturnsOkAndNoIRChangeIfAllInstructionsAreFoundInTheProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT _ = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + const std::string kProfileString = R"pb( + costs { name: "dot0" cost_us: 1.0 } + costs { name: "dot1" cost_us: 1.0 } + costs { name: "add0" cost_us: 1.0 } + costs { name: "ar-start" cost_us: 1.0 } + costs { name: "ar-start1" cost_us: 1.0 } + )pb"; + + ProfiledInstructionsProto profile; + ASSERT_TRUE(TextFormat::ParseFromString(kProfileString, &profile)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + *module->mutable_config().mutable_fdo_profile() = kProfileString; + + auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); + PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + pgle_accuracy_checker.Run(module.get())); + EXPECT_FALSE(changed); +} + +TEST_F(PGLEAccuracyCheckerTest, + ReturnsInvalidArgumentIfThereAreMissingInstructionsFromTheProfile) { + const absl::string_view kHloString = R"( + HloModule m + + apply_op { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT apply_op = f32[] add(x, y) + } + + ENTRY ar { + p0 = f32[32] parameter(0) + p1 = f32[32,32] parameter(1) + p2 = f32[32,32] parameter(2) + p3 = f32[32] parameter(3) + + dot0 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + dot1 = f32[32,32]{1,0} custom-call(p1, p2), custom_call_target="__cublas$gemm" + add0 = f32[32,32] add(dot0, dot1) + + ar-start = f32[32] all-reduce-start(p0), to_apply=apply_op + ar-done = f32[32] all-reduce-done(ar-start) + + ar-start1 = f32[32] all-reduce-start(p3), to_apply=apply_op + ar-done1 = f32[32] all-reduce-done(ar-start1) + + ROOT _ = (f32[32],f32[32],f32[32,32]) tuple(ar-done, ar-done1, add0) + })"; + + // Profile string, cost does not matter. + // We're missing `dot1` and `ar-start` from the profile. + const std::string kProfileString = R"pb( + costs { name: "dot0" cost_us: 1.0 } + costs { name: "add0" cost_us: 1.0 } + costs { name: "ar-start1" cost_us: 1.0 } + )pb"; + + ProfiledInstructionsProto profile; + ASSERT_TRUE(TextFormat::ParseFromString(kProfileString, &profile)); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + *module->mutable_config().mutable_fdo_profile() = kProfileString; + + auto pgle_estimator = GetProfileGuidedLatencyEstimator(profile); + PGLEAccuracyChecker pgle_accuracy_checker(*pgle_estimator); + EXPECT_THAT(pgle_accuracy_checker.Run(module.get()), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +} // namespace +} // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc index d0e841c4f9ebc1..493d1671e0e7fc 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.cc @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h similarity index 94% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h index 88b6bb662f2ed7..d2aca8ca17064c 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ -#define XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_pass_interface.h" @@ -130,4 +130,4 @@ class PipelinedP2PRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_PIPELINED_P2P_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_PIPELINED_P2P_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc index a0d58306cfa93b..287603c6d0de93 100644 --- a/third_party/xla/xla/service/gpu/pipelined_p2p_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/pipelined_p2p_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/pipelined_p2p_rewriter.h" +#include "xla/service/gpu/transforms/pipelined_p2p_rewriter.h" #include #include "absl/strings/string_view.h" @@ -228,8 +228,8 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { CHECK: %send-data = add(%c, %s) CHECK: %after-all = after-all() CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} - CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} - CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}} + CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}} CHECK: ROOT %tuple = tuple(%new-count, %recv, %send) CHECK: } @@ -245,8 +245,8 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined1) { CHECK: %f0 = constant(0) CHECK: %init = broadcast(%f0), dimensions={} CHECK: %after-all.1 = after-all() - CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} - CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}, {3,4}}"} + CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}} + CHECK{LITERAL}: %send.1 = send(%init, %after-all.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}, {3,4}}} CHECK: %while-init = tuple(%c0, %recv.1, %send.1) CHECK: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, CHECK-SAME{LITERAL}: backend_config={"known_trip_count":{"n":"25"}} @@ -616,11 +616,11 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { CHECK: %after-all = after-all() CHECK: %send-done.p.clone = send-done(%get-tuple-element.1), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0"} CHECK: %send-done.1.p.clone = send-done(%get-tuple-element.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1"} - CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} - CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK{LITERAL}: %recv = recv(%after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}} + CHECK{LITERAL}: %send = send(%send-data, %after-all), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}} CHECK: %after-all.1 = after-all() - CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} - CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK{LITERAL}: %recv.1 = recv(%after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}} + CHECK{LITERAL}: %send.1 = send(%send-data, %after-all.1), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}} CHECK: ROOT %tuple = tuple(%new-count, %recv, %send, %recv.1, %send.1) CHECK: } @@ -636,11 +636,11 @@ TEST_F(PipelinedP2pRewriterTest, SendRecvPipelined2) { CHECK: %f0 = constant(0) CHECK: %init = broadcast(%f0), dimensions={} CHECK: %after-all.2 = after-all() - CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} - CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs="{{3,0}}"} + CHECK{LITERAL}: %recv.2 = recv(%after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}} + CHECK{LITERAL}: %send.2 = send(%init, %after-all.2), channel_id=1, frontend_attributes={_xla_send_recv_pipeline="0",_xla_send_recv_source_target_pairs={{3,0}}} CHECK: %after-all.3 = after-all() - CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} - CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs="{{0,1}, {1,2}, {2,3}}"} + CHECK{LITERAL}: %recv.3 = recv(%after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}} + CHECK{LITERAL}: %send.3 = send(%init, %after-all.3), channel_id=2, frontend_attributes={_xla_send_recv_pipeline="1",_xla_send_recv_source_target_pairs={{0,1}, {1,2}, {2,3}}} CHECK: %while-init = tuple(%c0, %recv.2, %send.2, %recv.3, %send.3) CHECK{LITERAL}: %while-result.p.clone = while(%while-init), condition=%while-cond, body=%while-body, backend_config={"known_trip_count":{"n":"25"}} CHECK: %get-tuple-element.4 = get-tuple-element(%while-result.p.clone), index=1 diff --git a/third_party/xla/xla/service/gpu/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc similarity index 81% rename from third_party/xla/xla/service/gpu/priority_fusion.cc rename to third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index d57a83c20ee8d1..bae58de45849a1 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/priority_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include #include @@ -31,6 +31,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/meta/type_traits.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" @@ -46,14 +47,19 @@ limitations under the License. #include "xla/service/fusion_queue.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/fusion_process_dump.pb.h" +#include "xla/service/gpu/fusions/triton/triton_support.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/hlo_traversal.h" +#include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/fusion_analysis_cache.h" #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" +#include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model.h" #include "xla/service/gpu/model/gpu_performance_model_base.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" +#include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/hlo_graph_dumper.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" @@ -64,6 +70,7 @@ limitations under the License. #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -114,6 +121,19 @@ bool IsFusible(const HloInstruction& instr) { } } +// Returns a GpuBackendConfig proto for a Triton fusion with the given +// BlockLevelParameters. +GpuBackendConfig GetTritonGpuBackendConfig( + const BlockLevelParameters& block_level_parameters) { + GpuBackendConfig gpu_backend_config; + gpu_backend_config.mutable_fusion_backend_config()->set_kind( + std::string(kTritonFusionKind)); + *gpu_backend_config.mutable_fusion_backend_config() + ->mutable_block_level_fusion_config() = + block_level_parameters.ToBlockLevelFusionConfig(); + return gpu_backend_config; +} + // An implementation of FusionQueue that determines whether to fuse instructions // according to a cost model, and chooses the next fusion candidate according to // dynamically updated priorities. The elements in the queue are producer nodes @@ -121,23 +141,26 @@ bool IsFusible(const HloInstruction& instr) { // performance when fusing it to all of its fusible users. We greedily pick the // max-benefit producer to fuse, and update the estimated benefits of the fused // nodes and their operands. -class GpuPriorityFusionQueue { +class PriorityFusionQueue { using Priority = int64_t; using CanFuseCallback = std::function; public: - GpuPriorityFusionQueue( - HloComputation* computation, - const GpuHloCostAnalysis::Options& cost_analysis_options, - const se::DeviceDescription* device_info, - FusionProcessDumpProto* fusion_process_dump, - tsl::thread::ThreadPool* thread_pool, mlir::MLIRContext* mlir_context, - HloFusionAnalysisCache& fusion_analysis_cache, - bool triton_softmax_priority_fusion_enabled) + PriorityFusionQueue(HloComputation* computation, + const GpuHloCostAnalysis::Options& cost_analysis_options, + const se::DeviceDescription* device_info, + FusionProcessDumpProto* fusion_process_dump, + tsl::thread::ThreadPool* thread_pool, + mlir::MLIRContext* mlir_context, + HloFusionAnalysisCache& fusion_analysis_cache, + bool triton_softmax_priority_fusion_enabled) : computation_(computation), device_info_(device_info), cost_analysis_(cost_analysis_options, *device_info), + gpu_indexing_performance_model_(device_info, &fusion_analysis_cache, + cost_analysis_options.shape_size, + mlir_context), fusion_process_dump_(fusion_process_dump), thread_pool_(thread_pool), mlir_context_(mlir_context), @@ -155,6 +178,7 @@ class GpuPriorityFusionQueue { // Initializes the priority queue. std::vector instructions; for (auto* instruction : computation->MakeInstructionPostOrder()) { + TF_CHECK_OK(UpdatePerformanceModelCache(instruction)); if (instruction->opcode() == HloOpcode::kParameter || instruction->user_count() == 0 || !instruction->IsFusible() || instruction->opcode() == HloOpcode::kTuple || @@ -163,7 +187,6 @@ class GpuPriorityFusionQueue { } instructions.push_back(instruction); } - ComputeAndSetPriorities(instructions); } @@ -247,18 +270,49 @@ class GpuPriorityFusionQueue { return !current_consumers_.empty(); } + absl::Status UpdatePerformanceModelCache(HloInstruction* producer) { + bool is_triton_fusion = IsGenericTritonFusion(*producer); + if (!IsFusible(*producer) && !is_triton_fusion) { + return absl::OkStatus(); + } + + if (gpu_performance_model_cache_.Get(*producer)) { + return absl::OkStatus(); + } + + EstimateRunTimeData runtime_data; + if (is_triton_fusion) { + TF_ASSIGN_OR_RETURN( + runtime_data, + gpu_indexing_performance_model_.EstimateRunTimeForTriton(producer)); + } else { + auto config = GpuPerformanceModelOptions::PriorityFusion( + &fusion_analysis_cache_, &gpu_performance_model_cache_); + runtime_data = GpuPerformanceModel::EstimateRunTimeForInstruction( + producer, *device_info_, &cost_analysis_, config); + } + + gpu_performance_model_cache_.Set(*producer, runtime_data); + + return absl::OkStatus(); + } + // Update priorities of all affected ops. - void UpdatePriorities() { + absl::Status UpdatePriorities() { // Revisit costs of all updated ops. It's important to update cost analysis // before recalculating priorities. for (auto instruction : to_update_priority_) { - TF_CHECK_OK(cost_analysis_.RevisitInstruction(instruction)); + TF_RETURN_IF_ERROR(cost_analysis_.RevisitInstruction(instruction)); + } + for (auto producer : to_update_priority_) { + TF_RETURN_IF_ERROR(UpdatePerformanceModelCache(producer)); } ComputeAndSetPriorities(std::vector{ to_update_priority_.begin(), to_update_priority_.end()}); to_update_priority_.clear(); + return absl::OkStatus(); } // Prepares producer and consumer instruction to be fused. Invalidates caches @@ -271,8 +325,6 @@ class GpuPriorityFusionQueue { consumer->name(), "| inside PriorityFusion"), *consumer, producer); } - - InvalidateCaches(producer); InvalidateCaches(consumer); } @@ -287,6 +339,14 @@ class GpuPriorityFusionQueue { } } + block_level_parameters_cache_.erase(instruction); + for (const HloInstruction* operand : instruction->operands()) { + auto it = block_level_parameters_cache_.find(operand); + if (it != block_level_parameters_cache_.end()) { + it->second.erase(instruction); + } + } + gpu_performance_model_cache_.Invalidate(*instruction); fusion_analysis_cache_.Invalidate(*instruction); fusion_info_cache_.Invalidate(instruction); @@ -327,6 +387,7 @@ class GpuPriorityFusionQueue { // calculations on 'fusion.operands' below, before it is finally removed // in 'RemoveInstruction'. if (original_producer->user_count() == 0) { + InvalidateCaches(original_producer); original_producer->DetachFromOperandsAndUsers(); } @@ -361,6 +422,17 @@ class GpuPriorityFusionQueue { reverse_map_.erase(reverse_it); } + // Returns a map from consumer to BlockLevelParameters. This is used to + // determine if a producer-consumer fusion is a Triton fusion. + absl::flat_hash_map + GetBlockLevelParametersMap(const HloInstruction* producer) { + auto it = block_level_parameters_cache_.find(producer); + if (it == block_level_parameters_cache_.end()) { + return {}; + } + return it->second; + } + HloInstruction* current_producer() { return current_producer_; } const std::vector& current_consumers() { @@ -417,6 +489,24 @@ class GpuPriorityFusionQueue { run_times.time_fused); } + FusionDecision IsTritonSupported(const HloInstruction& instruction) { + if (instruction.opcode() != HloOpcode::kFusion) { + return IsTritonSupportedInstruction( + instruction, device_info_->gpu_compute_capability()); + } + + for (const HloInstruction* instruction : + instruction.fused_instructions_computation()->instructions()) { + if (auto codegen_decision = IsTritonSupportedInstruction( + *instruction, device_info_->gpu_compute_capability()); + !codegen_decision) { + return codegen_decision; + } + } + + return {}; + } + FusionDecision CanFuseTriton(HloInstruction* producer, HloInstruction* consumer) { if (!triton_softmax_priority_fusion_enabled_) { @@ -427,24 +517,52 @@ class GpuPriorityFusionQueue { if (!IsFusible(*consumer)) { return "the consumer is not fusible"; } + + if (auto fusion_decision = IsTritonSupported(*consumer); + !fusion_decision) { + return fusion_decision; + } } else { if (!IsFusible(*producer)) { return "the producer is not fusible"; } + + if (auto fusion_decision = IsTritonSupported(*producer); + !fusion_decision) { + return fusion_decision; + } } auto fusion = HloFusionAdaptor::ForProducerConsumer(producer, consumer); - SymbolicTileAnalysisOrError symbolic_tile_analysis_or = - SymbolicTileAnalysis::AnalyzeFusion(*fusion, mlir_context_); + absl::StatusOr tiled_run_time_data_or_error = + gpu_indexing_performance_model_.TryFindBestTilingForFusion(*fusion); + + if (!tiled_run_time_data_or_error.ok()) { + return FusionDecision{ + absl::StrCat("TiledRunTimeDataOrError return status: ", + tiled_run_time_data_or_error.status().message())}; + } if (const auto* fusion_decision = - std::get_if(&symbolic_tile_analysis_or)) { + std::get_if(&*tiled_run_time_data_or_error)) { return { absl::StrCat("Fusion can not be tiled with SymbolicTileAnalysis: ", fusion_decision->Explain())}; } + TiledRunTimeData tiled_run_time_data = + std::get(*std::move(tiled_run_time_data_or_error)); + + gpu_performance_model_cache_.Set( + *producer, *consumer, tiled_run_time_data.runtime_data.exec_time); + + { + absl::MutexLock lock(&block_level_parameters_cache_mutex_); + block_level_parameters_cache_[producer][consumer] = + tiled_run_time_data.block_level_parameters; + } + return {}; } @@ -550,7 +668,6 @@ class GpuPriorityFusionQueue { return it->second; } } - auto fusion_decision = CanFuse(producer, consumer); // The lock is required, because writing to a flat_hash_map is not @@ -595,9 +712,12 @@ class GpuPriorityFusionQueue { const se::DeviceDescription* device_info_; - // Reference to cost model that defines priorities in the queue. + // Cost Analysis that is used to estimate the cost of a fusion. GpuHloCostAnalysis cost_analysis_; + // Performance model that is used to estimate the run time of a fusion. + GpuPerformanceModelWithIndexingAnalysis gpu_indexing_performance_model_; + // The priority queue of producers, implemented as an ordered map, where a // key is a pair: the first element is the priority and the second element is // the unique ID of the instruction to break ties. @@ -639,6 +759,14 @@ class GpuPriorityFusionQueue { can_fuse_cache_; absl::Mutex can_fuse_cache_mutex_; + // Caches block level parameters for a (producer, consumer) pair. A cache + // entry is invalidated if producer or consumer is modified. + absl::flat_hash_map< + const HloInstruction*, + absl::flat_hash_map> + block_level_parameters_cache_; + absl::Mutex block_level_parameters_cache_mutex_; + GpuPerformanceModelCache gpu_performance_model_cache_; // Cache for `FusionFitsInBudget` to avoid recomputing expensive properties @@ -652,8 +780,7 @@ class GpuPriorityFusionQueue { } // namespace -/*static*/ bool GpuPriorityFusion::IsExpensive( - const HloInstruction& instruction) { +/*static*/ bool PriorityFusion::IsExpensive(const HloInstruction& instruction) { // Some floating-point math ops are cheap on the GPU. switch (instruction.opcode()) { case HloOpcode::kDivide: @@ -686,15 +813,15 @@ bool IsSmallConstant(const HloInstruction* instr) { ShapeUtil::ElementsIn(instr->shape()) <= 1; } -bool GpuPriorityFusion::ConsumeFuel(HloInstruction* producer, - HloInstruction* consumer) { +bool PriorityFusion::ConsumeFuel(HloInstruction* producer, + HloInstruction* consumer) { return xla::ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] { return absl::StrFormat("Not fusing producer %s with consumer %s", producer->name(), consumer->name()); }); }; -absl::StatusOr GpuPriorityFusion::Run( +absl::StatusOr PriorityFusion::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool dump_enabled = @@ -740,7 +867,7 @@ absl::StatusOr GpuPriorityFusion::Run( for (auto* computation : fusible_computations) { CHECK(!computation->IsFusionComputation()); - auto fusion_queue = std::make_unique( + auto fusion_queue = std::make_unique( computation, cost_analysis_options_, &device_info_, fusion_process_dump_.get(), thread_pool_, &mlir_context_, fusion_analysis_cache_, triton_softmax_priority_fusion_enabled); @@ -748,6 +875,10 @@ absl::StatusOr GpuPriorityFusion::Run( while (fusion_queue->DequeueNextProducer()) { auto producer = fusion_queue->current_producer(); + absl::flat_hash_map + block_level_parameters_map = + fusion_queue->GetBlockLevelParametersMap(producer); + for (auto* consumer : fusion_queue->current_consumers()) { // Don't fuse into single bitcasts. We ignore them in the check // CanFuseWithAllNonBitcastUsers(), so we need to check it here. @@ -764,6 +895,12 @@ absl::StatusOr GpuPriorityFusion::Run( fusion_queue->OnFusingInstruction(fusion_instruction, producer, consumer); + auto backend_config_it = block_level_parameters_map.find(consumer); + if (backend_config_it != block_level_parameters_map.end()) { + TF_RETURN_IF_ERROR(fusion_instruction->set_backend_config( + GetTritonGpuBackendConfig(backend_config_it->second))); + } + changed = true; } @@ -773,7 +910,7 @@ absl::StatusOr GpuPriorityFusion::Run( TF_RETURN_IF_ERROR(computation->RemoveInstruction(producer)); } - fusion_queue->UpdatePriorities(); + TF_RETURN_IF_ERROR(fusion_queue->UpdatePriorities()); } // Fuse all constants. @@ -815,8 +952,8 @@ absl::StatusOr GpuPriorityFusion::Run( return changed; } -FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, - int64_t operand_index) { +FusionDecision PriorityFusion::ShouldFuse(HloInstruction* consumer, + int64_t operand_index) { // This method is called in `InstructionFusion::Run` right before fusion, but // it will always return true. Fusion decision are fully controlled by the // PriorityQueue. If the queue returns a producer that shouldn't be fused, @@ -824,7 +961,7 @@ FusionDecision GpuPriorityFusion::ShouldFuse(HloInstruction* consumer, return {}; } -HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( +HloInstruction::FusionKind PriorityFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { // Derive kInput/kLoop fusion kinds from fusion analysis. This shouldn't // matter but some passes downstream still query these instead of fusion @@ -846,15 +983,10 @@ HloInstruction::FusionKind GpuPriorityFusion::ChooseKind( } } -HloInstruction* GpuPriorityFusion::FuseInstruction( +HloInstruction* PriorityFusion::FuseInstruction( HloInstruction* fusion_instruction, HloInstruction* producer) { HloInstruction* result = fusion_instruction; if (producer->opcode() == HloOpcode::kFusion) { - if (IsGenericTritonFusion(*producer)) { - TF_CHECK_OK(fusion_instruction->set_backend_config( - *producer->backend_config())); - } - fusion_instruction->MergeFusionInstruction(producer); } else { result = InstructionFusion::FuseInstruction(fusion_instruction, producer); @@ -862,7 +994,7 @@ HloInstruction* GpuPriorityFusion::FuseInstruction( return result; } -std::unique_ptr GpuPriorityFusion::GetFusionQueue( +std::unique_ptr PriorityFusion::GetFusionQueue( HloComputation* computation) { return nullptr; } diff --git a/third_party/xla/xla/service/gpu/priority_fusion.h b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h similarity index 87% rename from third_party/xla/xla/service/gpu/priority_fusion.h rename to third_party/xla/xla/service/gpu/transforms/priority_fusion.h index 999eb78ceca245..fce2be535e23a3 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_PRIORITY_FUSION_H_ -#define XLA_SERVICE_GPU_PRIORITY_FUSION_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ #include @@ -41,12 +41,12 @@ limitations under the License. namespace xla { namespace gpu { -class GpuPriorityFusion : public InstructionFusion { +class PriorityFusion : public InstructionFusion { public: - GpuPriorityFusion(tsl::thread::ThreadPool* thread_pool, - const se::DeviceDescription& device, - GpuHloCostAnalysis::Options cost_analysis_options) - : InstructionFusion(GpuPriorityFusion::IsExpensive), + PriorityFusion(tsl::thread::ThreadPool* thread_pool, + const se::DeviceDescription& device, + GpuHloCostAnalysis::Options cost_analysis_options) + : InstructionFusion(PriorityFusion::IsExpensive), thread_pool_(thread_pool), device_info_(device), cost_analysis_options_(std::move(cost_analysis_options)), @@ -97,4 +97,4 @@ class GpuPriorityFusion : public InstructionFusion { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_PRIORITY_FUSION_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_PRIORITY_FUSION_H_ diff --git a/third_party/xla/xla/service/gpu/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc similarity index 84% rename from third_party/xla/xla/service/gpu/priority_fusion_test.cc rename to third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index 4f71a51b869b4f..b552fd9ade5366 100644 --- a/third_party/xla/xla/service/gpu/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/priority_fusion.h" +#include "xla/service/gpu/transforms/priority_fusion.h" #include @@ -74,20 +74,28 @@ class PriorityFusionTest : public HloTestBase { auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); auto analysis = HloFusionAnalysis::Create( - Cast(computation->FusionInstruction()), - &device_info); + *computation->FusionInstruction(), device_info); kinds.push_back(analysis.GetEmitterFusionKind()); } return kinds; } - GpuPriorityFusion priority_fusion_{ + PriorityFusion priority_fusion_{ /*thread_pool=*/nullptr, TestGpuDeviceInfo::RTXA6000DeviceInfo(), GpuHloCostAnalysis::Options{ShapeSizeBytesFunction(), /*per_second_rates=*/{}, /*count_multiple_input_accesses=*/true}}; }; +class PriorityFusionWithTritonEnabledTest : public PriorityFusionTest { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = PriorityFusionTest::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true); + return debug_options; + } +}; + TEST_F(PriorityFusionTest, FuseWithSharedArgument) { auto module = ParseAndReturnVerifiedModule(R"( HloModule test_module @@ -333,9 +341,10 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { broadcast.18310.clone.1 = f32[2048,24576]{1,0} broadcast(convert.30039.clone.1), dimensions={} multiply.12550.clone.1 = f32[2048,24576]{1,0} multiply(tanh.798.clone.1, broadcast.18310.clone.1) convert.29370.clone.1 = bf16[2048,24576]{1,0} convert(multiply.12550.clone.1) - bitcast.22330 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6582 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22330), dimensions={0,3,2,1} - convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6582) + bitcast.1 = bf16[2048,2048,12]{2,1,0} bitcast(convert.29370.clone.1) + transpose.6582 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6582) + convert.33705 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast) constant_10212 = f32[] constant(-2.38197633e+38) broadcast.22828 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10212), dimensions={} select.589 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22829, convert.33705, broadcast.22828) @@ -347,9 +356,9 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { bitcast.11069 = pred[2048,2048]{1,0} bitcast(predarg) broadcast.22825 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3} - bitcast.22331 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6580 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22331), dimensions={0,3,2,1} - convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6580) + transpose.6580 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast.2 = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6580) + convert.33703 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast.2) constant_10213 = f32[] constant(-2.38197633e+38) broadcast.22824 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_10213), dimensions={} select.587 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22825, convert.33703, broadcast.22824) @@ -362,9 +371,9 @@ TEST_F(PriorityFusionTest, DoNotFuseTransposeIntoReduce) { constant_468 = f32[] constant(-2.38197633e+38) broadcast.22833 = pred[1,12,2048,2048]{3,2,1,0} broadcast(bitcast.11069), dimensions={2,3} - bitcast.22332 = bf16[1,2048,2048,12]{3,2,1,0} bitcast(convert.29370.clone.1) - transpose.6584 = bf16[1,12,2048,2048]{3,2,1,0} transpose(bitcast.22332), dimensions={0,3,2,1} - convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(transpose.6584) + transpose.6584 = bf16[12,2048,2048]{2,1,0} transpose(bitcast.1), dimensions={2,1,0} + bitcast.3 = bf16[1,12,2048,2048]{3,2,1,0} bitcast(transpose.6584) + convert.33707 = f32[1,12,2048,2048]{3,2,1,0} convert(bitcast.3) broadcast.22832 = f32[1,12,2048,2048]{3,2,1,0} broadcast(constant_468), dimensions={} select.591 = f32[1,12,2048,2048]{3,2,1,0} select(broadcast.22833, convert.33707, broadcast.22832) broadcast.22821 = f32[1,12,2048,2048]{3,2,1,0} broadcast(reduce.1614), dimensions={1,2} @@ -856,14 +865,13 @@ TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2 } )"); + auto& debug_options = module->mutable_config().mutable_debug_options(); + debug_options.set_xla_gpu_mlir_emitter_level(3); EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); } -TEST_F(PriorityFusionTest, CanMergeTritonFusionWithBothProducerAndConsumer) { -#ifndef GOOGLE_CUDA - GTEST_SKIP() << "Triton fusion only enable for CUDA devices."; -#endif - +TEST_F(PriorityFusionWithTritonEnabledTest, + CanMergeTritonFusionWithBothProducerAndConsumer) { const std::string kHloText = R"( HloModule t add { @@ -896,13 +904,10 @@ ENTRY main { param_0 = f32[125]{0} parameter(0) param_1 = f32[125,127]{1,0} parameter(1) producer_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=producer_computation - triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton"}} + triton_softmax = f32[125,127]{1,0} fusion(producer_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} ROOT consumer_fusion = f32[125,127]{1,0} fusion(param_1, triton_softmax), kind=kLoop, calls=consumer_computation })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); - auto debug_options = module->config().debug_options(); - debug_options.set_xla_gpu_enable_triton_softmax_priority_fusion(true); - module->mutable_config().set_debug_options(debug_options); EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); EXPECT_TRUE(verifier().Run(module.get()).status().ok()); @@ -910,7 +915,140 @@ ENTRY main { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, GmockMatch(m::Fusion(m::Parameter(), m::Parameter()))); EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kCustom); - EXPECT_TRUE(IsGenericTritonFusion(*root)); + ASSERT_TRUE(IsGenericTritonFusion(*root)); + + EXPECT_TRUE(root->backend_config() + ->fusion_backend_config() + .has_block_level_fusion_config()); + EXPECT_EQ(root->backend_config() + ->fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + FuseTritonProducerWithTwoConsumers) { + const std::string kHloText = R"( +HloModule t +add { + Arg_0 = f32[] parameter(0) + Arg_1 = f32[] parameter(1) + ROOT add = f32[] add(Arg_0, Arg_1) +} + +producer_computation { + parameter_0 = f32[125]{0} parameter(0) + ROOT broadcast = f32[125,127] broadcast(parameter_0), dimensions={0} +} + +consumer_computation.1 { + parameter_0 = f32[125,127] parameter(0) + ROOT log = f32[125,127] log(parameter_0) +} + +consumer_computation.2 { + parameter_0 = f32[125,127] parameter(0) + ROOT exp = f32[125,127] exponential(parameter_0) +} + +ENTRY main { + param_0 = f32[125]{0} parameter(0) + producer_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=producer_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} + consumer_fusion.1 = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation.1 + consumer_fusion.2 = f32[125,127] fusion(producer_fusion), kind=kLoop, calls=consumer_computation.2 + ROOT tuple = (f32[125,127], f32[125,127]) tuple(consumer_fusion.1, consumer_fusion.2) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + EXPECT_TRUE(priority_fusion_.Run(module.get()).value()); + EXPECT_TRUE(verifier().Run(module.get()).status().ok()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + HloInstruction *fusion1, *fusion2; + EXPECT_THAT(root, GmockMatch(m::Tuple(m::Fusion(&fusion1, m::Parameter()), + m::Fusion(&fusion2, m::Parameter())))); + EXPECT_TRUE(IsGenericTritonFusion(*fusion1)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config1, + fusion1->backend_config()); + EXPECT_TRUE( + backend_config1.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config1.fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); + + EXPECT_TRUE(IsGenericTritonFusion(*fusion2)); + TF_ASSERT_OK_AND_ASSIGN(auto backend_config2, + fusion2->backend_config()); + EXPECT_TRUE( + backend_config2.fusion_backend_config().has_block_level_fusion_config()); + EXPECT_EQ(backend_config2.fusion_backend_config() + .block_level_fusion_config() + .output_tile_sizes_size(), + 2); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + TritonProducerNotSupported_DoNotFuse) { + const std::string kHloText = R"( +HloModule t + +producer_computation { + parameter_0 = c64[] parameter(0) + broadcast = c64[125,127] broadcast(parameter_0), dimensions={} + ROOT real = f32[125,127] real(broadcast) +} + +triton_computation { + parameter_0 = f32[125,127] parameter(0) + parameter_1 = f32[125,127] parameter(1) + ROOT add = f32[125,127] add(parameter_0, parameter_1) +} + +ENTRY main { + param_0 = c64[] parameter(0) + param_1 = f32[125,127] parameter(1) + producer_fusion = f32[125,127] fusion(param_0), kind=kLoop, calls=producer_computation + ROOT triton_fusion = f32[125,127] fusion(producer_fusion, param_1), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + // Triton does not support c64, so producer_fusion and triton_fusion and will + // not be fused. + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); +} + +TEST_F(PriorityFusionWithTritonEnabledTest, + TritonConsumerNotSupported_DoNotFuse) { + const std::string kHloText = R"( +HloModule t + +triton_computation { + parameter_0 = f32[] parameter(0) + ROOT boardcast = f32[125,127] broadcast(parameter_0), dimensions={} +} + +consumer_computation { + parameter_0 = c64[] parameter(0) + parameter_1 = f32[125,127] parameter(1) + broadcast = c64[125,127] broadcast(parameter_0), dimensions={} + real = f32[125,127] real(broadcast) + ROOT add = f32[125,127] add(real, parameter_1) +} + +ENTRY main { + param_0 = f32[] parameter(1) + param_1 = c64[] parameter(0) + triton_fusion = f32[125,127] fusion(param_0), kind=kCustom, calls=triton_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["1","127"],"num_warps":"1"}}} + ROOT consumer_fusion = f32[125,127] fusion(param_1, triton_fusion), kind=kLoop, calls=consumer_computation +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText)); + + // Triton does not support c64, so triton_fusion and consumer_fusion will not + // be fused. + EXPECT_FALSE(priority_fusion_.Run(module.get()).value()); } TEST_F(PriorityFusionTest, DoNotFuseInsideReducer) { diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc index 7f1f800d2bc3de..d33c849168151e 100644 --- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h similarity index 87% rename from third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h index fcecb460747cc7..4e74394052595a 100644 --- a/third_party/xla/xla/service/gpu/gpu_reduce_scatter_creator.h +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ -#define XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -40,4 +40,4 @@ class ReduceScatterCreator : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_REDUCE_SCATTER_CREATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCE_SCATTER_CREATOR_H_ diff --git a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc index b1d2734d9b0e49..39a2c72a10a213 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_reduce_scatter_creator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduce_scatter_creator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_reduce_scatter_creator.h" +#include "xla/service/gpu/transforms/reduce_scatter_creator.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc index ac5419cf28d872..8c2929c0787f54 100644 --- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h similarity index 88% rename from third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h index 03d6819081d5da..1630aecff00e76 100644 --- a/third_party/xla/xla/service/gpu/reduction_degenerate_dim_remover.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ -#define XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -53,4 +53,4 @@ class ReductionDegenerateDimRemover : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_DEGENERATE_DIM_REMOVER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DEGENERATE_DIM_REMOVER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc index bb6eb634db78a5..7a9b7fa3fdbe0e 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_degenerate_dim_remover_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_degenerate_dim_remover.h" +#include "xla/service/gpu/transforms/reduction_degenerate_dim_remover.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc index 8ab4fcf648a255..ca4fba4fac9403 100644 --- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h similarity index 80% rename from third_party/xla/xla/service/gpu/reduction_dimension_grouper.h rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h index 8ee4efd0cfd261..d179dcd6c78415 100644 --- a/third_party/xla/xla/service/gpu/reduction_dimension_grouper.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ -#define XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -32,12 +32,12 @@ namespace gpu { // // For example, // -// f[] out = reduce(f[10,20,30] input, dimensions={0,1,2}) +// out = f32[] reduce(f32[10,20,30] input, dimensions={0,1,2}) // // becomes: // -// f[600] tmp = f[600] bitcast(f[10,20,30] input) -// f[] out = reduce(f[600] tmp, dimensions={0}) +// tmp = f32[6000] bitcast(f32[10,20,30] input) +// out = f32[] reduce(f32[6000] tmp, dimensions={0}) // class ReductionDimensionGrouper : public HloModulePass { public: @@ -53,4 +53,4 @@ class ReductionDimensionGrouper : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_DIMENSION_GROUPER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_DIMENSION_GROUPER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc index fa149a13b940c0..afbbbec01d3c27 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_dimension_grouper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_dimension_grouper_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_dimension_grouper.h" +#include "xla/service/gpu/transforms/reduction_dimension_grouper.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc similarity index 99% rename from third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc index a91fdf7e387b7a..fd45f8b34ec55b 100644 --- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h similarity index 89% rename from third_party/xla/xla/service/gpu/reduction_layout_normalizer.h rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h index 7d2d207773e057..f6e2d7c200dd67 100644 --- a/third_party/xla/xla/service/gpu/reduction_layout_normalizer.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ -#define XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -51,4 +51,4 @@ class ReductionLayoutNormalizer : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_LAYOUT_NORMALIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_LAYOUT_NORMALIZER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc index 817d9c73c95b16..46f5e9320eadfc 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_layout_normalizer_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_layout_normalizer_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_layout_normalizer.h" +#include "xla/service/gpu/transforms/reduction_layout_normalizer.h" #include diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc index cd37319a47de30..dce9288888a8a5 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_splitter.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h similarity index 92% rename from third_party/xla/xla/service/gpu/reduction_splitter.h rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter.h index 7e7652500e6d3a..87520d3d7063b1 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ -#define XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -56,4 +56,4 @@ class ReductionSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_REDUCTION_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/reduction_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc index 13a5210fee2ee6..4b9f6fb130ed0f 100644 --- a/third_party/xla/xla/service/gpu/reduction_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/reduction_splitter.h" +#include "xla/service/gpu/transforms/reduction_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/rename_fusions.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc similarity index 98% rename from third_party/xla/xla/service/gpu/rename_fusions.cc rename to third_party/xla/xla/service/gpu/transforms/rename_fusions.cc index a2a3048a05655e..9ab62f68664ebd 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions.cc +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/rename_fusions.h" +#include "xla/service/gpu/transforms/rename_fusions.h" #include #include diff --git a/third_party/xla/xla/service/gpu/rename_fusions.h b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h similarity index 90% rename from third_party/xla/xla/service/gpu/rename_fusions.h rename to third_party/xla/xla/service/gpu/transforms/rename_fusions.h index c3065a4dbd1df5..5abcd6169cc9d1 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions.h +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_RENAME_FUSIONS_H_ -#define XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -44,4 +44,4 @@ class RenameFusions : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_RENAME_FUSIONS_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_RENAME_FUSIONS_H_ diff --git a/third_party/xla/xla/service/gpu/rename_fusions_test.cc b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/rename_fusions_test.cc rename to third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc index 60c97cf2ff9438..47470859f84d2e 100644 --- a/third_party/xla/xla/service/gpu/rename_fusions_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/rename_fusions_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/rename_fusions.h" +#include "xla/service/gpu/transforms/rename_fusions.h" #include diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc similarity index 96% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc index 771e8cbed8a9a0..3841f4a1551f77 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.cc +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sanitize_constant_names.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include @@ -29,7 +29,7 @@ namespace xla { namespace gpu { -absl::StatusOr GpuSanitizeConstantNames::Run( +absl::StatusOr SanitizeConstantNames::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { bool changed = false; diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h similarity index 84% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h index 08701a4fe3432d..f743137f764ffb 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names.h +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ -#define XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -28,7 +28,7 @@ namespace gpu { // Sanitizes HLO instruction names for the GPU backend. Currently, it only // replaces . and - in the HLO constant instruction names with _ to please the // LLVM PTX backend. -class GpuSanitizeConstantNames : public HloModulePass { +class SanitizeConstantNames : public HloModulePass { public: absl::string_view name() const override { return "sanitize-constant-names"; } @@ -41,4 +41,4 @@ class GpuSanitizeConstantNames : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SANITIZE_CONSTANT_NAMES_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SANITIZE_CONSTANT_NAMES_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc rename to third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc index 17f45dc100f684..8e9779003af6f5 100644 --- a/third_party/xla/xla/service/gpu/gpu_sanitize_constant_names_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sanitize_constant_names_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sanitize_constant_names.h" +#include "xla/service/gpu/transforms/sanitize_constant_names.h" #include #include @@ -44,7 +44,7 @@ TEST_F(SanitizeConstantNamesTest, InstructionNameWithHyphenSanitized) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); } @@ -59,7 +59,7 @@ TEST_F(SanitizeConstantNamesTest, InstructionNameWithDotSanitized) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); } @@ -74,7 +74,7 @@ TEST_F(SanitizeConstantNamesTest, NewInstructionNameRegisteredWithModule) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); HloInstruction *root = module->entry_computation()->root_instruction(); EXPECT_EQ(root->name(), "equal_to"); @@ -99,7 +99,7 @@ TEST_F(SanitizeConstantNamesTest, BufferSanitizedNameCollisionResolved) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - EXPECT_TRUE(GpuSanitizeConstantNames().Run(module.get()).value()); + EXPECT_TRUE(SanitizeConstantNames().Run(module.get()).value()); EXPECT_THAT(FindInstruction(module.get(), "equal_to_1"), GmockMatch(m::Constant())); EXPECT_THAT(FindInstruction(module.get(), "equal_to_2"), diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc similarity index 95% rename from third_party/xla/xla/service/gpu/gpu_scatter_expander.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_expander.cc index b03b340cb8bbd9..26eb2107087a0d 100644 --- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_scatter_expander.h" +#include "xla/service/gpu/transforms/scatter_expander.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" diff --git a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h similarity index 83% rename from third_party/xla/xla/service/gpu/gpu_scatter_expander.h rename to third_party/xla/xla/service/gpu/transforms/scatter_expander.h index 100350cb67ac01..f86b93235b2b5b 100644 --- a/third_party/xla/xla/service/gpu/gpu_scatter_expander.h +++ b/third_party/xla/xla/service/gpu/transforms/scatter_expander.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ -#define XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -26,7 +26,7 @@ namespace xla { class GpuScatterExpander : public ScatterExpander { public: // Although we pass kEliminateAllScatters, we override this behavior in - // InstruuctionMatchesPattern and select only some scatters to expand. + // InstructionMatchesPattern and select only some scatters to expand. GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {} absl::string_view name() const override { return "gpu_scatter_expander"; } @@ -37,4 +37,4 @@ class GpuScatterExpander : public ScatterExpander { } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SCATTER_EXPANDER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_EXPANDER_H_ diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc similarity index 99% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc index 9672bf259a328c..d9c1debacc5e27 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h similarity index 92% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier.h rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h index 349837747466b6..96f39b5fbed1a6 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier.h +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ -#define XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -55,4 +55,4 @@ class ScatterSliceSimplifier : public HloModulePass { } // namespace xla -#endif // XLA_SERVICE_GPU_SCATTER_SLICE_SIMPLIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCATTER_SLICE_SIMPLIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc index 281a4f0576e0c7..8f1c93c1ec31d0 100644 --- a/third_party/xla/xla/service/gpu/scatter_slice_simplifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scatter_slice_simplifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scatter_slice_simplifier.h" +#include "xla/service/gpu/transforms/scatter_slice_simplifier.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc similarity index 98% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc index a0af798118669d..9929b355345b4d 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.cc +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_schedule_postprocessing.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include @@ -132,7 +132,7 @@ absl::StatusOr ProcessComputation( } // anonymous namespace -absl::StatusOr GpuSchedulePostprocessing::Run( +absl::StatusOr SchedulePostprocessing::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { if (!module->has_schedule()) return false; diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h similarity index 83% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h index d8eda81f257803..899098dfcce68f 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing.h +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ -#define XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -34,11 +34,9 @@ namespace gpu { // attribute value untouch for the operations with is_sync=true and for P2P // operations, assumming the runtime won't use those values. // -class GpuSchedulePostprocessing : public HloModulePass { +class SchedulePostprocessing : public HloModulePass { public: - absl::string_view name() const override { - return "gpu-schedule-postprocessing"; - } + absl::string_view name() const override { return "schedule-postprocessing"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -49,4 +47,4 @@ class GpuSchedulePostprocessing : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SCHEDULE_POSTPROCESSING_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULE_POSTPROCESSING_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc rename to third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc index 9d4956bdd5b4db..0c9c6e675e1fa7 100644 --- a/third_party/xla/xla/service/gpu/gpu_schedule_postprocessing_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/schedule_postprocessing_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_schedule_postprocessing.h" +#include "xla/service/gpu/transforms/schedule_postprocessing.h" #include @@ -32,9 +32,9 @@ namespace xla { namespace gpu { namespace { -using GpuSchedulePostprocessingTest = HloTestBase; +using SchedulePostprocessingTest = HloTestBase; -TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { +TEST_F(SchedulePostprocessingTest, SynchronousOpsNotChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -47,12 +47,12 @@ TEST_F(GpuSchedulePostprocessingTest, SynchronousOpsNotChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); } -TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) { +TEST_F(SchedulePostprocessingTest, P2POpsNotChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -71,12 +71,12 @@ TEST_F(GpuSchedulePostprocessingTest, P2POpsNotChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); } -TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { +TEST_F(SchedulePostprocessingTest, AsynchronousOpsChanged) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -89,7 +89,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_TRUE(changed); @@ -101,7 +101,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsChanged) { EXPECT_TRUE(collective_backend_config.no_parallel_custom_call()); } -TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { +TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -115,7 +115,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); @@ -127,7 +127,7 @@ TEST_F(GpuSchedulePostprocessingTest, AsynchronousOpsWithParallelCustomcall) { EXPECT_FALSE(collective_backend_config.no_parallel_custom_call()); } -TEST_F(GpuSchedulePostprocessingTest, +TEST_F(SchedulePostprocessingTest, AsynchronousOpsWithParallelNestedCustomcall) { constexpr absl::string_view kHloString = R"( HloModule module, is_scheduled=true @@ -146,7 +146,7 @@ TEST_F(GpuSchedulePostprocessingTest, )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule((kHloString))); - GpuSchedulePostprocessing pass; + SchedulePostprocessing pass; TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); EXPECT_FALSE(changed); diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc similarity index 83% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc index fbf1b2c5c58eb2..d7962130a2eeb8 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include @@ -23,6 +23,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "tsl/platform/statusor.h" namespace xla::gpu { @@ -37,7 +38,13 @@ absl::StatusOr AnnotateSchedulingInstructionNames( if (!inst->metadata().scheduling_name().empty()) { continue; } - inst->set_metadata_scheduling_name(std::string(inst->name())); + // We skip constants as we might have to sanitize them in order to satisfy + // LLVM backend. I.e. we allow `GpuSanitizeConstantNames` pass to run post + // scheduling. + if (inst->opcode() == HloOpcode::kConstant) { + continue; + } + inst->set_metadata_scheduling_name(inst->name()); changed = true; } return changed; diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h similarity index 87% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h index 3f9b769d3b85f0..03c21bbf09b784 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -41,4 +41,4 @@ class SchedulingInstructionAnnotator : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SCHEDULING_INSTRUCTION_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc similarity index 73% rename from third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc index 146607f790da52..abe8d50a63c09b 100644 --- a/third_party/xla/xla/service/gpu/scheduling_instruction_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/scheduling_instruction_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/scheduling_instruction_annotator.h" +#include "xla/service/gpu/transforms/scheduling_instruction_annotator.h" #include @@ -72,6 +72,40 @@ TEST_F(SchedulingInstructionAnnotatorTest, EXPECT_TRUE(filecheck_matches); } +TEST_F(SchedulingInstructionAnnotatorTest, SkipsAnnotatingConstants) { + constexpr absl::string_view kHloString = R"( + HloModule module, is_scheduled=true + + ENTRY entry { + p0 = f32[1] parameter(0) + c1 = f32[1] constant(42) + ROOT add0 = f32[1] add(p0, c1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + SchedulingInstructionAnnotator pass; + TF_ASSERT_OK_AND_ASSIGN(bool changed, pass.Run(module.get())); + + ASSERT_TRUE(changed); + constexpr absl::string_view kExpected = R"( +// CHECK: %[[P0:.+]] = {{.*}} parameter(0) +// CHECK-SAME: scheduling_name="[[P0]]" +// CHECK-NEXT: %[[C1:.+]] = f32[1] +// CHECK-NOT: scheduling_name +// CHECK-SAME: constant({42}) +// CHECK: %[[ADD0:.+]] = {{.*}} add(%[[P0]], %[[C1]]) +// CHECK-SAME: scheduling_name="[[ADD0]]" + )"; + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(false)), + kExpected)); + EXPECT_TRUE(filecheck_matches); +} + TEST_F(SchedulingInstructionAnnotatorTest, DoesNotAnnotateAllInstructionsWithTheirRespectiveNames) { constexpr absl::string_view kHloString = R"( diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc similarity index 99% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index c6bd79636b924c..fe43b285c834dd 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" #include #include @@ -47,6 +47,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_indexing_performance_model.h" #include "xla/service/gpu/model/symbolic_tile_analysis.h" #include "xla/service/gpu/model/tiled_hlo_computation.h" +#include "xla/service/gpu/model/triton_emitter_constraints.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -457,7 +458,8 @@ absl::StatusOr CanSymbolicTileAnalysisTileDiamondChain( mlir::MLIRContext context; SymbolicTileAnalysisOrError symbolic_tile_analysis_or_error = SymbolicTileAnalysis::AnalyzeComputation( - *softmax_fusion->called_computation(), &context); + *softmax_fusion->called_computation(), &context, + TritonEmitterConstraints::GetBuilder()); bool can_tile = std::holds_alternative( symbolic_tile_analysis_or_error); diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h similarity index 94% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton.h rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h index 9da8cc54daf400..36f780f43cd1e8 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton.h +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ -#define XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ #include #include @@ -98,4 +98,4 @@ class SoftmaxRewriterTriton : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_SOFTMAX_REWRITER_TRITON_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SOFTMAX_REWRITER_TRITON_H_ diff --git a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc rename to third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc index 8488031e19afdc..1b3139c9d40132 100644 --- a/third_party/xla/xla/service/gpu/softmax_rewriter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/softmax_rewriter_triton.h" +#include "xla/service/gpu/transforms/softmax_rewriter_triton.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc similarity index 96% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc index 217387c2548f60..b299db8d19316a 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include #include @@ -203,8 +203,7 @@ bool IsCubCompatibleSort(HloSortInstruction* sort_op) { VLOG(2) << "Sort dimension should be the minor one"; return false; } - if (Product(operand_shape.dimensions()) < - GpuSortRewriter::SortSizeThreshold()) { + if (Product(operand_shape.dimensions()) < SortRewriter::SortSizeThreshold()) { VLOG(2) << "Tensor shape size is too small to see an improvement"; return false; } @@ -239,7 +238,7 @@ HloInstruction* UnpackResultPair(HloSortInstruction* sort_op, } // namespace // Rewrites a single sort instruction with a custom call. -absl::StatusOr GpuSortRewriter::RunOnInstruction( +absl::StatusOr SortRewriter::RunOnInstruction( HloSortInstruction* sort_op) { // Get the sort tensor index and direction. SortComputationAnalysis sort_config = AnalyzeSortOp(*sort_op).value(); @@ -307,7 +306,7 @@ absl::StatusOr GpuSortRewriter::RunOnInstruction( } // Rewrites the sorts in the given computation into calls to CUB. -absl::StatusOr GpuSortRewriter::RunOnComputation( +absl::StatusOr SortRewriter::RunOnComputation( HloComputation* computation) { std::vector sort_ops; for (auto* inst : computation->instructions()) { @@ -325,17 +324,17 @@ absl::StatusOr GpuSortRewriter::RunOnComputation( } // Replace compatible sort operations with custom calls. -absl::StatusOr GpuSortRewriter::Run( +absl::StatusOr SortRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), before:\n" + module->ToString()); + XLA_VLOG_LINES(2, "SortRewriter::Run(), before:\n" + module->ToString()); bool changed = false; for (HloComputation* computation : module->MakeNonfusionComputations(execution_threads)) { TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); changed |= result; } - XLA_VLOG_LINES(2, "GpuSortRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "SortRewriter::Run(), after:\n" + module->ToString()); return changed; } diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h similarity index 88% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter.h index 51dba3c95d9efa..406df7a0472a27 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ -#define XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -31,9 +31,9 @@ namespace gpu { // Only a subset of shapes is supported - either a single tensor with a simple // compare function or a pair of tensors where keys are unsigned integers. -class GpuSortRewriter : public HloModulePass { +class SortRewriter : public HloModulePass { public: - absl::string_view name() const override { return "gpu-sort-rewriter"; } + absl::string_view name() const override { return "sort-rewriter"; } // CUB radix sort is slower than XLA sort on small shapes, so do not rewrite // tensors with sizes below this limit. @@ -60,4 +60,4 @@ class GpuSortRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_GPU_SORT_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_SORT_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc similarity index 85% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc index abacbc1111bfdb..e9bf60cdb4c9b7 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_stub.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_stub.cc @@ -13,30 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" - #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { -absl::StatusOr GpuSortRewriter::RunOnInstruction( +absl::StatusOr SortRewriter::RunOnInstruction( HloSortInstruction* sort_op) { return false; } -absl::StatusOr GpuSortRewriter::RunOnComputation( +absl::StatusOr SortRewriter::RunOnComputation( HloComputation* computation) { return false; } -absl::StatusOr GpuSortRewriter::Run( +absl::StatusOr SortRewriter::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { return false; diff --git a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc similarity index 91% rename from third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc index 69cdb92e39ed77..853de5b50ba6c6 100644 --- a/third_party/xla/xla/service/gpu/gpu_sort_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/sort_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_sort_rewriter.h" +#include "xla/service/gpu/transforms/sort_rewriter.h" #include @@ -35,18 +35,18 @@ namespace { namespace m = ::xla::match; -class GpuSortRewriterTest : public HloTestBase { +class SortRewriterTest : public HloTestBase { public: void SetUp() override { HloTestBase::SetUp(); - GpuSortRewriter::SetSortSizeThresholdForTestingOnly(1000); + SortRewriter::SetSortSizeThresholdForTestingOnly(1000); } bool RunModuleAndPass(HloModule* module) { auto cloned = module->Clone(); - bool changed = GpuSortRewriter().Run(module).value(); + bool changed = SortRewriter().Run(module).value(); if (changed) { - // Here we run an end to end test to make sure that GpuSortRewriter does + // Here we run an end to end test to make sure that SortRewriter does // not introduce an incorrect rewrite. To do this, we need to clone the // original module because the interpreter cannot process the already // optimized module. @@ -62,7 +62,7 @@ class GpuSortRewriterTest : public HloTestBase { }; // Basic sort: ascending. -TEST_F(GpuSortRewriterTest, SortKeysLessThan) { +TEST_F(SortRewriterTest, SortKeysLessThan) { constexpr char kHlo[] = R"( HloModule TestModule @@ -88,7 +88,7 @@ ENTRY %main { } // Basic sort: descending. -TEST_F(GpuSortRewriterTest, SortKeysGreaterThan) { +TEST_F(SortRewriterTest, SortKeysGreaterThan) { constexpr char kHlo[] = R"( HloModule TestModule @@ -114,7 +114,7 @@ ENTRY %main { } // Comparer swaps the parameter order -> direction is reversed. -TEST_F(GpuSortRewriterTest, SortKeysGreaterThanSwapped) { +TEST_F(SortRewriterTest, SortKeysGreaterThanSwapped) { constexpr char kHlo[] = R"( HloModule TestModule @@ -140,7 +140,7 @@ ENTRY %main { } // Sort a pair of tensors, keys go first. -TEST_F(GpuSortRewriterTest, SortPairs) { +TEST_F(SortRewriterTest, SortPairs) { constexpr char kHlo[] = R"( HloModule TestModule @@ -167,7 +167,7 @@ ENTRY %main { } // Sort a pair of tensors, keys go last. -TEST_F(GpuSortRewriterTest, SortPairsSwapped) { +TEST_F(SortRewriterTest, SortPairsSwapped) { constexpr char kHlo[] = R"( HloModule TestModule @@ -194,7 +194,7 @@ ENTRY %main { } // CUB sort doesn't support more than two tensors. -TEST_F(GpuSortRewriterTest, NoRewriteManyTensors) { +TEST_F(SortRewriterTest, NoRewriteManyTensors) { constexpr char kHlo[] = R"( HloModule TestModule @@ -221,7 +221,7 @@ ENTRY %main { } // Only 1D shapes are supported. -TEST_F(GpuSortRewriterTest, NoRewriteNonMinorSortDimension) { +TEST_F(SortRewriterTest, NoRewriteNonMinorSortDimension) { constexpr char kHlo[] = R"( HloModule TestModule @@ -241,7 +241,7 @@ ENTRY %main { } // Kernels are compiled for a subset of types. -TEST_F(GpuSortRewriterTest, NoRewriteUnsupportedType) { +TEST_F(SortRewriterTest, NoRewriteUnsupportedType) { constexpr char kHlo[] = R"( HloModule TestModule @@ -261,7 +261,7 @@ ENTRY %main { } // Comparer must be a simple function. -TEST_F(GpuSortRewriterTest, NoRewriteComplexComparer) { +TEST_F(SortRewriterTest, NoRewriteComplexComparer) { constexpr char kHlo[] = R"( HloModule TestModule @@ -282,7 +282,7 @@ ENTRY %main { } // Comparer must use adjacent input values. -TEST_F(GpuSortRewriterTest, NoRewriteMixedKeysValues) { +TEST_F(SortRewriterTest, NoRewriteMixedKeysValues) { constexpr char kHlo[] = R"( HloModule TestModule @@ -306,7 +306,7 @@ ENTRY %main { } // Small shapes do not see improvement from CUB sort. -TEST_F(GpuSortRewriterTest, NoRewriteSmallSize) { +TEST_F(SortRewriterTest, NoRewriteSmallSize) { constexpr char kHlo[] = R"( HloModule TestModule @@ -326,7 +326,7 @@ ENTRY %main { } // Basic sort: with batch dimension. -TEST_F(GpuSortRewriterTest, SortWithBatchDim) { +TEST_F(SortRewriterTest, SortWithBatchDim) { constexpr char kHlo[] = R"( HloModule TestModule @@ -352,7 +352,7 @@ ENTRY %main { } // Basic sort: with multiple batch dimensions. -TEST_F(GpuSortRewriterTest, SortWithMultipleBatchDims) { +TEST_F(SortRewriterTest, SortWithMultipleBatchDims) { constexpr char kHlo[] = R"( HloModule TestModule @@ -379,7 +379,7 @@ ENTRY %main { // Sort a pair of tensors (values, indices generated by iota) with a complex // compare. -TEST_F(GpuSortRewriterTest, SortPairsIotaComparerSimple) { +TEST_F(SortRewriterTest, SortPairsIotaComparerSimple) { constexpr char kHlo[] = R"( HloModule TestModule @@ -412,7 +412,7 @@ ENTRY %main { // Sort a pair of tensors (values, indices generated by iota) with a complex // compare computation that matches the output of the StableSortExpander pass. -TEST_F(GpuSortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) { +TEST_F(SortRewriterTest, SortPairsIotaComparerLikeStableSortExpander) { constexpr char kHlo[] = R"( HloModule TestModule @@ -444,8 +444,8 @@ ENTRY %main { m::GetTupleElement(m::CustomCall(), 1)))); } -TEST_F(GpuSortRewriterTest, SortSizeThresholdIsSet) { - EXPECT_EQ(GpuSortRewriter::SortSizeThreshold(), 1000); +TEST_F(SortRewriterTest, SortSizeThresholdIsSet) { + EXPECT_EQ(SortRewriter::SortSizeThreshold(), 1000); } } // namespace diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc similarity index 95% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index 7e54ea5aded6a0..68805b1ddc3c0c 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include #include @@ -120,6 +120,12 @@ absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( fusion_instruction->fused_instructions_computation(), absl::StrCat("wrapped_", wrapped_opcode, "_computation")); if (module->has_schedule()) { + // Update the scheduling names of the fusion and its root instruction + // to match their newly assigned instruction names during creation. + fusion_instruction->set_metadata_scheduling_name( + fusion_instruction->name()); + HloInstruction* root = fusion_instruction->fused_expression_root(); + root->set_metadata_scheduling_name(root->name()); module->schedule().replace_instruction(computation, instruction, fusion_instruction); } diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h similarity index 91% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator.h rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h index 8a0284adee390e..81816f88dabba2 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ -#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -57,4 +57,4 @@ class StreamAttributeAnnotator : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ANNOTATOR_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ANNOTATOR_H_ diff --git a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc similarity index 78% rename from third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index 17d9b2f1e212d7..c7d2ca59cff0e9 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_annotator.h" +#include "xla/service/gpu/transforms/stream_attribute_annotator.h" #include #include @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -211,21 +212,24 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { constexpr absl::string_view kHloString = R"( - HloModule ModuleWithAsyncDynamicUpdateSlice + HloModule ModuleWithAsyncDynamicUpdateSlice, is_scheduled=true ENTRY entry (param_0: f32[256,128,128], param_1: f32[1,128,128]) -> f32[256,128,128] { - param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) - param_1 = f32[1,128,128]{2,1,0} parameter(1) - izero = s32[] constant(0) + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"} + param_1 = f32[1,128,128]{2,1,0} parameter(1), metadata={scheduling_name="param_1"} + izero = s32[] constant(0), metadata={scheduling_name="izero"} dynamic-update-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, f32[1,128,128]{2,1,0}, s32[], s32[], s32[]), f32[256,128,128]{2,1,0:S(5)}, u32[]) - dynamic-update-slice-start(param_0, param_1, izero, izero, izero) + dynamic-update-slice-start(param_0, param_1, izero, izero, izero), + metadata={scheduling_name="dynamic-update-slice-start.2"} ROOT dynamic-update-slice-done.2 = f32[256,128,128]{2,1,0:S(5)} - dynamic-update-slice-done(dynamic-update-slice-start.2) + dynamic-update-slice-done(dynamic-update-slice-start.2), + metadata={scheduling_name="dynamic-update-slice-done.2"} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); + EXPECT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN(bool changed, StreamAttributeAnnotator().Run(module.get())); @@ -245,25 +249,51 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, fusion->backend_config()); EXPECT_EQ(gpu_config.operation_queue_id(), 1); + // Check if the schedule name the same as the instruction name + for (const auto* comp : module->computations()) { + for (const auto* instruction : comp->instructions()) { + if (!instruction->metadata().scheduling_name().empty()) { + EXPECT_EQ(instruction->name(), + instruction->metadata().scheduling_name()); + } + } + } + constexpr absl::string_view kExpectedSchedulingName = R"( +// CHECK: %wrapped_dynamic-update-slice_computation +// CHECK: ROOT %[[DYNAMIC_UPDATE_SLICE:.+]] = f32[256,128,128]{2,1,0:S(5)} dynamic-update-slice( +// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE]]"} +// CHECK: %[[DYNAMIC_UPDATE_SLICE_START:.+]] = {{.*}} fusion-start( +// CHECK-SAME: calls=%wrapped_dynamic-update-slice_computation +// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_UPDATE_SLICE_START]]"} + )"; + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(false)), + kExpectedSchedulingName)); + EXPECT_TRUE(filecheck_matches); } TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { constexpr absl::string_view kHloString = R"( - HloModule ModuleWithAsyncDynamicSlice + HloModule ModuleWithAsyncDynamicSlice, is_scheduled=true ENTRY entry (param_0: f32[256,128,128]) -> f32[1,128,128] { - param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0) - izero = s32[] constant(0) + param_0 = f32[256,128,128]{2,1,0:S(5)} parameter(0), metadata={scheduling_name="param_0"} + izero = s32[] constant(0), metadata={scheduling_name="izero"} dynamic-slice-start.2 = ((f32[256,128,128]{2,1,0:S(5)}, s32[], s32[], s32[]), f32[1,128,128]{2,1,0}, u32[]) - dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128} + dynamic-slice-start(param_0, izero, izero, izero), dynamic_slice_sizes={1,128,128}, + metadata={scheduling_name="dynamic-slice-start.2"} ROOT dynamic-slice-done.2 = f32[1,128,128]{2,1,0} - dynamic-slice-done(dynamic-slice-start.2) + dynamic-slice-done(dynamic-slice-start.2), + metadata={scheduling_name="dynamic-slice-done.2"} } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); + EXPECT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN(bool changed, StreamAttributeAnnotator().Run(module.get())); EXPECT_TRUE(changed); @@ -282,6 +312,29 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { TF_ASSERT_OK_AND_ASSIGN(GpuBackendConfig gpu_config, fusion->backend_config()); EXPECT_EQ(gpu_config.operation_queue_id(), 1); + // Check if the schedule name the same as the instruction name + for (const auto* comp : module->computations()) { + for (const auto* instruction : comp->instructions()) { + if (!instruction->metadata().scheduling_name().empty()) { + EXPECT_EQ(instruction->name(), + instruction->metadata().scheduling_name()); + } + } + } + constexpr absl::string_view kExpectedSchedulingName = R"( +// CHECK: %wrapped_dynamic-slice_computation +// CHECK: ROOT %[[DYNAMIC_SLICE:.+]] = f32[1,128,128]{2,1,0} dynamic-slice( +// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE]]"} +// CHECK: %[[DYNAMIC_SLICE_START:.+]] = {{.*}} fusion-start( +// CHECK-SAME: calls=%wrapped_dynamic-slice_computation +// CHECK-SAME: metadata={scheduling_name="[[DYNAMIC_SLICE_START]]"} + )"; + TF_ASSERT_OK_AND_ASSIGN( + bool filecheck_matches, + RunFileCheck( + module->ToString(HloPrintOptions().set_print_operand_shape(false)), + kExpectedSchedulingName)); + EXPECT_TRUE(filecheck_matches); } } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc similarity index 97% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc index 822c6473dba483..be0eb6fc7ac5e0 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/string_view.h" diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h similarity index 88% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h index 95fe7bba66508e..157b57913b6b71 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ -#define XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -46,4 +46,4 @@ class StreamAttributeAsyncWrapper : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_STREAM_ATTRIBUTE_ASYNC_WRAPPER_H_ diff --git a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc similarity index 97% rename from third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc rename to third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc index 8b3dcb23eac7bc..32ed4c50c57ca1 100644 --- a/third_party/xla/xla/service/gpu/stream_attribute_async_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_async_wrapper_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/stream_attribute_async_wrapper.h" +#include "xla/service/gpu/transforms/stream_attribute_async_wrapper.h" #include diff --git a/third_party/xla/xla/service/gpu/topk_specializer.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc similarity index 98% rename from third_party/xla/xla/service/gpu/topk_specializer.cc rename to third_party/xla/xla/service/gpu/transforms/topk_specializer.cc index bd01a076cc1711..1cc6206ee8908a 100644 --- a/third_party/xla/xla/service/gpu/topk_specializer.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_specializer.h" +#include "xla/service/gpu/transforms/topk_specializer.h" #include diff --git a/third_party/xla/xla/service/gpu/topk_specializer.h b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h similarity index 88% rename from third_party/xla/xla/service/gpu/topk_specializer.h rename to third_party/xla/xla/service/gpu/transforms/topk_specializer.h index 5b57f57b77bba7..e3ec5658f497cd 100644 --- a/third_party/xla/xla/service/gpu/topk_specializer.h +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ -#define XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -38,4 +38,4 @@ class TopkSpecializer : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_TOPK_SPECIALIZER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPECIALIZER_H_ diff --git a/third_party/xla/xla/service/gpu/topk_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/topk_test.cc rename to third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc index 43e25b8543cc61..96d7e49bade1c4 100644 --- a/third_party/xla/xla/service/gpu/topk_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_specializer_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "xla/service/gpu/transforms/topk_specializer.h" + #include #include @@ -33,7 +35,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/topk_specializer.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/platform_util.h" #include "xla/service/topk_rewriter.h" diff --git a/third_party/xla/xla/service/gpu/topk_splitter.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/topk_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/topk_splitter.cc index d20116dd22dd7c..41ba13500c4182 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_splitter.h" +#include "xla/service/gpu/transforms/topk_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/topk_splitter.h b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h similarity index 91% rename from third_party/xla/xla/service/gpu/topk_splitter.h rename to third_party/xla/xla/service/gpu/transforms/topk_splitter.h index 8fee2dc4975dbd..c6fe4290d7e225 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TOPK_SPLITTER_H_ -#define XLA_SERVICE_GPU_TOPK_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ #include @@ -49,4 +49,4 @@ class TopKSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TOPK_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TOPK_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/topk_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc similarity index 99% rename from third_party/xla/xla/service/gpu/topk_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc index 834185f990956c..8236c26d4056ae 100644 --- a/third_party/xla/xla/service/gpu/topk_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/topk_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/topk_splitter.h" +#include "xla/service/gpu/transforms/topk_splitter.h" #include diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc new file mode 100644 index 00000000000000..d81d3be88b4273 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/permutation_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace gpu { + +class TransposeDimensionGroupVisitor : public DfsHloRewriteVisitor { + public: + absl::Status HandleTranspose(HloInstruction *transpose) override { + VLOG(4) << "Input: " << transpose->ToString(); + absl::InlinedVector permutation; + auto normalized_dims = ShapeUtil::GetNormalizedLogicalTransposeShape( + transpose->shape(), transpose->dimensions(), permutation); + if (!normalized_dims.has_value() || + normalized_dims == transpose->shape().dimensions()) { + return absl::OkStatus(); + } + auto normalized_operand_dims = + ComposePermutations(*normalized_dims, InversePermutation(permutation)); + Shape grouped_operand_shape = ShapeUtil::MakeShapeWithDescendingLayout( + transpose->shape().element_type(), normalized_operand_dims); + auto new_operand = transpose->AddInstruction(HloInstruction::CreateBitcast( + grouped_operand_shape, transpose->mutable_operand(0))); + Shape grouped_shape = ShapeUtil::MakeShapeWithDescendingLayout( + transpose->shape().element_type(), *normalized_dims); + auto new_transpose = + transpose->AddInstruction(HloInstruction::CreateTranspose( + grouped_shape, new_operand, permutation)); + VLOG(5) << "Generated new transpose: " << new_transpose->ToString(); + return ReplaceWithNewInstruction( + transpose, + HloInstruction::CreateBitcast(transpose->shape(), new_transpose)); + } +}; + +absl::StatusOr TransposeDimensionGrouper::Run( + HloModule *module, + const absl::flat_hash_set &execution_threads) { + TF_ASSIGN_OR_RETURN( + bool changed, + TransposeDimensionGroupVisitor().RunOnModule(module, execution_threads)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h new file mode 100644 index 00000000000000..c07ada3c39d7a7 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper.h @@ -0,0 +1,57 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Groups dimensions that are adjacent (logically and physically) in the +// transpose operand and the transpose output. +// +// Precondition: LayoutNormalization has been run (physical proximity and +// logical proximity become the same). +// +// For example, +// +// out = f32[30,10,20] transpose(f32[10,20,30] input, dimensions={2,0,1}) +// +// becomes: +// +// tmp = f32[200,30] bitcast(f32[10,20,30] input) +// transpose = f32[30,200] transpose(f32[200,30] tmp, dimensions={1,0}) +// out = f32[30,0,20] bitcast(f32[30,200] transpose) +// +class TransposeDimensionGrouper : public HloModulePass { + public: + absl::string_view name() const override { + return "transpose-dimension-grouper"; + } + using HloPassInterface::Run; + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRANSPOSE_DIMENSION_GROUPER_H_ diff --git a/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc new file mode 100644 index 00000000000000..bbcf3dbe68dcf8 --- /dev/null +++ b/third_party/xla/xla/service/gpu/transforms/transpose_dimension_grouper_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/transpose_dimension_grouper.h" + +#include + +#include "absl/strings/string_view.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/test.h" + +namespace xla { + +namespace { + +class TransposeDimensionGrouperTest : public HloTestBase { + public: + void CheckDimensionGrouper(absl::string_view hlo, + std::optional expected) { + RunAndFilecheckHloRewrite(hlo, gpu::TransposeDimensionGrouper{}, expected); + } +}; + +TEST_F(TransposeDimensionGrouperTest, TransposeWithGrouping) { + const char* hlo = R"( +HloModule TransposeWithGrouping + +ENTRY main { + input = f32[100,1,10,32,2]{4,3,2,1,0} parameter(0) + ROOT out = f32[10,1,32,100,2]{4,3,2,1,0} transpose(input), dimensions={2,1,3,0,4} +} +)"; + + CheckDimensionGrouper(hlo, + R"( +// CHECK: [[input_0:%[^ ]+]] = f32[100,1,10,32,2]{4,3,2,1,0} parameter(0) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[100,320,2]{2,1,0} bitcast([[input_0]]) +// CHECK: [[transpose:%[^ ]+]] = f32[320,100,2]{2,1,0} transpose([[bitcast_1]]), dimensions={1,0,2} +// CHECK: ROOT {{.*}} = f32[10,1,32,100,2]{4,3,2,1,0} bitcast([[transpose]]) + )"); +} + +// TODO(b/328656780): Do not normalize to 3D once the emitter supports any +// number of dimensions. +TEST_F(TransposeDimensionGrouperTest, Normalize2DTo3D) { + const char* hlo = R"( +HloModule TransposeWithGrouping + +ENTRY main { + input = f32[50,20,30]{2,1,0} parameter(0) + ROOT out = f32[20,30,50]{2,1,0} transpose(input), dimensions={1,2,0} +} +)"; + + CheckDimensionGrouper(hlo, + R"( +// CHECK: [[input_0:%[^ ]+]] = f32[50,20,30]{2,1,0} parameter(0) +// CHECK: [[bitcast_1:%[^ ]+]] = f32[1,50,600]{2,1,0} bitcast([[input_0]]) +// CHECK: [[transpose:%[^ ]+]] = f32[1,600,50]{2,1,0} transpose([[bitcast_1]]), dimensions={0,2,1} +// CHECK: ROOT {{.*}} = f32[20,30,50]{2,1,0} bitcast([[transpose]]) + )"); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc similarity index 99% rename from third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc index b54d006947f9c1..fb023fc8cc693f 100644 --- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" #include #include @@ -374,7 +374,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { se::GpuComputeCapability gpu_version_; }; -absl::StatusOr GpuTreeReductionRewriter::Run( +absl::StatusOr TreeReductionRewriter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); diff --git a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h similarity index 86% rename from third_party/xla/xla/service/gpu/tree_reduction_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h index 5f6edf8ac33e4e..7f57d211a8acbd 100644 --- a/third_party/xla/xla/service/gpu/tree_reduction_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h @@ -12,9 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ -#define XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ - +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -74,15 +73,13 @@ namespace gpu { // f32[A, Q, C] inner_reduce = reduce(reshaped, dimensions={2}) // f32[A, C] outer_reduce = reduce(inner_reduce, dimensions={1}) // -class GpuTreeReductionRewriter : public HloModulePass { +class TreeReductionRewriter : public HloModulePass { public: - explicit GpuTreeReductionRewriter(se::GpuComputeCapability gpu_version) + explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version) : gpu_version_(gpu_version) {} - ~GpuTreeReductionRewriter() override = default; - absl::string_view name() const override { - return "gpu-tree-reduction-rewriter"; - } + ~TreeReductionRewriter() override = default; + absl::string_view name() const override { return "tree-reduction-rewriter"; } using HloPassInterface::Run; absl::StatusOr Run( @@ -96,4 +93,4 @@ class GpuTreeReductionRewriter : public HloModulePass { } // end namespace gpu } // end namespace xla -#endif // XLA_SERVICE_GPU_TREE_REDUCTION_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TREE_REDUCTION_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc rename to third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc index c8357b479597c4..91f4481a202885 100644 --- a/third_party/xla/xla/service/gpu/tests/tree_reduction_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/tree_reduction_rewriter.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" #include @@ -34,11 +34,11 @@ class TreeReductionRewriterTest : public HloTestBase { RunAndFilecheckHloRewrite( hlo, #if TENSORFLOW_USE_ROCM - gpu::GpuTreeReductionRewriter{se::RocmComputeCapability { + gpu::TreeReductionRewriter{se::RocmComputeCapability { "908" }}, #else - gpu::GpuTreeReductionRewriter{se::CudaComputeCapability{8, 1}}, + gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}}, #endif expected); #elif TENSORFLOW_USE_ROCM diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc similarity index 97% rename from third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc rename to third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc index 2dcd36569b7073..e81bdae50a25bf 100644 --- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triangular_solve_rewriter.h" +#include "xla/service/gpu/transforms/triangular_solve_rewriter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h similarity index 91% rename from third_party/xla/xla/service/gpu/triangular_solve_rewriter.h rename to third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h index 6d4b1c14188a08..c52e0ffb545a3e 100644 --- a/third_party/xla/xla/service/gpu/triangular_solve_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/triangular_solve_rewriter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ -#define XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -57,4 +57,4 @@ class TriangularSolveRewriter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_TRIANGULAR_SOLVE_REWRITER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRIANGULAR_SOLVE_REWRITER_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc similarity index 97% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index 75c43feadd605c..10ae640f3659b1 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" #include #include @@ -30,8 +30,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/executable.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/ir_emission_utils.h" diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h similarity index 89% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h index 6d74f46c1cfda0..e3dc6ebe5dd9f7 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier.h +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ -#define XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ #include "absl/container/flat_hash_set.h" #include "absl/functional/any_invocable.h" @@ -23,8 +23,8 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_pass_interface.h" #include "xla/service/shaped_buffer.h" @@ -71,4 +71,4 @@ absl::Status ForAllTritonFusions( } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_TRITON_FUSION_NUMERICS_VERIFIER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_NUMERICS_VERIFIER_H_ diff --git a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc similarity index 96% rename from third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc rename to third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index 1d35d1927b2a58..0382577d0d0fb9 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/triton_fusion_numerics_verifier.h" +#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h" #include #include @@ -26,13 +26,13 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/primitive_util.h" -#include "xla/service/gpu/autotuner_compile_util.h" -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_compile_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/platform.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla::gpu { namespace { diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc similarity index 98% rename from third_party/xla/xla/service/gpu/variadic_op_splitter.cc rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc index f1371575b7d625..0712040a7d1029 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter.h b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h similarity index 88% rename from third_party/xla/xla/service/gpu/variadic_op_splitter.h rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h index 4449ce2a0bdcda..304afa1d80a605 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ -#define XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" @@ -40,4 +40,4 @@ class VariadicOpSplitter : public HloModulePass { } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_VARIADIC_OP_SPLITTER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_VARIADIC_OP_SPLITTER_H_ diff --git a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc similarity index 98% rename from third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc rename to third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc index 6d7b72eebe0ba3..1d726136a3a8ee 100644 --- a/third_party/xla/xla/service/gpu/variadic_op_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/variadic_op_splitter_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/variadic_op_splitter.h" +#include "xla/service/gpu/transforms/variadic_op_splitter.h" #include #include diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc similarity index 81% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc index 8f5e26124f24a4..04d5905c652467 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.cc @@ -13,9 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include +#include #include #include "absl/container/flat_hash_set.h" @@ -27,6 +28,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/hlo_creation_utils.h" @@ -48,15 +50,15 @@ namespace m = match; // and type conversions of FP8 operands into the bodies of their while loops, // i.e. rewrites // -// inputs --> dequant --> while loop {dynamic-slice/collective-permute/dot} +// inputs --> dequant --> while loop {collective-permute/dot/etc} // // into // -// inputs --> while loop {dequant --> dynamic-slice/collective-permute/dot}. -absl::Status ShiftDequantizationF8(const HloComputation* comp, - const std::array& gte) { - HloInstruction* while_instr = comp->WhileCallInstruction(); - if (!while_instr) { +// inputs --> while loop {dequant --> collective-permute/dot/etc}. +absl::Status ShiftDequantizationF8(HloComputation* while_body) { + HloInstruction* while_instr = while_body->WhileCallInstruction(); + // The input of the while loop will be modified and must have no other users. + if (!while_instr || while_instr->operand(0)->user_count() != 1) { return absl::OkStatus(); } @@ -105,39 +107,42 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, return absl::OkStatus(); } - // Identify the dot and collective-permute or dynamic-slice instructions in - // the all-gather or reduce-scatter patterns in while's body. - HloComputation* while_body = while_instr->while_body(); + // Identify the dot, get-tuple-element and collective-permute or dynamic-slice + // instructions in the all-gather or reduce-scatter patterns in while's body. HloComputation* while_condition = while_instr->while_condition(); HloInstruction* while_root = while_body->root_instruction(); - std::array dots, dyn_slices{nullptr, nullptr}, + std::array dots, gtes, dyn_slices{nullptr, nullptr}, coll_perms{nullptr, nullptr}; - if (Match( - while_root, - m::Tuple(m::CollectivePermute( - &coll_perms[1], m::CollectivePermute( - &coll_perms[0], m::Op().Is(gte[0]))), - m::Op().Is(gte[1]), - m::DynamicUpdateSlice( - m::DynamicUpdateSlice().WithOperand( - 1, m::Dot(&dots[0], m::Op().Is(gte[0]), - m::Op().Is(gte[1]))), - m::Dot(&dots[1], m::Op(), m::Op().Is(gte[1])), m::Op(), - m::Op(), m::Op()), - m::Op(), m::Op()))) { + if (Match(while_root, + m::Tuple(m::CollectivePermute( + &coll_perms[1], + m::CollectivePermute( + &coll_perms[0], + m::GetTupleElement(>es[0], m::Parameter(), 0))), + m::GetTupleElement(>es[1], m::Parameter(), 1), + m::DynamicUpdateSlice( + m::DynamicUpdateSlice().WithOperand( + 1, m::Dot(&dots[0], m::Op(), m::Op())), + m::Dot(&dots[1], m::Op(), m::Op()), m::Op(), m::Op(), + m::Op()), + m::Op(), m::Op())) && + dots[0]->operand(0) == gtes[0] && dots[0]->operand(1) == gtes[1] && + dots[1]->operand(1) == gtes[1]) { VLOG(5) << "Identified all-gather windowed einsum pattern."; } else if (Match( while_root, - m::Tuple(m::Op().Is(gte[0]), m::Op().Is(gte[1]), + m::Tuple(m::GetTupleElement(>es[0], m::Parameter(), 0), + m::GetTupleElement(>es[1], m::Parameter(), 1), m::AddAnyOrder( m::Dot(&dots[0], m::DynamicSlice(&dyn_slices[0]), - m::Op().Is(gte[1])), + m::Op()), m::Op()), m::CollectivePermute(m::AddAnyOrder( m::Dot(&dots[1], m::DynamicSlice(&dyn_slices[1]), - m::Op().Is(gte[1])), + m::Op()), m::Op())), - m::Op()))) { + m::Op())) && + dots[0]->operand(1) == gtes[1] && dots[1]->operand(1) == gtes[1]) { VLOG(5) << "Identified reduce-scatter windowed einsum pattern."; } else { VLOG(5) << "Unable to identify valid windowed einsum pattern."; @@ -165,14 +170,14 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, } // In the while body, replace the existing get-tuple-element instructions - // retrieving BF16/FP16/FP32 dot operands with dequantized get-tuple-element + // retrieving BF16/FP16/FP32 dot operands with get-tuple-element // instructions retrieving FP8 dot operands from the input tuple. HloInstruction* body_param = while_body->parameter_instruction(0); for (int k = 0; k < 2; ++k) { TF_ASSIGN_OR_RETURN(HloInstruction * operand_f8, MakeGetTupleElementHlo(body_param, k)); - if (while_root->operand(k) == gte[k]) { + if (while_root->operand(k) == gtes[k]) { TF_RETURN_IF_ERROR( while_root->ReplaceOperandWithDifferentShape(k, operand_f8)); ShapeUtil::UpdateTupleShape(operand_f8->shape(), k, @@ -191,7 +196,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // Dequantize the operands of the dots and dynamic-slices. HloInstruction* operand_f32 = - MakeConvertToHlo(operand_f8, gte[k]->shape().element_type()); + MakeConvertToHlo(operand_f8, gtes[k]->shape().element_type()); HloInstruction* broadcast_scale = MakeBroadcastHlo(operand_scale, {}, operand_f32->shape()); TF_ASSIGN_OR_RETURN( @@ -203,10 +208,10 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // operands. The order of dequantization and dynamic-slices will be // exchanged in gemm_rewriter.cc. for (int l = 0; l < 2; ++l) { - if (dots[l]->operand(k) == gte[k]) { + if (dots[l]->operand(k) == gtes[k]) { TF_RETURN_IF_ERROR(dots[l]->ReplaceOperandWith(k, operand_scaled)); } - if (dyn_slices[l] && dyn_slices[l]->operand(0) == gte[k]) { + if (dyn_slices[l] && dyn_slices[l]->operand(0) == gtes[k]) { TF_RETURN_IF_ERROR( dyn_slices[l]->ReplaceOperandWith(0, operand_scaled)); } @@ -216,7 +221,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // dots[1], which prevents it from being exchanged with dequantization in // gemm_rewriter.cc. Instead, directly insert the dequantization before // dots[1] here. - if (coll_perms[0] && coll_perms[0]->operand(0) == gte[k]) { + if (coll_perms[0] && coll_perms[0]->operand(0) == gtes[k]) { std::array coll_perms_f8{nullptr, nullptr}; // Change the type of both collective-permutes to FP8. coll_perms_f8[0] = @@ -228,7 +233,7 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, // Insert the dequantization between coll_perms[0] and dots[1]. HloInstruction* coll_perm0_f32 = - MakeConvertToHlo(coll_perms_f8[0], gte[k]->shape().element_type()); + MakeConvertToHlo(coll_perms_f8[0], gtes[k]->shape().element_type()); TF_ASSIGN_OR_RETURN(HloInstruction * x_scaled, MakeBinaryHlo(binaries[k]->opcode(), coll_perm0_f32, broadcast_scale)); @@ -243,17 +248,19 @@ absl::Status ShiftDequantizationF8(const HloComputation* comp, } // Update the shape of the while call in the parent computation. + HloInstruction* new_while_instr = while_instr->AddInstruction( + while_instr->CloneWithNewShape(while_root->shape())); TF_RETURN_IF_ERROR( - while_instr->ReplaceAllUsesWithDifferentShape(while_instr->AddInstruction( - while_instr->CloneWithNewShape(while_root->shape())))); + while_instr->ReplaceAllUsesWithDifferentShape(new_while_instr)); + while_instr->while_body()->SetWhileCallInstruction(new_while_instr); TF_RETURN_IF_ERROR(while_instr->parent()->RemoveInstruction(while_instr)); if (coll_perms[0]) { TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[1])); TF_RETURN_IF_ERROR(while_body->RemoveInstruction(coll_perms[0])); } - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[0])); - TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gte[1])); + TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[0])); + TF_RETURN_IF_ERROR(while_body->RemoveInstruction(gtes[1])); VLOG(5) << "FP8 dequantization moved into while loop."; return absl::OkStatus(); @@ -302,22 +309,11 @@ absl::StatusOr HandleRsWindowedEinsumLoop(HloComputation* comp, return changed; } for (auto inst : comp->MakeInstructionPostOrder()) { - HloInstruction* matched_dot; - std::array gte; // The dot we'd like to parallelize is consuming the second loop input // as RHS. - if (Match(inst, - m::Dot(&matched_dot, - m::DynamicSlice().WithOperand( - 0, m::GetTupleElement(>e[0], m::Parameter(), 0)), - m::GetTupleElement(>e[1], m::Parameter(), 1)))) { - // If present, move the dequantization of FP8 operands of the dot into the - // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom - // Call. - TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte)); - + if (Match(inst, m::Dot())) { // Dispatch the dot to additional compute stream. - TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; changed = true; } @@ -332,6 +328,10 @@ absl::StatusOr HandleRsWindowedEinsumLoop(HloComputation* comp, changed = true; } } + // If present, move the dequantization of FP8 operands of the dot into the + // while loop to allow e.g. gemm_rewriter.cc to fuse the dequantization and + // dot into an FP8 GEMM. + TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp)); return changed; } @@ -345,23 +345,15 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, return changed; } for (auto inst : comp->MakeInstructionPostOrder()) { - HloInstruction* matched_dot; - std::array gte; // The dot we'd like to parallelize is consuming the second loop input // as RHS and first loop input as LHS. - if (Match(inst, m::Dot(&matched_dot, - m::GetTupleElement(>e[0], m::Parameter(), 0), - m::GetTupleElement(>e[1], m::Parameter(), 1)))) { - // If present, move the dequantization of FP8 operands of the dot into the - // while loop to allow gemm_rewriter.cc to rewrite into an FP8 Custom - // Call. - TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp, gte)); - + if (Match(inst, m::Dot(m::GetTupleElement(m::Parameter(), 0), + m::GetTupleElement(m::Parameter(), 1)))) { // Dispatch the dot to additional compute stream. - TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(matched_dot, stream_id)); + TF_RETURN_IF_ERROR(UpdateDotAndConsumerConfig(inst, stream_id)); ++stream_id; TF_RETURN_IF_ERROR( - SetForceDelayForInstruction(matched_dot, /*force_delay=*/true)); + SetForceDelayForInstruction(inst, /*force_delay=*/true)); changed = true; } @@ -375,6 +367,11 @@ absl::StatusOr HandleAgWindowedEinsumLoop(HloComputation* comp, changed = true; } } + // If present, move the dequantization of FP8 operands of the dot into the + // while loop to allow e.g. gemm_rewriter.cc to fuse the dequantization and + // dot into an FP8 GEMM. + TF_RETURN_IF_ERROR(ShiftDequantizationF8(comp)); + return changed; } @@ -382,12 +379,11 @@ static int64_t GetAgActivationCacheIndex(const HloInstruction* while_loop) { const HloInstruction* loop_tuple = while_loop->operand(0); const Shape& tuple_shape = loop_tuple->shape(); CHECK(tuple_shape.IsTuple()); - return tuple_shape.tuple_shapes_size(); + return tuple_shape.tuple_shapes_size() - 1; } absl::Status ProcessWindowedEinsumLoopForActivationCaching( - GpuWindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop, - HloInstruction* ag_with_shared_operand) { + WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop) { HloInstruction* loop = ag_loop.loop; // Transform the while body to cache the allgathered result in the // output buffer to be consumed by the dot @@ -406,41 +402,10 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( // The full buffer that we will use to cache the accumulated activation // is the last operand in the output tuple. int64_t full_cache_buffer_index = GetAgActivationCacheIndex(loop); - std::vector new_input_shapes(input_shape.tuple_shapes().begin(), - input_shape.tuple_shapes().end()); - new_input_shapes.push_back(ag_with_shared_operand->shape()); - // Update body input shape - Shape new_input_shape = ShapeUtil::MakeTupleShape(new_input_shapes); - *input_tuple->mutable_shape() = new_input_shape; HloInstruction* full_buffer_output_gte = while_body->AddInstruction(HloInstruction::CreateGetTupleElement( - ag_with_shared_operand->shape(), input_tuple, - full_cache_buffer_index)); - - // Update condition input shape - HloComputation* cond_comp = loop->while_condition(); - HloInstruction* cond_input_tuple = cond_comp->parameter_instruction(0); - *cond_input_tuple->mutable_shape() = new_input_shape; - - // Update input to the while instruction in parent computation - HloInstruction* original_while_input = loop->mutable_operand(0); - HloComputation* parent_comp = loop->parent(); - std::vector new_operands( - original_while_input->operands().begin(), - original_while_input->operands().end()); - new_operands.push_back( - parent_comp->AddInstruction(HloInstruction::CreateBroadcast( - ag_with_shared_operand->shape(), - parent_comp->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::Zero(new_input_shapes[0].element_type()))), - {}))); - HloInstruction* new_while_input = - parent_comp->AddInstruction(HloInstruction::CreateTuple(new_operands)); - TF_RETURN_IF_ERROR( - loop->ReplaceOperandWithDifferentShape(0, new_while_input)); - TF_RETURN_IF_ERROR(parent_comp->ReplaceInstructionWithDifferentShape( - original_while_input, new_while_input)); - *loop->mutable_shape() = new_input_shape; + ShapeUtil::GetTupleElementShape(input_shape, full_cache_buffer_index), + input_tuple, full_cache_buffer_index)); HloInstruction* new_full_buffer_output = nullptr; // Find the DUS in the loop body and re-use the slice indices @@ -550,6 +515,7 @@ absl::Status ProcessWindowedEinsumLoopForActivationCaching( HloInstruction::CreateTuple(original_operands)); TF_RETURN_IF_ERROR( while_body->ReplaceInstructionWithDifferentShape(root, new_output_tuple)); + return absl::OkStatus(); } @@ -579,8 +545,7 @@ struct MatchedGemmA2aResult { class WindowedEinsumVisitor : public DfsHloRewriteVisitor { public: explicit WindowedEinsumVisitor( - std::vector& - all_ag_loops) + std::vector& all_ag_loops) : all_ag_loops_(all_ag_loops) {} absl::StatusOr MatchA2aGemmWithIntermediateReshapes( HloInstruction* dot, HloInstruction** lhs, HloInstruction** rhs) { @@ -673,65 +638,145 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { absl::Status HandleDot(HloInstruction* dot) override { CHECK_EQ(dot->opcode(), HloOpcode::kDot); HloComputation* comp = dot->parent(); - // Rewrites a allgather-dot pattern that shares the same operand - // with a windowed einsum loop to consume the output of the loop - // and remove the all-gather. - // Now that we have processed all loops, we can check if there are any - // allgather-dot pattern that we can optimize. We'd want to transform: + // Rewrites an allgather-dot pattern that shares the same operand with a + // windowed einsum loop to consume the output of the loop and remove the + // all-gather. Now that we have processed all loops, we can check if there + // are any allgather-dot pattern that we can optimize. We'd want to + // transform: // input // / | - // / | - // AG windowed loop - // / - // / - // dot + // dequantize | + // (optional) | + // / | + // AG windowed loop + // / + // / + // dot // to: - // input + // input // | // | - // windowed loop + // windowed loop // | + // dequantize + // (FP8) // | // dot // The windowed einsum loop will also be rewritten to output the full input // to be consumed by the dot. This is advantageous since the chained dot can // fully utilize all the resources on the GPU while comm is hidden by the - // first collective matmul loop. - for (GpuWindowedEinsumHandler::WindowedEinsumAgLoops ag_loop : + // first collective matmul loop. When the data type is FP8, input is + // dequantized, i.e. type converted and scaled, ahead of the all-gather. The + // dequantization is moved in WindowedEinsumVisitor between the windowed + // loop and the dot. + for (WindowedEinsumHandler::WindowedEinsumAgLoops& ag_loop : all_ag_loops_) { + HloComputation* comp = dot->parent(); HloInstruction* loop = ag_loop.loop; - HloInstruction* ag_operand = nullptr; - - if (Match(dot, m::Dot(m::AllGather(&ag_operand), m::Op())) || - Match(dot, m::Dot(m::Op(), m::AllGather(&ag_operand)))) { - HloInstruction* windowed_lhs = - loop->mutable_operand(0)->mutable_operand(0); - HloInstruction* ag_with_shared_operand = nullptr; - if (ag_operand && ag_operand->mutable_operand(0) == windowed_lhs) { - ag_with_shared_operand = ag_operand; + + HloInstruction* windowed_lhs = + loop->mutable_operand(0)->mutable_operand(0); + + // In the FP8 case, the all-gather operates on the dequantized + // windowed_lhs. The dequantization is shifted to the output of the while + // loop below. + HloInstruction *all_gather, *binary, *scale = nullptr; + auto all_gather_optionally_dequantized = m::AnyOf( + m::AllGather(&all_gather, + m::Divide(&binary, m::Convert(m::Op().Is(windowed_lhs)), + m::Broadcast(m::Op(&scale)))), + m::AllGather( + &all_gather, + m::MultiplyAnyOrder(&binary, m::Convert(m::Op().Is(windowed_lhs)), + m::Broadcast(m::Op(&scale)))), + m::AllGather(&all_gather, m::Op().Is(windowed_lhs))); + + if (!Match(dot, m::Dot(all_gather_optionally_dequantized, m::Op())) && + !Match(dot, m::Dot(m::Op(), all_gather_optionally_dequantized))) { + continue; + } + + if (scale) { + // When the loop contains an FP8 GEMM, a scalar scaling factor must be + // captured. + if (!ShapeUtil::IsScalar(scale->shape())) { + continue; } - if (!ag_with_shared_operand) { + // The element type of windowed_lhs must be a supported FP8 type. + if (windowed_lhs->shape().element_type() != F8E4M3FN && + windowed_lhs->shape().element_type() != F8E5M2) { continue; } + // The scaling multiplication or division must be in BF16, FP16 or FP32. + if (binary->shape().element_type() != BF16 && + binary->shape().element_type() != F16 && + binary->shape().element_type() != F32) { + continue; + } + } + + if (!ag_loop.consumed) { + // Add a broadcasted zero of the same type as windowed_lhs. This caches + // the accumulated activation inside the loop. + Literal zero_literal = + LiteralUtil::Zero(windowed_lhs->shape().element_type()); + HloInstruction* zero = comp->AddInstruction( + HloInstruction::CreateConstant(std::move(zero_literal))); + Shape zero_bcast_shape = ShapeUtil::ChangeElementType( + all_gather->shape(), windowed_lhs->shape().element_type()); + HloInstruction* zero_bcast = + MakeBroadcastHlo(zero, {}, zero_bcast_shape); + loop->mutable_operand(0)->AppendOperand(zero_bcast); + ShapeUtil::AppendShapeToTuple( + zero_bcast->shape(), loop->mutable_operand(0)->mutable_shape()); + + // Update the parameter tuples of while's body and condition + // computations. + for (HloComputation* while_comp : + {loop->while_body(), loop->while_condition()}) { + while_comp->ReplaceParameter( + 0, HloInstruction::CreateParameter( + 0, loop->mutable_operand(0)->shape(), + while_comp->parameter_instruction(0)->name())); + } + + // Update the shape of the while loop in the parent computation. + *loop->mutable_shape() = loop->operand(0)->shape(); + VLOG(5) << "Found all-gather that shares the same operand with a " "windowed einsum loop : " << loop->ToString(); - if (!ag_loop.consumed) { - TF_RETURN_IF_ERROR(ProcessWindowedEinsumLoopForActivationCaching( - ag_loop, ag_with_shared_operand)); - ag_loop.consumed = true; - } - int64_t cache_output_index = dot->operand_index(ag_with_shared_operand); - HloComputation* comp = dot->parent(); - HloInstruction* new_gte = - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - loop, GetAgActivationCacheIndex(loop) - 1)); TF_RETURN_IF_ERROR( - dot->ReplaceOperandWith(cache_output_index, new_gte)); - TF_RETURN_IF_ERROR(comp->RemoveInstruction(ag_with_shared_operand)); + ProcessWindowedEinsumLoopForActivationCaching(ag_loop)); + ag_loop.consumed = true; + } + + int64_t cache_output_index = dot->operand_index(all_gather); + HloInstruction* new_gte = + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + loop, GetAgActivationCacheIndex(loop))); + + HloInstruction* new_gte_scaled; + + if (scale) { + // In the FP8 case, insert the dequantization of windowed_lhs between + // the while loop and the dot. + HloInstruction* new_convert = + MakeConvertToHlo(new_gte, binary->shape().element_type()); + HloInstruction* bcast_scale = + MakeBroadcastHlo(scale, {}, new_convert->shape()); + TF_ASSIGN_OR_RETURN( + new_gte_scaled, + MakeBinaryHlo(binary->opcode(), new_convert, bcast_scale)); + } + + TF_RETURN_IF_ERROR(dot->ReplaceOperandWith( + cache_output_index, scale ? new_gte_scaled : new_gte)); + if (all_gather->user_count() == 0) { + TF_RETURN_IF_ERROR(comp->RemoveInstruction(all_gather)); } } // Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms @@ -1106,16 +1151,16 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor { } private: - std::vector& all_ag_loops_; + std::vector& all_ag_loops_; }; } // namespace -absl::StatusOr GpuWindowedEinsumHandler::Run( +absl::StatusOr WindowedEinsumHandler::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { XLA_VLOG_LINES( - 5, "GpuWindowedEinsumHandler::Run(), before:\n" + module->ToString()); + 5, "WindowedEinsumHandler::Run(), before:\n" + module->ToString()); bool changed = false; int64_t stream_id = hlo_query::NextChannelId(*module); @@ -1128,13 +1173,12 @@ absl::StatusOr GpuWindowedEinsumHandler::Run( changed = comp_result; } else if (comp->name().find(kWindowedEinsumAgLoopName) == 0) { VLOG(5) << "Processing computation: " << comp->name(); - TF_ASSIGN_OR_RETURN(bool comp_result, - HandleAgWindowedEinsumLoop(comp, stream_id)); + TF_ASSIGN_OR_RETURN(changed, HandleAgWindowedEinsumLoop(comp, stream_id)); all_ag_loops_.push_back( WindowedEinsumAgLoops(comp->WhileCallInstruction())); - changed = comp_result; } } + for (HloComputation* comp : module->MakeNonfusionComputations(execution_threads)) { WindowedEinsumVisitor visitor(all_ag_loops_); @@ -1142,8 +1186,8 @@ absl::StatusOr GpuWindowedEinsumHandler::Run( changed |= visitor.changed(); } - XLA_VLOG_LINES( - 5, "GpuWindowedEinsumHandler::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(5, + "WindowedEinsumHandler::Run(), after:\n" + module->ToString()); return changed; } diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h similarity index 86% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h index b511920f7f24b0..bcc7680e1b7ef5 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler.h +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ -#define XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ +#ifndef XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ #include @@ -35,11 +35,9 @@ namespace xla::gpu { // optimize it on GPU by annotating independent gemms with // stream ids in the backend config. By running them in different // streams, we can practically achieve overlap between gemms too. -class GpuWindowedEinsumHandler : public HloModulePass { +class WindowedEinsumHandler : public HloModulePass { public: - absl::string_view name() const override { - return "gpu-windowed-einsum-handler"; - } + absl::string_view name() const override { return "windowed-einsum-handler"; } struct WindowedEinsumAgLoops { explicit WindowedEinsumAgLoops(HloInstruction* loop) : loop(loop) {} @@ -63,4 +61,4 @@ class GpuWindowedEinsumHandler : public HloModulePass { } // namespace xla::gpu -#endif // XLA_SERVICE_GPU_GPU_WINDOWED_EINSUM_HANDLER_H_ +#endif // XLA_SERVICE_GPU_TRANSFORMS_WINDOWED_EINSUM_HANDLER_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc similarity index 85% rename from third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc rename to third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc index 6f23319980e90c..151b5b41b5b866 100644 --- a/third_party/xla/xla/service/gpu/gpu_windowed_einsum_handler_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/windowed_einsum_handler_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/gpu_windowed_einsum_handler.h" +#include "xla/service/gpu/transforms/windowed_einsum_handler.h" #include #include @@ -34,7 +34,7 @@ namespace { namespace m = ::xla::match; -using GpuWindowedEinsumHanlderTest = HloTestBase; +using WindowedEinsumHandlerTest = HloTestBase; HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) { for (auto inst : comp->instructions()) { @@ -45,7 +45,7 @@ HloInstruction* FindInstructionByName(HloComputation* comp, std::string name) { return nullptr; } -TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, AgLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,512,24576]{2,1,0}, bf16[24576,24576]{1,0})->bf16[2048,24576]{1,0}}, num_partitions=4 @@ -102,7 +102,7 @@ ENTRY test_main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -121,7 +121,7 @@ ENTRY test_main { cp1->backend_config()->force_earliest_schedule()); } -TEST_F(GpuWindowedEinsumHanlderTest, RsLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, RsLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[24576,24576]{1,0}, bf16[512,24576]{1,0}, bf16[2048,24576]{1,0})->bf16[512,24576]{1,0}}, num_partitions=4 @@ -180,7 +180,7 @@ ENTRY main.9_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -198,7 +198,7 @@ ENTRY main.9_spmd { cp1->backend_config()->force_earliest_schedule()); } -TEST_F(GpuWindowedEinsumHanlderTest, AgLoopsMultipleConsumersAreChained) { +TEST_F(WindowedEinsumHandlerTest, AgLoopsMultipleConsumersAreChained) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[2,512,24576]{2,1,0}, bf16[24576,24576]{1,0}, bf16[24576,24576]{1,0})->bf16[2,2048,24576]{2,1,0}}, num_partitions=4 @@ -259,7 +259,7 @@ ENTRY main.12_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -286,7 +286,7 @@ ENTRY main.12_spmd { m::Op(), m::Op(), m::Op(), m::Op()), m::Op(), m::Op(), m::Op(), m::Op())))); } -TEST_F(GpuWindowedEinsumHanlderTest, A2aGemmHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, A2aGemmHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,8192]{3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=8 @@ -350,7 +350,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -358,7 +358,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, GemmA2aHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,4,2048,32768]{3,2,1,0})->bf16[1,4,2048,8192]{3,2,1,0}}, num_partitions=4 @@ -422,7 +422,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched, @@ -430,7 +430,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, A2aTransposeLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, A2aTransposeLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,8192,32768]{2,1,0}, bf16[1,1,8192,4,1,2048]{5,4,3,2,1,0})->bf16[1,4,2048,32768]{3,2,1,0}}, num_partitions=4 @@ -504,7 +504,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -513,7 +513,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2, EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, GemmA2aTransposeLoopsHaveStreamIds) { +TEST_F(WindowedEinsumHandlerTest, GemmA2aTransposeLoopsHaveStreamIds) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[1,4,2048,32768]{3,2,1,0}, bf16[1,32768,8192]{2,1,0})->bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0}}, num_partitions=4 @@ -588,7 +588,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -597,7 +597,7 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204 EXPECT_TRUE(filecheck_matched); } -TEST_F(GpuWindowedEinsumHanlderTest, AllGatherF8) { +TEST_F(WindowedEinsumHandlerTest, AllGatherF8) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 @@ -660,7 +660,7 @@ ENTRY test_main { } )"; - RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(), + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( ; CHECK-LABEL: windowed_dot_general_body_ag ; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) @@ -716,7 +716,7 @@ ENTRY test_main { )"); } -TEST_F(GpuWindowedEinsumHanlderTest, ReduceScatterF8) { +TEST_F(WindowedEinsumHandlerTest, ReduceScatterF8) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f8e4m3fn[2,2048,24576]{2,1,0}, f32[], f32[])->f32[2,512,24576]{2,1,0}}, num_partitions=4 @@ -780,7 +780,7 @@ ENTRY main.9_spmd { } )"; - RunAndFilecheckHloRewrite(kHloString, GpuWindowedEinsumHandler(), + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), R"( ; CHECK-LABEL: windowed_dot_general_body_rs ; CHECK-NEXT: [[P0:%[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) parameter(0) @@ -814,13 +814,13 @@ ENTRY main.9_spmd { ; CHECK-DAG: lhs_contracting_dims={2}, ; CHECK-DAG: rhs_contracting_dims={0}, ; CHECK-DAG: backend_config={ -; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID:[0-9]+]]", +; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID0:[1-9][0-9]*]]", ; CHECK-DAG: "wait_on_operation_queues":[], ; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[ADD3:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[CP0]], [[DOT0]]), ; CHECK-DAG: backend_config={" ; CHECK-DAG: operation_queue_id":"0", -; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID]]"], +; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID0]]"], ; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[GTE6:[^ ]+]] = f32[2,512,24576]{2,1,0} get-tuple-element([[P0]]), index=3 ; CHECK-NEXT: [[ADD4:%[^ ]+]] = u32[] add([[GTE4]], [[PID]]) @@ -830,14 +830,137 @@ ENTRY main.9_spmd { ; CHECK-NEXT: [[DSLICE3:%[^ ]+]] = f32[2,512,24576]{2,1,0} dynamic-slice([[MUL0]], [[C0]], [[RESHAPE1]], [[C0]]), dynamic_slice_sizes={2,512,24576} ; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,512,24576]{2,1,0} dot([[DSLICE3]], [[MUL1]]), ; CHECK-DAG: lhs_contracting_dims={2}, -; CHECK-DAG: rhs_contracting_dims={0} +; CHECK-DAG: rhs_contracting_dims={0}, +; CHECK-DAG: backend_config={ +; CHECK-DAG: "operation_queue_id":"[[OPQUEUEID1:[1-9][0-9]*]]", +; CHECK-DAG: "wait_on_operation_queues":[], +; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[ADD5:%[^ ]+]] = f32[2,512,24576]{2,1,0} add([[GTE6]], [[DOT1]]) +; CHECK-DAG: backend_config={" +; CHECK-DAG: operation_queue_id":"0", +; CHECK-DAG: "wait_on_operation_queues":["[[OPQUEUEID1]]"], +; CHECK-DAG: "force_earliest_schedule":false} ; CHECK-NEXT: [[CP1:[^ ]+]] = f32[2,512,24576]{2,1,0} collective-permute([[ADD5]]), channel_id=10 ; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = (f8e4m3fn[2,2048,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,512,24576]{2,1,0}, f32[2,512,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[]) tuple([[GTE0]], [[GTE1]], [[ADD3]], [[CP1]], [[ADD0]], /*index=5*/[[GTE3]], [[GTE5]]) )"); } -TEST_F(GpuWindowedEinsumHanlderTest, +TEST_F(WindowedEinsumHandlerTest, AllGatherMultipleConsumersF8) { + constexpr absl::string_view kHloString = R"( +HloModule all_gather_multiple_consumers_f8, entry_computation_layout={(f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f8e4m3fn[24576,24576]{1,0}, f8e4m3fn[24576,24576]{1,0}, f32[], f32[], f32[], f32[])->f32[2,2048,24576]{2,1,0}}, num_partitions=4 +windowed_dot_general_body_ag { + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + lhs = f32[2,512,24576]{2,1,0} get-tuple-element(input), index=0 + permuted_lhs0 = f32[2,512,24576]{2,1,0} collective-permute(lhs), channel_id=2, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + permuted_lhs1 = f32[2,512,24576]{2,1,0} collective-permute(permuted_lhs0), channel_id=3, source_target_pairs={{0,3},{1,0},{2,1},{3,2}} + rhs = f32[24576,24576]{1,0} get-tuple-element(input), index=1 + partial_dot_output = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=2 + dot0 = f32[2,512,24576]{2,1,0} dot(lhs, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c0 = s32[] constant(0) + dot_update_slice_offsets = s32[4]{0} constant({0, 512, 1024, 1536}) + loop_counter = u32[] get-tuple-element(input), index=4 + partition_id = u32[] partition-id() + loop_counter_plus_partition_id = u32[] add(loop_counter, partition_id) + c4 = u32[] constant(4) + dot_update_slice_offsets_index0 = u32[] remainder(loop_counter_plus_partition_id, c4) + dot_update_slice_offset0 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index0), dynamic_slice_sizes={1} + dot_update_slice_offset_scalar0 = s32[] reshape(dot_update_slice_offset0) + updated_dot_output0 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(partial_dot_output, dot0, c0, dot_update_slice_offset_scalar0, c0) + dot1 = f32[2,512,24576]{2,1,0} dot(permuted_lhs0, rhs), lhs_contracting_dims={2}, rhs_contracting_dims={0} + c1 = u32[] constant(1) + loop_counter_plus_one = u32[] add(loop_counter, c1) + loop_counter_plus_partition_id_plus_one = u32[] add(loop_counter_plus_one, partition_id) + dot_update_slice_offsets_index1 = u32[] remainder(loop_counter_plus_partition_id_plus_one, c4) + dot_update_slice_offset1 = s32[1]{0} dynamic-slice(dot_update_slice_offsets, dot_update_slice_offsets_index1), dynamic_slice_sizes={1} + dot_update_slice_offset1_scalar = s32[] reshape(dot_update_slice_offset1) + updated_dot_output1 = f32[2,2048,24576]{2,1,0} dynamic-update-slice(updated_dot_output0, dot1, c0, dot_update_slice_offset1_scalar, c0) + pass_through = f32[2,2048,24576]{2,1,0} get-tuple-element(input), index=3 + next_loop_counter = u32[] add(loop_counter_plus_one, c1) + ROOT tuple = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(permuted_lhs1, rhs, updated_dot_output1, pass_through, next_loop_counter) +} // windowed_dot_general_body_ag + +windowed_dot_general_cond_ag { + input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) parameter(0) + loop_counter = u32[] get-tuple-element(input), index=4 + loop_limit = u32[] constant(4) + ROOT compare = pred[] compare(loop_counter, loop_limit), direction=LT +} + +ENTRY main { + lhs = f8e4m3fn[2,512,24576]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + rhs0 = f8e4m3fn[24576,24576]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + c0_f32 = f32[] constant(0) + c0_f32_bcast = f32[2,2048,24576]{2,1,0} broadcast(c0_f32), dimensions={} + c0_u32 = u32[] constant(0) + // Dequantization of LHS and RHS: + scale_lhs = f32[] parameter(4) + scale_lhs_bcast = f32[2,512,24576]{2,1,0} broadcast(scale_lhs), dimensions={} + lhs_f32 = f32[2,512,24576]{2,1,0} convert(lhs) + lhs_scaled = f32[2,512,24576]{2,1,0} multiply(lhs_f32, scale_lhs_bcast) + scale_rhs0 = f32[] parameter(5) + scale_rhs0_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs0), dimensions={} + rhs0_f32 = f32[24576,24576]{1,0} convert(rhs0) + rhs0_scaled = f32[24576,24576]{1,0} multiply(rhs0_f32, scale_rhs0_bcast) + // While loop of all-gather windowed einsum: + while_input = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) tuple(lhs_scaled, rhs0_scaled, c0_f32_bcast, c0_f32_bcast, c0_u32) + while = (f32[2,512,24576]{2,1,0}, f32[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[]) while(while_input), condition=windowed_dot_general_cond_ag, body=windowed_dot_general_body_ag + // Additional all-gather FP8 dot operating on a dequantized RHS and the LHS also consumed by the windowed einsum. + all-gather1 = f32[2,2048,24576]{2,1,0} all-gather(lhs_scaled), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + rhs1 = f8e4m3fn[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_rhs1 = f32[] parameter(6) + scale_rhs1_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs1), dimensions={} + rhs1_f32 = f32[24576,24576]{1,0} convert(rhs1) + rhs1_scaled = f32[24576,24576]{1,0} multiply(rhs1_f32, scale_rhs1_bcast) + dot1 = f32[2,2048,24576]{2,1,0} dot(all-gather1, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + // Another all-gather FP8 dot operating on a dequantized RHS and the LHS also consumed by the windowed einsum. + all-gather2 = f32[2,2048,24576]{2,1,0} all-gather(lhs_scaled), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={1}, use_global_device_ids=true + rhs2 = f8e4m3fn[24576,24576]{1,0} parameter(3), sharding={devices=[1,4]<=[4]} + scale_rhs2 = f32[] parameter(7) + scale_rhs2_bcast = f32[24576,24576]{1,0} broadcast(scale_rhs2), dimensions={} + rhs2_f32 = f32[24576,24576]{1,0} convert(rhs2) + rhs2_scaled = f32[24576,24576]{1,0} multiply(rhs2_f32, scale_rhs2_bcast) + dot2 = f32[2,2048,24576]{2,1,0} dot(all-gather2, rhs2_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT product = f32[2,2048,24576]{2,1,0} multiply(dot1, dot2) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloString)); + + RunAndFilecheckHloRewrite(kHloString, WindowedEinsumHandler(), + R"( +; CHECK-LABEL: %main +; CHECK: [[WHILE0:%[^ ]+]] = (f8e4m3fn[2,512,24576]{2,1,0}, f8e4m3fn[24576,24576]{1,0}, f32[2,2048,24576]{2,1,0}, f32[2,2048,24576]{2,1,0}, u32[], /*index=5*/f32[], f32[], f8e4m3fn[2,2048,24576]{2,1,0}) while([[TUPLE0:%[^ ]+]]), +; CHECK-DAG: condition=%windowed_dot_general_cond_ag, +; CHECK-DAG: body=%windowed_dot_general_body_ag +; CHECK: [[LHS1:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[WHILE0]]), index=7 +; CHECK-NEXT: [[LHS1_F32:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[LHS1]]) +; CHECK-NEXT: [[SCALE_LHS1_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[SCALE_LHS1:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[LHS1_SCALED:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[LHS1_F32]], [[SCALE_LHS1_BCAST]]) +; CHECK-NEXT: [[RHS1:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS1_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS1]]) +; CHECK: [[SCALE_RHS1_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS1:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[RHS1_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS1_F32]], [[SCALE_RHS1_BCAST]]) +; CHECK-NEXT: [[DOT1:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dot([[LHS1_SCALED]], [[RHS1_SCALED]]), +; CHECK-DAG: lhs_contracting_dims={2}, +; CHECK-DAG: rhs_contracting_dims={0} +; CHECK: [[LHS2:%[^ ]+]] = f8e4m3fn[2,2048,24576]{2,1,0} get-tuple-element([[WHILE0]]), index=7 +; CHECK-NEXT: [[LHS2_F32:%[^ ]+]] = f32[2,2048,24576]{2,1,0} convert([[LHS2]]) +; CHECK-NEXT: [[SCALE_LHS2_BCAST:%[^ ]+]] = f32[2,2048,24576]{2,1,0} broadcast([[SCALE_LHS2:%[^ ]+]]), dimensions={} +; CHECK-NEXT: [[LHS2_SCALED:%[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[LHS2_F32]], [[SCALE_LHS2_BCAST]]) +; CHECK-NEXT: [[RHS2:%[^ ]+]] = f8e4m3fn[24576,24576]{1,0} parameter(3), sharding={devices=[1,4]<=[4]} +; CHECK-NEXT: [[RHS2_F32:%[^ ]+]] = f32[24576,24576]{1,0} convert([[RHS2]]) +; CHECK-NEXT: [[SCALE_RHS2:%[^ ]+]] = f32[] parameter(7) +; CHECK-NEXT: [[SCALE_RHS2_BCAST:%[^ ]+]] = f32[24576,24576]{1,0} broadcast([[SCALE_RHS2]]), dimensions={} +; CHECK-NEXT: [[RHS2_SCALED:%[^ ]+]] = f32[24576,24576]{1,0} multiply([[RHS2_F32]], [[SCALE_RHS2_BCAST]]) +; CHECK-NEXT: [[DOT2:%[^ ]+]] = f32[2,2048,24576]{2,1,0} dot([[LHS2_SCALED]], [[RHS2_SCALED]]), +; CHECK-DAG: lhs_contracting_dims={2}, +; CHECK-DAG: rhs_contracting_dims={0} +; CHECK-NEXT: ROOT [[OUT:[^ ]+]] = f32[2,2048,24576]{2,1,0} multiply([[DOT1]], [[DOT2]]) +)"); +} + +TEST_F(WindowedEinsumHandlerTest, AgLoopsMultipleConsumersAreChainedWithShardedContratingDim) { constexpr absl::string_view kHloString = R"( HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(bf16[16,2048,512]{2,1,0}, bf16[4096,6288]{1,0}, bf16[16,2048,6288]{2,1,0})->bf16[4096,6288]{1,0}}, num_partitions=8 @@ -900,7 +1023,7 @@ ENTRY main.12_spmd { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - GpuWindowedEinsumHandler gpu_handler; + WindowedEinsumHandler gpu_handler; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get())); EXPECT_TRUE(changed); @@ -914,5 +1037,6 @@ ENTRY main.12_spmd { EXPECT_EQ(inst->operand(0)->tuple_index(), 5); EXPECT_EQ(inst->operand(0)->operand(0), ag_loop); } + } // namespace } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc index c911bf1462e353..c5cc0adc2879ff 100644 --- a/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/triton_fusion_analysis_test.cc @@ -23,7 +23,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/gemm_fusion.h" +#include "xla/service/gpu/transforms/gemm_fusion.h" #include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" diff --git a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc index 6a9a66539854dc..42ab73a15a0412 100644 --- a/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc +++ b/third_party/xla/xla/service/gpu/triton_tiling_propagation.cc @@ -42,6 +42,7 @@ limitations under the License. #include "xla/layout.h" #include "xla/permutation_util.h" #include "xla/service/gpu/fusions/triton/triton_support.h" +#include "xla/service/gpu/fusions/triton/triton_support_legacy.h" #include "xla/service/instruction_fusion.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -1011,15 +1012,10 @@ GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible( if (hlo.opcode() == HloOpcode::kPad) { return "Pads are not fused yet."; } - for (const HloInstruction* operand : hlo.operands()) { - if (!legacy_triton::IsTritonSupportedDataType( - operand->shape().element_type(), gpu_version)) { - return "Unsupported input data type."; - } - } - if (!legacy_triton::IsTritonSupportedDataType(hlo.shape().element_type(), - gpu_version)) { - return "Unsupported output data type."; + if (auto decision = + legacy_triton::IsTritonSupportedInstruction(hlo, gpu_version); + !decision.CanFuse()) { + return decision; } DimOrdersAndReqsOrError result_or_error = GetPropagatedDimOrdersAndRequirements(hlo, src_dim_order, diff --git a/third_party/xla/xla/service/gpu_compilation_environment_test.cc b/third_party/xla/xla/service/gpu_compilation_environment_test.cc index 072f66b147e287..e684f7c68e4995 100644 --- a/third_party/xla/xla/service/gpu_compilation_environment_test.cc +++ b/third_party/xla/xla/service/gpu_compilation_environment_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include #include "xla/parse_flags_from_env.h" #include "xla/service/compilation_environments.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" diff --git a/third_party/xla/xla/service/heap_simulator/BUILD b/third_party/xla/xla/service/heap_simulator/BUILD index 9af9bf0a8dcb00..bada0fdf1d597b 100644 --- a/third_party/xla/xla/service/heap_simulator/BUILD +++ b/third_party/xla/xla/service/heap_simulator/BUILD @@ -42,26 +42,27 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_live_range", "//xla/service:buffer_value", - "//xla/service:buffer_value_containers", "//xla/service:hlo_alias_analysis", "//xla/service:hlo_buffer", "//xla/service:hlo_dataflow_analysis", - "//xla/service:hlo_ordering", "//xla/service:hlo_proto_cc", "//xla/service:hlo_value", + "//xla/service:logical_buffer", "//xla/service:time_utils", - "//xla/service:tuple_points_to_analysis", - "//xla/service/memory_space_assignment:repacking", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) @@ -71,21 +72,25 @@ xla_cc_test( deps = [ ":allocation_block", ":heap_simulator", - "//xla:literal", - "//xla:status_macros", + "//xla:comparison_util", + "//xla:literal_util", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/service:buffer_value", - "//xla/service:hlo_ordering", + "//xla/service:hlo_alias_analysis", + "//xla/service:hlo_dataflow_analysis", + "//xla/service:hlo_module_config", "//xla/service:hlo_parser", "//xla/service:hlo_value", - "//xla/service:tuple_points_to_analysis", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc index a45528e6c1ac2f..a499fcf119c424 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.cc @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -34,23 +35,133 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/map_util.h" +#include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/hlo_dataflow_analysis.h" #include "xla/service/hlo_value.h" -#include "xla/service/memory_space_assignment/repacking.h" +#include "xla/service/logical_buffer.h" #include "xla/service/time_utils.h" #include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" namespace xla { +namespace { + +constexpr int64_t kMaxMemoryMapDimensionSize = 100; + +struct AsciiMemoryMapParameters { + int64_t memory_block_size = 1; + int64_t end_of_last_occupied_chunk = -1; +}; + +// Given a set of BufferIntervalTreeNodes, returns the best memory block size(to +// visually represent all chunks in a compact fashion) and the maximum chunk end +// of all occupied chunks. The best memory block size is the greatest common +// divisor of all chunk offsets and chunk ends. These are parameters required to +// construct a compact memory map. +AsciiMemoryMapParameters GetAsciiMemoryMapParameters( + std::vector& nodes) { + CHECK(!nodes.empty()); + int64_t min_chunk_offset = std::numeric_limits::max(); + int64_t end_of_last_occupied_chunk = -1; + int64_t memory_block_size = nodes.front()->chunk.offset; + for (const BufferIntervalTreeNode* node : nodes) { + min_chunk_offset = std::min(min_chunk_offset, node->chunk.offset); + end_of_last_occupied_chunk = + std::max(end_of_last_occupied_chunk, node->chunk.chunk_end()); + memory_block_size = std::gcd(memory_block_size, node->chunk.offset); + memory_block_size = std::gcd(memory_block_size, node->chunk.chunk_end()); + } + VLOG(3) << " min_chunk_offset: " << min_chunk_offset + << " end_of_last_occupied_chunk: " << end_of_last_occupied_chunk + << " memory_block_size: " << memory_block_size; + return {memory_block_size, end_of_last_occupied_chunk}; +} + +// Returns a memory map for the given time interval [start, end]. +// The memory map is a 2D array of size [n, m], where n is the number of memory +// blocks and m is the number of time steps. Each row represents a memory block +// and each column represents a time step. The value at (i, j) indicates whether +// there is a buffer occupying the entire memory block at time j. +std::vector> GetMemoryMap( + int64_t start, int64_t end, int64_t memory_block_size, + int64_t num_memory_blocks, + std::vector& nodes) { + int64_t total_time = end - start + 1; + std::vector> memory_map( + num_memory_blocks, std::vector(total_time, false)); + for (const BufferIntervalTreeNode* node : nodes) { + for (int64_t i = node->chunk.offset / memory_block_size; + i < node->chunk.chunk_end() / memory_block_size; ++i) { + for (int64_t j = std::max(node->start - start, int64_t{0}); + j <= std::min(node->end - start, end - start); ++j) { + memory_map[i][j] = true; + } + } + } + return memory_map; +} + +// Given a list of BufferIntervalTreeNodes, returns a string representation of +// the nodes. +std::string BufferIntervalTreeNodesToString( + absl::Span nodes) { + std::string output; + for (const BufferIntervalTreeNode* node : nodes) { + absl::StrAppend(&output, node->ToString(), "\n"); + } + return output; +} + +// Returns a string representation of the memory map of occupied memory blocks +// for the given time interval [start, end]. +std::string MemoryMapToString(int64_t start, int64_t end, + int64_t memory_block_size, int64_t group_size, + std::vector>& memory_map) { + int64_t num_memory_blocks = memory_map.size(); + int64_t total_time = memory_map.front().size(); + std::string output = "\n"; + absl::StrAppend(&output, "Memory map for time: [", start, ",", end, + "], memory_block_size: ", memory_block_size, + ", group_size: ", group_size, "\n\n"); + for (int64_t i = num_memory_blocks - 1; i >= 0; --i) { + for (int64_t j = 0; j < total_time; ++j) { + if (group_size && j % group_size == 0) { + absl::StrAppend(&output, " "); + } + absl::StrAppend(&output, memory_map[i][j] ? "#" : "."); + } + absl::StrAppend(&output, " ", std::to_string((i + 1) * memory_block_size), + "\n"); + } + for (int64_t j = start; j <= end; ++j) { + if (group_size && j % group_size == 0) { + absl::StrAppend(&output, " "); + } + absl::StrAppend(&output, std::to_string(j % 10)); + } + absl::StrAppend(&output, "\n\n"); + return output; +} + +} // namespace using absl::flat_hash_map; using absl::flat_hash_set; @@ -73,6 +184,11 @@ std::string HeapSimulator::Chunk::ToString() const { return absl::StrCat("[", offset, ",", chunk_end(), ")"); } +std::string BufferIntervalTreeNode::ToString() const { + return absl::StrCat("start: ", start, " end: ", end, + " chunk: ", chunk.ToString()); +} + bool HeapSimulator::Chunk::OverlapsWith(Chunk other_chunk) const { CHECK_NE(size, 0); CHECK_NE(other_chunk.size, 0); @@ -113,14 +229,12 @@ absl::StatusOr HeapSimulator::MinimumMemoryForModule( absl::StatusOr HeapSimulator::MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map* - memory_by_computation) { + const LogicalBuffer::SizeFunction& size_function) { TF_ASSIGN_OR_RETURN( HeapSimulator::Result result, HeapSimulator::Run(std::make_unique>(), computation, sequence, alias_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Options())); return result.heap_size; } @@ -161,11 +275,9 @@ absl::StatusOr> HeapSimulator::Run( const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, - const BufferValue::SizeFunction& size_fn, const Options& options, - const absl::flat_hash_map* - memory_by_computation) { + const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*schedule=*/nullptr, memory_by_computation); + /*schedule=*/nullptr); HloSchedule schedule(computation.parent()); schedule.set_sequence(&computation, instruction_sequence); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_live_range, @@ -185,7 +297,7 @@ absl::StatusOr> HeapSimulator::Run( const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, - /*schedule=*/schedule, nullptr); + /*schedule=*/schedule); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_live_range, HloLiveRange::Run(*schedule, alias_analysis, &computation)); @@ -386,19 +498,16 @@ absl::Status HeapSimulator::RunComputation( return absl::OkStatus(); } -HeapSimulator::HeapSimulator( - std::unique_ptr> algorithm, - const BufferValue::SizeFunction& size_fn, const Options& options, - const HloSchedule* schedule, - const absl::flat_hash_map* - memory_by_computation) +HeapSimulator::HeapSimulator(std::unique_ptr> algorithm, + const BufferValue::SizeFunction& size_fn, + const Options& options, + const HloSchedule* schedule) : no_fragmentation_stats_( std::make_unique>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), - schedule_(schedule), - memory_by_computation_(memory_by_computation) { + schedule_(schedule) { debug_trace_.set_whole_module_simulation(schedule_ != nullptr); } @@ -523,21 +632,10 @@ void NoFragmentationStatsHeap::Alloc(const BufferType* buffer, template void NoFragmentationStatsHeap::AccountForSubcomputationMemory( - const HloInstruction* instruction, int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) { + const HloInstruction* instruction, int64_t alloc_size_by_instruction) { // We only count the memory usage of the largest subcomputation, instead of // adding them all, because subcomputations won't execute in parallel. int64_t max_subcomputation_bytes = 0; - for (const auto* c : instruction->called_computations()) { - auto it = memory_by_computation.find(c); - if (it != memory_by_computation.end()) { - int64_t subcomputation_bytes = it->second; - if (subcomputation_bytes > max_subcomputation_bytes) { - max_subcomputation_bytes = subcomputation_bytes; - } - } - } if (max_subcomputation_bytes > 0 && (instruction->opcode() == HloOpcode::kWhile || instruction->opcode() == HloOpcode::kCall || @@ -848,6 +946,16 @@ bool BufferIntervalTree::Remove(int64_t start, int64_t end, std::vector BufferIntervalTree::ChunksOverlappingInTime( int64_t start, int64_t end) const { std::vector result; + for (const BufferIntervalTreeNode* node : + NodesOverlappingInTime(start, end)) { + result.push_back(node->chunk); + } + return result; +} + +std::vector +BufferIntervalTree::NodesOverlappingInTime(int64_t start, int64_t end) const { + std::vector result; if (root_ == nullptr) { return result; } @@ -863,7 +971,7 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( visiting_stack.push_back(top->left); } if (top->start <= end && top->end >= start) { - result.push_back(top->chunk); + result.push_back(top); } if (end < top->start) { continue; @@ -875,6 +983,51 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( return result; } +std::string BufferIntervalTree::NodesOverlappingInTimeToAsciiArt( + int64_t start, int64_t end, int64_t group_size) const { + std::vector nodes = + NodesOverlappingInTime(start, end); + if (nodes.empty()) { + return "No nodes overlapping in time. Memory is free!"; + } + auto [memory_block_size, end_of_last_occupied_chunk] = + GetAsciiMemoryMapParameters(nodes); + CHECK_GE(end_of_last_occupied_chunk, 0); + CHECK_NE(memory_block_size, 0); + int64_t total_time = end - start + 1; + int64_t num_memory_blocks = end_of_last_occupied_chunk / memory_block_size; + if (total_time > kMaxMemoryMapDimensionSize || + num_memory_blocks > kMaxMemoryMapDimensionSize) { + std::string output; + absl::StrAppend( + &output, + "\nCannot print memory usage to ASCII art. Printing nodes instead!\n\n", + BufferIntervalTreeNodesToString(nodes)); + return output; + } + std::vector> memory_map = + GetMemoryMap(start, end, memory_block_size, num_memory_blocks, nodes); + return MemoryMapToString(start, end, memory_block_size, group_size, + memory_map); +} + +std::vector BufferIntervalTree::MemoryUsedInInterval( + int64_t start, int64_t end) const { + int64_t total_time = end - start + 1; + CHECK_GE(total_time, 0); + std::vector nodes = + NodesOverlappingInTime(start, end); + std::vector memory_used_in_interval(total_time, 0); + for (const BufferIntervalTreeNode* node : nodes) { + int64_t node_start = std::max(node->start, start); + int64_t node_end = std::min(node->end, end); + for (int64_t time = node_start; time <= node_end; ++time) { + memory_used_in_interval[time - start] += node->chunk.size; + } + } + return memory_used_in_interval; +} + template std::string GlobalDecreasingSizeBestFitHeap::BufferInterval::ToString() const { diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator.h b/third_party/xla/xla/service/heap_simulator/heap_simulator.h index dfa62f018ae133..6d5f4558b6e6b4 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator.h +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator.h @@ -24,7 +24,6 @@ limitations under the License. #include #include #include -#include #include #include #include @@ -35,7 +34,9 @@ limitations under the License. #endif #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/functional/any_invocable.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_computation.h" @@ -43,16 +44,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_live_range.h" #include "xla/service/buffer_value.h" -#include "xla/service/buffer_value_containers.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_alias_analysis.h" -#include "xla/service/hlo_buffer.h" -#include "xla/service/hlo_dataflow_analysis.h" -#include "xla/service/hlo_ordering.h" #include "xla/service/hlo_value.h" -#include "xla/service/memory_space_assignment/repacking.h" -#include "xla/service/tuple_points_to_analysis.h" +#include "xla/service/logical_buffer.h" namespace xla { @@ -154,9 +150,7 @@ class HeapSimulator { static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, const HloAliasAnalysis& alias_analysis, - const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map* - memory_by_computation = nullptr); + const LogicalBuffer::SizeFunction& size_function); static absl::StatusOr MinimumMemoryForComputation( const HloComputation& computation, const HloInstructionSequence& sequence, @@ -190,9 +184,7 @@ class HeapSimulator { const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, - const Options& options = Options(), - const absl::flat_hash_map* - memory_by_computation = nullptr); + const Options& options = Options()); // Same as above, but runs on with a schedule that covers all nested // computations. @@ -210,9 +202,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator(std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, - const Options& options, const HloSchedule* schedule = nullptr, - const absl::flat_hash_map* - memory_by_computation = nullptr); + const Options& options, const HloSchedule* schedule = nullptr); ~HeapSimulator(); absl::Status RunComputation( @@ -250,13 +240,10 @@ class HeapSimulator { const std::unique_ptr> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; - // schedule_ is set by buffer assignment, and memory_by_computation_ is - // set by hlo scheduling. Then, in RunComputation, we check both in order to - // handle subcomputations. It would be good to unify the handling of - // subcomputations, but it's not clear how. + // schedule_ is set by buffer assignment. Then, in RunComputation, we check + // both in order to handle subcomputations. It would be good to unify the + // handling of subcomputations, but it's not clear how. const HloSchedule* schedule_; - const absl::flat_hash_map* - memory_by_computation_; // Hold some sets for error-checking the sequence of Alloc and Free calls. absl::flat_hash_set allocated_buffers_; @@ -296,9 +283,7 @@ class HeapAlgorithm { virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, // The total number of bytes allocated by instruction. - int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) {} + int64_t alloc_size_by_instruction) {} // Free de-allocates a previously allocated buffer. virtual void Free(const BufferType* buffer, int64_t size) = 0; @@ -334,9 +319,8 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferType* buffer, int64_t size) override; void AccountForSubcomputationMemory( - const HloInstruction* instruction, int64_t alloc_size_by_instruction, - const absl::flat_hash_map& - memory_by_computation) override; + const HloInstruction* instruction, + int64_t alloc_size_by_instruction) override; void Free(const BufferType* buffer, int64_t size) override; @@ -364,6 +348,8 @@ struct BufferIntervalTreeNode { BufferIntervalTreeNode* right; // parent BufferIntervalTreeNode* parent; + + std::string ToString() const; }; // An interval tree that can query buffers overlapping in time. @@ -383,7 +369,53 @@ class BufferIntervalTree { BufferIntervalTreeNode* GetRoot() { return root_; } + // Returns a compact 2D view of memory usage over time. + // X axis is time, Y axis is memory. + // + // Say there are 3 buffers in the heap: + // - Buffer 1: memory block [0, 16), time interval [15, 25] + // - Buffer 2: memory block [16, 48), time interval [15, 19] + // - Buffer 3: memory block [32, 64), time interval [20, 22] + // + // NodesOverlappingInTimeToAsciiArt(/*start=*/18, /*end=*/23, + // /*group_size=*/3) returns: + // + // Memory map for time: [18,23], memory_block_size: 16, group_size: 3 + // + // ..# ##. 64 + // ### ##. 48 + // ##. ... 32 + // ### ### 16 + // 890 123 + // + // Explanation: + // + // The functions decides a memory block size of 16 would be most compact to + // display all the buffers. + // '#' indicates used and '.' indicates free memory. + // + // ..# ##. 64 "64" indicates memory block [48,64) + // ### ##. 48 "48" indicates memory block [32,48) + // ##. ... 32 "32" indicates memory block [16,32) + // ### ### 16 "16" indicates memory block [0,16) + // 890 123 + // + // "890 123" indicate the last digits of time instants 18, 19, 20, 21, 22, 23. + // Only the last digit is shown for compactness. + // `group_size=3` inserts spaces after every 3 columns (time instants). + // All the memory blocks beyond 64 are free for time interval [18,23]. + std::string NodesOverlappingInTimeToAsciiArt(int64_t start, int64_t end, + int64_t group_size = 0) const; + + // Returns a vector of size `end - start + 1` where the element at index i is + // the memory used at the time instant `start + i`. Both `start` and `end` are + // inclusive. + std::vector MemoryUsedInInterval(int64_t start, int64_t end) const; + private: + std::vector NodesOverlappingInTime( + int64_t start, int64_t end) const; + BufferIntervalTreeNode* root_ = nullptr; std::list node_storage_; }; diff --git a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc index 480213f78e8f7b..878030b01e99b3 100644 --- a/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc +++ b/third_party/xla/xla/service/heap_simulator/heap_simulator_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include -#include #include #include #include @@ -26,27 +25,38 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_join.h" +#include "absl/container/inlined_vector.h" +#include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_schedule.h" +#include "xla/literal_util.h" #include "xla/service/buffer_value.h" #include "xla/service/heap_simulator/allocation_block.h" -#include "xla/service/hlo_ordering.h" +#include "xla/service/hlo_alias_analysis.h" +#include "xla/service/hlo_dataflow_analysis.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_value.h" -#include "xla/service/tuple_points_to_analysis.h" -#include "xla/status_macros.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { namespace { +using ::testing::ContainerEq; +using ::testing::HasSubstr; +using ::testing::StrEq; + class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { @@ -213,9 +223,6 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { auto size_fn = [](const BufferValue& buffer) { return ShapeUtil::ByteSizeOf(buffer.shape()); }; - absl::flat_hash_map memory_by_computation; - memory_by_computation[cond_computation] = 5; - memory_by_computation[body_computation] = 16; std::unique_ptr alias_analysis = HloAliasAnalysis::Run(module.get()).value(); @@ -224,7 +231,7 @@ TEST_F(MinimumMemoryForSequenceTest, SubcomputationAccounting) { // so we don't double count. EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( *entry_computation, schedule.sequence(entry_computation), - *alias_analysis, size_fn, &memory_by_computation) + *alias_analysis, size_fn) .value()); } @@ -2019,6 +2026,61 @@ TEST_F(IntervalTreeTest, ThreeLevelsRightLeftChunkDifferent) { ASSERT_EQ(tree.GetRoot(), nullptr); } +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArt) { + // Buffer 1: memory block [0, 16), time interval [15, 25] + // Buffer 2: memory block [16, 48), time interval [15, 19] + // Buffer 3: memory block [32, 64), time interval [20, 22] + BufferIntervalTree tree; + tree.Add(15, 25, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + tree.Add(15, 19, HeapSimulator::Chunk::FromOffsetEnd(16, 48)); + tree.Add(20, 22, HeapSimulator::Chunk::FromOffsetEnd(32, 64)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/18, /*end=*/23, /*group_size=*/3); + EXPECT_THAT(output, HasSubstr("Memory map for time: [18,23], " + "memory_block_size: 16, group_size: 3")); + EXPECT_THAT(output, HasSubstr("..# ##. 64")); + EXPECT_THAT(output, HasSubstr("### ##. 48")); + EXPECT_THAT(output, HasSubstr("##. ... 32")); + EXPECT_THAT(output, HasSubstr("### ### 16")); + EXPECT_THAT(output, HasSubstr("890 123")); +} + +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArtTooLarge) { + BufferIntervalTree tree; + tree.Add(0, 4, HeapSimulator::Chunk::FromOffsetEnd(0, 128)); + tree.Add(5, 10, HeapSimulator::Chunk::FromOffsetEnd(1, 129)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/0, /*end=*/10, /*group_size=*/3); + EXPECT_THAT( + output, + HasSubstr( + "Cannot print memory usage to ASCII art. Printing nodes instead!")); + EXPECT_THAT(output, HasSubstr("start: 0 end: 4 chunk: [0,128)")); + EXPECT_THAT(output, HasSubstr("start: 5 end: 10 chunk: [1,129)")); +} + +TEST_F(IntervalTreeTest, BufferIntervalTreeToAsciiArtFreeMemory) { + BufferIntervalTree tree; + tree.Add(5, 10, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + std::string output = tree.NodesOverlappingInTimeToAsciiArt( + /*start=*/0, /*end=*/4, /*group_size=*/10); + EXPECT_THAT(output, StrEq("No nodes overlapping in time. Memory is free!")); +} + +TEST_F(IntervalTreeTest, BufferIntervalTreeMemoryUsedInInterval) { + // Buffer 1: memory block [0, 16), time interval [15, 25] + // Buffer 2: memory block [16, 48), time interval [15, 19] + // Buffer 3: memory block [32, 64), time interval [20, 22] + BufferIntervalTree tree; + tree.Add(15, 25, HeapSimulator::Chunk::FromOffsetEnd(0, 16)); + tree.Add(15, 19, HeapSimulator::Chunk::FromOffsetEnd(16, 48)); + tree.Add(20, 22, HeapSimulator::Chunk::FromOffsetEnd(32, 64)); + std::vector memory_used_by_time = tree.MemoryUsedInInterval( + /*start=*/18, /*end=*/23); + std::vector expected_memory_used_by_time = {48, 48, 48, 48, 48, 16}; + EXPECT_THAT(memory_used_by_time, ContainerEq(expected_memory_used_by_time)); +} + class SlicedBufferIntervalTest : public ::testing::Test { public: using HeapTy = GlobalDecreasingSizeBestFitHeap; diff --git a/third_party/xla/xla/service/hlo.proto b/third_party/xla/xla/service/hlo.proto index 083fd7d2b3fac8..83c3b8545feb2d 100644 --- a/third_party/xla/xla/service/hlo.proto +++ b/third_party/xla/xla/service/hlo.proto @@ -112,7 +112,7 @@ enum CustomCallApiVersion { } // Serialization of HloInstruction. -// Next ID: 88 +// Next ID: 90 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -382,6 +382,12 @@ message HloInstructionProto { // Represents the list of devices that participate in a collective operation. xla.CollectiveDeviceListProto collective_device_list = 87; + + // For HLO value tracking. + xla.OriginalValueProto original_value = 88; + + // Specifies if a call instruction is a composite. + bool is_composite = 89; } // Serialization of HloComputation. @@ -576,6 +582,7 @@ message HloModuleProto { FUSION = 2; LAYOUT = 3; DOT = 4; + FLAGNET = 5; } // Information about the optimization profile that this module contains. diff --git a/third_party/xla/xla/service/hlo_alias_analysis_test.cc b/third_party/xla/xla/service/hlo_alias_analysis_test.cc index 36709bc0a8e79c..ea687b640a58ef 100644 --- a/third_party/xla/xla/service/hlo_alias_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_alias_analysis_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_computation_test.cc b/third_party/xla/xla/service/hlo_computation_test.cc index 16ce018e7da3e0..a7190b33f2088d 100644 --- a/third_party/xla/xla/service/hlo_computation_test.cc +++ b/third_party/xla/xla/service/hlo_computation_test.cc @@ -15,8 +15,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" +#include #include -#include #include #include #include @@ -24,19 +24,24 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_clone_context.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/utils/hlo_matchers.h" -#include "xla/literal.h" +#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" +#include "xla/shape_tree.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status.h" #include "tsl/platform/statusor.h" namespace xla { @@ -940,5 +945,32 @@ TEST_F(HloComputationTest, CloneWrappedAsyncInstructionSameWrappedFunc) { cloned_done.get()->async_wrapped_computation()); } +TEST_F(HloComputationTest, CompositeCall) { + const char* const hlo_string = R"( + HloModule Module + + add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) + } + + ENTRY %CallR0F32AddScalar.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=add, is_composite=true, + frontend_attributes={ + composite.attributes={n = 1 : i32, tensor = dense<1> : tensor}, + composite.name="foo.bar", + composite.version="1" + } +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* composite_call = FindInstruction(module.get(), "call"); + EXPECT_EQ(composite_call->opcode(), HloOpcode::kCall); + EXPECT_TRUE(composite_call->is_composite()); + EXPECT_EQ(composite_call->frontend_attributes().map().size(), 3); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_cse.cc b/third_party/xla/xla/service/hlo_cse.cc index 2594fa392a5c1c..3162204e68c6dd 100644 --- a/third_party/xla/xla/service/hlo_cse.cc +++ b/third_party/xla/xla/service/hlo_cse.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -316,6 +317,32 @@ absl::StatusOr HloCSE::Run( } } } + if (auto fusion = computation->FusionInstruction()) { + if (fusion->IsMultiOutputFusion()) { + // Attach users to the representative instruction, thus making the + // duplicate fusion roots unused. HloDCE can then cleanup the unused + // fusion roots. + absl::flat_hash_map + root_to_unique_index; + int64_t root_index = 0; + HloInstruction* root = computation->root_instruction(); + for (const HloInstruction* hlo : root->operands()) { + if (root_to_unique_index.find(hlo) == root_to_unique_index.end()) { + root_to_unique_index[hlo] = root_to_unique_index[hlo] = root_index; + } + ++root_index; + } + if (root_to_unique_index.size() < root->operand_count()) { + for (HloInstruction* user : fusion->users()) { + if (user->opcode() == HloOpcode::kGetTupleElement) { + const HloInstruction* fusion_root = + root->operand(user->tuple_index()); + user->set_tuple_index(root_to_unique_index[fusion_root]); + } + } + } + } + } } return changed; } diff --git a/third_party/xla/xla/service/hlo_cse_test.cc b/third_party/xla/xla/service/hlo_cse_test.cc index 106eea0923b0be..f6378353b8d507 100644 --- a/third_party/xla/xla/service/hlo_cse_test.cc +++ b/third_party/xla/xla/service/hlo_cse_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include "absl/algorithm/container.h" #include "absl/strings/substitute.h" #include "xla/hlo/ir/hlo_computation.h" @@ -918,7 +919,10 @@ TEST_F(HloCseTest, MultiOutputFusion) { ENTRY entry { p0 = f32[] parameter(0) p1 = f32[] parameter(1) - ROOT root = (f32[], f32[]) fusion(p0, p1), kind=kLoop, calls=f + fusion = (f32[], f32[]) fusion(p0, p1), kind=kLoop, calls=f + gte0 = f32[] get-tuple-element(fusion), index=0 + gte1 = f32[] get-tuple-element(fusion), index=1 + ROOT res = (f32[], f32[]) tuple(gte0, gte1) } )"; @@ -928,10 +932,18 @@ TEST_F(HloCseTest, MultiOutputFusion) { SCOPED_TRACE(absl::StrCat("Module after CSE:\n", m->ToString())); EXPECT_EQ(changed, true); + HloInstruction* root = m->entry_computation()->root_instruction(); HloInstruction* add0; HloInstruction* add1; + HloInstruction* gte0; + HloInstruction* gte1; + ASSERT_THAT(root, GmockMatch(m::Tuple(m::GetTupleElement(>e0), + m::GetTupleElement(>e1)))); + EXPECT_EQ(gte0, gte1); + EXPECT_EQ(gte0->tuple_index(), 0); + const HloInstruction* fusion = gte0->operand(0); ASSERT_THAT( - m->entry_computation()->root_instruction()->fused_expression_root(), + fusion->fused_expression_root(), GmockMatch(m::Tuple(m::Add(&add0, m::Parameter(0), m::Parameter(1)), m::Add(&add1, m::Parameter(0), m::Parameter(1))))); EXPECT_EQ(add0, add1); diff --git a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc b/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc index 5158e068501d5e..967c2955db20c2 100644 --- a/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_dataflow_analysis_test.cc @@ -41,8 +41,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_dce_test.cc b/third_party/xla/xla/service/hlo_dce_test.cc index c80c286c17ea2d..38a170ae77160d 100644 --- a/third_party/xla/xla/service/hlo_dce_test.cc +++ b/third_party/xla/xla/service/hlo_dce_test.cc @@ -35,9 +35,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_domain_test.cc b/third_party/xla/xla/service/hlo_domain_test.cc index 63bc4aba494d79..13f80fdf6b441b 100644 --- a/third_party/xla/xla/service/hlo_domain_test.cc +++ b/third_party/xla/xla/service/hlo_domain_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/service/sharding_propagation.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc index 5f583c1f38f0bb..8b7a99b385db67 100644 --- a/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc +++ b/third_party/xla/xla/service/hlo_input_output_alias_config_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_instruction_test.cc b/third_party/xla/xla/service/hlo_instruction_test.cc index 6147be229323e0..7709bda6032e7f 100644 --- a/third_party/xla/xla/service/hlo_instruction_test.cc +++ b/third_party/xla/xla/service/hlo_instruction_test.cc @@ -15,6 +15,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" +#include +#include +#include +#include +#include #include #include #include @@ -22,24 +27,32 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" +#include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" -#include "xla/literal.h" +#include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/protobuf_util.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -2706,6 +2719,16 @@ TEST_F(HloInstructionTest, VerifyBodyComputationPointsToWhile) { } } EXPECT_EQ(num_while_body_comp, 1); + + for (HloInstruction* instruction : + module->entry_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body = instruction->while_body(); + EXPECT_TRUE(while_body->IsWhileBodyComputation()); + HloInstruction* while_back_ref = while_body->WhileCallInstruction(); + EXPECT_EQ(while_back_ref->while_body(), while_body); + } + } } TEST_F(HloInstructionTest, @@ -2752,7 +2775,7 @@ TEST_F(HloInstructionTest, module->AddEntryComputation(main_builder.Build()); // Should find conditional branch computations in the graph and it should - // point to the conditonal instruction. + // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { if (comp->IsConditionalBranchComputation()) { @@ -2827,7 +2850,7 @@ TEST_F(HloInstructionTest, module->AddEntryComputation(main_builder.Build()); // Should find conditional branch computations in the graph and it should - // point to the conditonal instruction. + // point to the conditional instruction. int num_conditional_branch_comp = 0; for (HloComputation* comp : module->MakeComputationPostOrder()) { if (comp->IsConditionalBranchComputation()) { diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.cc b/third_party/xla/xla/service/hlo_memory_scheduler.cc index 283b82e23ec738..83e40723895289 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.cc +++ b/third_party/xla/xla/service/hlo_memory_scheduler.cc @@ -90,11 +90,8 @@ class ListScheduler { static absl::StatusOr Run( HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation) { - ListScheduler scheduler(computation, points_to_analysis, size_function, - memory_by_computation); + const BufferValue::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); return scheduler.CreateSchedule(); } @@ -115,13 +112,10 @@ class ListScheduler { ListScheduler(HloComputation* computation, const TuplePointsToAnalysis& points_to_analysis, - const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation) + const BufferValue::SizeFunction& size_function) : computation_(computation), points_to_analysis_(points_to_analysis), - size_function_(size_function), - memory_by_computation_(memory_by_computation) { + size_function_(size_function) { // Create a map containing the LogicalBuffer uses for each HLO // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by @@ -242,29 +236,7 @@ class ListScheduler { freed_bytes += size_function_(*buffer); } } - // We only count the memory usage of the largest subcomputation, instead of - // adding them all, because subcomputations won't execute in parallel. - int64_t max_subcomputation_bytes = 0; - for (const auto* c : instruction->called_computations()) { - auto it = memory_by_computation_.find(c); - if (it != memory_by_computation_.end()) { - int64_t subcomputation_bytes = it->second; - if (subcomputation_bytes > max_subcomputation_bytes) { - max_subcomputation_bytes = subcomputation_bytes; - } - } - } - int64_t bytes_defined; - if (max_subcomputation_bytes > 0 && - (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || - opcode == HloOpcode::kConditional)) { - // The output buffer of while/call/conditional is always aliased with the - // output buffer of the root instruction in the body. Don't double count. - bytes_defined = max_subcomputation_bytes; - } else { - bytes_defined = entry.bytes_defined + max_subcomputation_bytes; - } - return freed_bytes - bytes_defined; + return freed_bytes - entry.bytes_defined; } // Constructs the scheduling priority of the given instruction. @@ -392,11 +364,6 @@ class ListScheduler { HloComputation* computation_; const TuplePointsToAnalysis& points_to_analysis_; const BufferValue::SizeFunction& size_function_; - // Computations are analyzed in post-order. When scheduling an instruction - // that includes subcomputations, such as a while loop, we use this map to - // look up the memory needed by subcomputations. - const absl::flat_hash_map& - memory_by_computation_; // A map containing the LogicalBuffers that each instruction uses. absl::flat_hash_map> @@ -426,19 +393,15 @@ absl::StatusOr ScheduleComputationHelper( const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, const MemorySchedulerAlgorithm& algorithm, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { VLOG(2) << "Computation: " << computation->name(); if (algorithm) { return algorithm(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - peak_memory); + size_function, postprocessor, peak_memory); } return DefaultMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, - postprocessor, peak_memory); + size_function, postprocessor, peak_memory); } } // namespace @@ -448,8 +411,6 @@ absl::StatusOr DFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // These variables are a hack to prevent overflows. int64_t cumulative_total_size = 0; @@ -526,9 +487,9 @@ absl::StatusOr DFSMemoryScheduler( CHECK_EQ(sequence.size(), computation->instruction_count()); if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -538,8 +499,6 @@ absl::StatusOr BFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // Index of HloInstruction in the `computation`. absl::flat_hash_map inst_index; @@ -586,9 +545,9 @@ absl::StatusOr BFSMemoryScheduler( CHECK_EQ(sequence.size(), computation->instruction_count()); if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; @@ -605,16 +564,14 @@ ModuleSchedulerAlgorithm ComputationSchedulerToModuleScheduler( const absl::flat_hash_set& execution_threads, int64_t* peak_memory) -> absl::StatusOr { HloSchedule schedule(module); - absl::flat_hash_map memory_by_computation; for (auto* computation : module->MakeComputationPostOrder(execution_threads)) { if (!computation->IsFusionComputation()) { - TF_ASSIGN_OR_RETURN( - HloInstructionSequence computation_sequence, - ScheduleComputationHelper( - computation, points_to_analysis, alias_analysis, size_func, - computation_scheduler, memory_by_computation, postprocessor, - /*peak_memory=*/nullptr)); + TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence, + ScheduleComputationHelper( + computation, points_to_analysis, alias_analysis, + size_func, computation_scheduler, postprocessor, + /*peak_memory=*/nullptr)); schedule.set_sequence(computation, std::move(computation_sequence)); } } @@ -631,20 +588,18 @@ absl::StatusOr ListMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { - TF_ASSIGN_OR_RETURN(HloInstructionSequence sequence, - ListScheduler::Run(computation, points_to_analysis, - size_function, memory_by_computation)); + TF_ASSIGN_OR_RETURN( + HloInstructionSequence sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); if (postprocessor) { sequence = postprocessor(sequence); } if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -654,8 +609,6 @@ absl::StatusOr PostOrderMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { HloInstructionSequence sequence(computation->MakeInstructionPostOrder()); if (postprocessor) { @@ -663,9 +616,9 @@ absl::StatusOr PostOrderMemoryScheduler( } if (peak_memory) { TF_ASSIGN_OR_RETURN( - *peak_memory, HeapSimulator::MinimumMemoryForComputation( - *computation, sequence, alias_analysis, size_function, - &memory_by_computation)); + *peak_memory, + HeapSimulator::MinimumMemoryForComputation( + *computation, sequence, alias_analysis, size_function)); } return sequence; } @@ -675,8 +628,6 @@ absl::StatusOr DefaultMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory) { // We try a few schedulers and choose whichever returns a lower min-memory, // not accounting for fragmentation. @@ -690,24 +641,21 @@ absl::StatusOr DefaultMemoryScheduler( TF_ASSIGN_OR_RETURN( HloInstructionSequence list_sequence, ListMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - &list_memory)); + size_function, postprocessor, &list_memory)); VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); int64_t dfs_memory; TF_ASSIGN_OR_RETURN( HloInstructionSequence dfs_sequence, DFSMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, postprocessor, - &dfs_memory)); + size_function, postprocessor, &dfs_memory)); VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); int64_t post_order_memory; - TF_ASSIGN_OR_RETURN( - HloInstructionSequence post_order_sequence, - PostOrderMemoryScheduler(computation, points_to_analysis, alias_analysis, - size_function, memory_by_computation, - postprocessor, &post_order_memory)); + TF_ASSIGN_OR_RETURN(HloInstructionSequence post_order_sequence, + PostOrderMemoryScheduler( + computation, points_to_analysis, alias_analysis, + size_function, postprocessor, &post_order_memory)); VLOG(2) << "Min-memory post order sequence: " << HumanReadableNumBytes(post_order_memory); @@ -815,21 +763,6 @@ absl::StatusOr ScheduleModule( return std::move(schedule); } -absl::StatusOr ScheduleComputation( - HloComputation* computation, const BufferValue::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor) { - CHECK(!computation->IsFusionComputation()); - TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, - TuplePointsToAnalysis::Run(computation->parent())); - TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, - HloAliasAnalysis::Run(computation->parent())); - absl::flat_hash_map empty_map; - return ScheduleComputationHelper( - computation, *points_to_analysis, *alias_analysis, size_function, - /*algorithm=*/nullptr, empty_map, postprocessor, - /*peak_memory=*/nullptr); -} - HloMemoryScheduler::HloMemoryScheduler( const BufferValue::SizeFunction& size_function, const ModuleSchedulerAlgorithm& algorithm) diff --git a/third_party/xla/xla/service/hlo_memory_scheduler.h b/third_party/xla/xla/service/hlo_memory_scheduler.h index 112ced3ee95112..2fb211ac6531a2 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler.h +++ b/third_party/xla/xla/service/hlo_memory_scheduler.h @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -51,7 +50,6 @@ using MemorySchedulerAlgorithm = std::function( HloComputation*, const TuplePointsToAnalysis&, const HloAliasAnalysis&, const LogicalBuffer::SizeFunction&, - const absl::flat_hash_map&, const MemorySchedulerPostprocessor&, /*peak_memory*/ int64_t*)>; @@ -73,8 +71,6 @@ absl::StatusOr ListMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // DFS-order scheduler @@ -83,8 +79,6 @@ absl::StatusOr DFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // BFS-order scheduler @@ -102,8 +96,6 @@ absl::StatusOr BFSMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // Naive Post Order scheduler @@ -112,8 +104,6 @@ absl::StatusOr PostOrderMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); // The default scheduling algorithm. Runs the list scheduler, the DFS scheduler, @@ -125,8 +115,6 @@ absl::StatusOr DefaultMemoryScheduler( const TuplePointsToAnalysis& points_to_analysis, const HloAliasAnalysis& alias_analysis, const LogicalBuffer::SizeFunction& size_function, - const absl::flat_hash_map& - memory_by_computation, const MemorySchedulerPostprocessor& postprocessor, int64_t* peak_memory); absl::StatusOr DefaultModuleScheduler( @@ -146,13 +134,6 @@ absl::StatusOr ScheduleModule( const absl::flat_hash_set& execution_threads = {}, int64_t* peak_memory = nullptr); -// Computes the schedule for a single computation. -// Currently only used by the GPU backend. -absl::StatusOr ScheduleComputation( - HloComputation* computation, - const LogicalBuffer::SizeFunction& size_function, - const MemorySchedulerPostprocessor& postprocessor); - // A pass which schedules the HLO instructions in a module. The HloModule's // schedule field is set to the resulting HloSchedule using // HloModule::set_schedule. diff --git a/third_party/xla/xla/service/hlo_memory_scheduler_test.cc b/third_party/xla/xla/service/hlo_memory_scheduler_test.cc index fef4b71c55b7a9..62a13d14097887 100644 --- a/third_party/xla/xla/service/hlo_memory_scheduler_test.cc +++ b/third_party/xla/xla/service/hlo_memory_scheduler_test.cc @@ -41,8 +41,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_module_dce_test.cc b/third_party/xla/xla/service/hlo_module_dce_test.cc index c192429c2f30e0..4b1a7b7e2e4409 100644 --- a/third_party/xla/xla/service/hlo_module_dce_test.cc +++ b/third_party/xla/xla/service/hlo_module_dce_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_module_group_test.cc b/third_party/xla/xla/service/hlo_module_group_test.cc index b56b53b4952e05..007df88bdcc9d9 100644 --- a/third_party/xla/xla/service/hlo_module_group_test.cc +++ b/third_party/xla/xla/service/hlo_module_group_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/service/hlo_module_group_metadata.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_module_test.cc b/third_party/xla/xla/service/hlo_module_test.cc index 8af2621174e7b0..f2375751a90f55 100644 --- a/third_party/xla/xla/service/hlo_module_test.cc +++ b/third_party/xla/xla/service/hlo_module_test.cc @@ -36,10 +36,10 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" -#include "tsl/lib/strings/proto_serialization.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/hlo_ordering.cc b/third_party/xla/xla/service/hlo_ordering.cc index 466f64cee1d49f..388de97291fab1 100644 --- a/third_party/xla/xla/service/hlo_ordering.cc +++ b/third_party/xla/xla/service/hlo_ordering.cc @@ -363,14 +363,13 @@ bool HloOrdering::UsesBeforeValueDefinition( return true; } } - // The use at an async call occurs before values that are defined in the - // called computation of the async wrapped instruction. - if (use.instruction->IsAsynchronous() && - use.instruction->async_wrapped_opcode() == HloOpcode::kCall) { + // The use at an async op occurs before values that are defined in the async + // wrapped computation or any of its nested computations. + if (use.instruction->IsAsynchronous()) { const HloInstruction* async = use.instruction; if (call_graph_->InstructionIsNestedIn( value.defining_instruction(), - async->async_wrapped_instruction()->to_apply())) { + async->async_wrapped_computation())) { VLOG(4) << " use is async " << use.instruction->name() << " and def is in called computation"; return true; diff --git a/third_party/xla/xla/service/hlo_ordering_test.cc b/third_party/xla/xla/service/hlo_ordering_test.cc index dc003777ecef6e..c0b1dc9c0c6bb7 100644 --- a/third_party/xla/xla/service/hlo_ordering_test.cc +++ b/third_party/xla/xla/service/hlo_ordering_test.cc @@ -28,9 +28,9 @@ limitations under the License. #include "xla/service/hlo_value.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { @@ -675,6 +675,7 @@ ENTRY %main { HloInstruction* async_wrapped_call = FindInstruction(module.get(), "async_wrapped_call"); HloInstruction* p0 = FindInstruction(module.get(), "p0"); + HloInstruction* broadcast1 = FindInstruction(module.get(), "broadcast.1"); ASSERT_NE(async_start, nullptr); ASSERT_NE(async_done, nullptr); @@ -685,13 +686,16 @@ ENTRY %main { HloUse async_done_use = HloUse{async_done, 0, {0, 0}}; HloUse call_use = HloUse{async_wrapped_call, 0}; const HloValue& value = dataflow->GetUniqueValueAt(async_wrapped_call, {}); + const HloValue& broadcast_value = dataflow->GetUniqueValueAt(broadcast1, {}); DependencyHloOrdering ordering(module.get()); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_start_use}, value, *dataflow)); + EXPECT_TRUE(ordering.UsesBeforeValueDefinition({&async_start_use}, + broadcast_value, *dataflow)); EXPECT_FALSE( ordering.UsesBeforeValueDefinition({&call_use}, value, *dataflow)); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_done_use}, value, *dataflow)); } @@ -795,11 +799,11 @@ ENTRY %main { const HloValue& value = dataflow->GetUniqueValueAt(async_wrapped_call, {}); DependencyHloOrdering ordering(module.get()); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_start_use}, value, *dataflow)); EXPECT_FALSE( ordering.UsesBeforeValueDefinition({&call_use}, value, *dataflow)); - EXPECT_FALSE( + EXPECT_TRUE( ordering.UsesBeforeValueDefinition({&async_done_use}, value, *dataflow)); } diff --git a/third_party/xla/xla/service/hlo_parser.cc b/third_party/xla/xla/service/hlo_parser.cc index ab144fa6eb34da..2ff069772c771e 100644 --- a/third_party/xla/xla/service/hlo_parser.cc +++ b/third_party/xla/xla/service/hlo_parser.cc @@ -46,6 +46,7 @@ limitations under the License. #include "absl/strings/strip.h" #include "absl/types/span.h" #include "Eigen/Core" +#include "xla/array.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" @@ -54,9 +55,11 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/ir/hlo_original_value.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" +#include "xla/hlo/ir/tile_assignment.h" #include "xla/layout.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -77,6 +80,7 @@ limitations under the License. #include "tsl/lib/gtl/map_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { @@ -311,6 +315,7 @@ class HloParserImpl : public HloParser { // enclosed in matching curly braces (returned value includes the curlies). kStringOrJsonDict, kCollectiveDeviceList, + kOriginalValue, }; struct AttrConfig { @@ -446,7 +451,7 @@ class HloParserImpl : public HloParser { // bool ParseAttributes( const absl::flat_hash_map& attrs, - bool allow_attributes = true); + bool allow_attributes = true, const std::optional& shape = {}); // sub_attributes ::= '{' (','? attribute)* '}' // @@ -460,7 +465,8 @@ class HloParserImpl : public HloParser { // Do not call this except in ParseAttributes or ParseSubAttributes. bool ParseAttributeHelper( const absl::flat_hash_map& attrs, - absl::flat_hash_set* seen_attrs); + absl::flat_hash_set* seen_attrs, + const std::optional& shape = {}); // Copy attributes from `attrs` to `message`, unless the attribute name is in // `non_proto_attrs`. @@ -487,12 +493,11 @@ class HloParserImpl : public HloParser { bool ParseWindow(Window* window, bool expect_outer_curlies); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); - bool ParseMetadata(OpMetadata* metadata); - bool ParseSingleOrListMetadata( - tsl::protobuf::RepeatedPtrField* metadata); + bool ParseMetadata(OpMetadata& metadata); + bool ParseSingleOrListMetadata(std::vector& metadata); bool ParseOpShardingType(OpSharding::Type* type); bool ParseListShardingType(std::vector* types); - bool ParseSharding(OpSharding* sharding); + bool ParseSharding(std::optional& sharding); bool ParseCollectiveDeviceList(CollectiveDeviceList* device_list); bool ParseFrontendAttributes(FrontendAttributes* frontend_attributes); bool ParseStatisticsViz(StatisticsViz* statistics_viz); @@ -500,7 +505,8 @@ class HloParserImpl : public HloParser { std::vector& iota_reshape_dims, std::vector& iota_transpose_perm, std::vector* devices); - bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); + bool ParseSingleSharding(std::optional& sharding, + bool lbrace_pre_lexed); bool ParseParameterReplication(ParameterReplication* parameter_replication); bool ParseBooleanListOrSingleBoolean(BoolList* boolean_list); bool ParseReplicaGroupsOnly(std::vector* replica_groups); @@ -564,6 +570,9 @@ class HloParserImpl : public HloParser { bool ParseBool(bool* result); bool ParseToken(TokKind kind, const std::string& msg); bool ParseUnsignedIntegerType(PrimitiveType* primitive_type); + bool ParseOriginalValue( + optional>* original_value, + const Shape& shape); using AliasingData = absl::flat_hash_map; @@ -1356,7 +1365,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, // Add optional attributes. These are added to any HloInstruction type if // present. absl::flat_hash_map attrs; - optional sharding; + optional sharding; optional frontend_attributes; optional statistics_viz; attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; @@ -1371,6 +1380,11 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, optional> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; + + optional> original_value; + attrs["original_value"] = {/*required=*/false, AttrTy::kOriginalValue, + &original_value}; + optional metadata; attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; @@ -1412,9 +1426,8 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, // TODO(b/257495070): Eliminate tuple sharding normalization in HLO parser. // Allow existing HLO text with invalid sharding on tuple shapes by // normalizing tuple sharding. - HloSharding hlo_sharding = HloSharding::FromProto(sharding.value()).value(); - hlo_sharding = hlo_sharding.NormalizeTupleSharding(instruction->shape()); - instruction->set_sharding(std::move(hlo_sharding)); + instruction->set_sharding( + sharding->NormalizeTupleSharding(instruction->shape())); } if (parameter_replication) { int leaf_count = ShapeUtil::GetLeafCount(instruction->shape()); @@ -1440,6 +1453,9 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, if (metadata) { instruction->set_metadata(*metadata); } + if (original_value) { + instruction->set_original_value(*original_value); + } if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } @@ -1492,7 +1508,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT return nullptr; } if (!ParseToken(TokKind::kRparen, "expects ')' after parameter number") || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } std::string param_name(name); @@ -1510,7 +1526,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT "expects '(' before constant literal") || !ParseLiteral(&literal, *shape) || !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1522,7 +1538,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &iota_dimension}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1535,7 +1551,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["largest"] = {/*required=*/false, AttrTy::kBool, &largest}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1582,7 +1598,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kTanh: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1613,7 +1629,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kStochasticConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1630,7 +1646,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kSelect: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -1646,7 +1662,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1655,7 +1671,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kBitcastConvert: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -1678,7 +1694,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["use_global_device_ids"] = {/*required=*/false, AttrTy::kBool, &use_global_device_ids}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (opcode == HloOpcode::kAllGather) { @@ -1715,7 +1731,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; } if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (opcode == HloOpcode::kAllReduce) { @@ -1748,7 +1764,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["constrain_layout"] = {/*required=*/false, AttrTy::kBool, &constrain_layout}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || (dimensions && dimensions->size() != 1)) { return nullptr; } @@ -1768,7 +1784,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional channel_id; attrs["channel_id"] = {/*required=*/false, AttrTy::kInt64, &channel_id}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateCollectiveBroadcast( @@ -1785,7 +1801,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["slice_sizes"] = {/*required=*/false, AttrTy::kBracedInt64ListList, &slice_sizes}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } std::vector> pairs(source_targets->size()); @@ -1900,6 +1916,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT std::vector async_wrapped_operands; std::vector async_wrapped_operand_shapes; Shape async_wrapped_root_shape; + async_wrapped_operand_shapes.reserve(operands.size()); for (const HloInstruction* operand : operands) { async_wrapped_operand_shapes.push_back(operand->shape()); } @@ -1941,7 +1958,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT // Attributes would have already been consumed when constructing the // async wrapped computation for async-start. if (!(async_wrapped_opcode && opcode == HloOpcode::kAsyncStart)) { - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } } @@ -1999,7 +2016,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT /*required=*/false, AttrTy::kInt32, &cross_program_prefetch_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateCopyStart( @@ -2008,7 +2025,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kReplicaId: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (shape.has_value()) { @@ -2019,7 +2036,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kPartitionId: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (shape.has_value()) { @@ -2030,7 +2047,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kDynamicReshape: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateDynamicReshape( @@ -2043,7 +2060,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &inferred_dimension}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateReshape( @@ -2051,7 +2068,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kAfterAll: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.empty()) { @@ -2062,7 +2079,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT case HloOpcode::kAddDependency: { if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2078,7 +2095,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || dimensions->size() != 1) { return nullptr; } @@ -2101,7 +2118,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT !(shape.has_value() ? ParseOperands(&operands, builder, shape->tuple_shapes_size()) : ParseOperands(&operands, builder))) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2127,7 +2144,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2149,7 +2166,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } // If the is_host_transfer attribute is not present then default to false. @@ -2165,7 +2182,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -2187,7 +2204,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateSend( @@ -2202,7 +2219,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &is_host_transfer}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -2220,7 +2237,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2234,10 +2251,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kCall: { optional to_apply; + optional is_composite = false; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; + attrs["is_composite"] = {/*required=*/false, AttrTy::kBool, + &is_composite}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2251,8 +2271,10 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT })) { return nullptr; } - return builder->AddInstruction( - HloInstruction::CreateCall(*shape, operands, *to_apply)); + + auto call_op = HloInstruction::CreateCall(*shape, operands, *to_apply); + call_op->set_is_composite(is_composite.value()); + return builder->AddInstruction(std::move(call_op)); } case HloOpcode::kReduceWindow: { optional reduce_computation; @@ -2261,7 +2283,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &reduce_computation}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2305,7 +2327,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &operand_precision}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2346,7 +2368,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &fft_length}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2383,7 +2405,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["type"] = {/*required=*/false, AttrTy::kComparisonType, &type}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2423,7 +2445,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/!operand_is_scalar, AttrTy::kBracedInt64List, &broadcast_dimensions}; - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operand_is_scalar && !broadcast_dimensions.has_value()) { @@ -2444,7 +2466,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes) || + !ParseAttributes(attrs, allow_attributes, shape) || dimensions->size() != 1) { return nullptr; } @@ -2470,7 +2492,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/false, AttrTy::kBracedInt64List, &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2496,7 +2518,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions_to_reduce}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.size() % 2) { @@ -2531,7 +2553,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2552,7 +2574,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!window) { @@ -2575,7 +2597,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateSlice( @@ -2587,7 +2609,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["dynamic_slice_sizes"] = { /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.empty()) { @@ -2606,7 +2628,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT } case HloOpcode::kDynamicUpdateSlice: { if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (operands.size() < 2) { @@ -2628,7 +2650,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2648,7 +2670,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/3)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2670,7 +2692,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/5)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2694,7 +2716,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &feature_index}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/5)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2715,7 +2737,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -2740,7 +2762,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT AttrTy::kInstructionAliasing, &output_to_operand_aliasing}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } auto instr = builder->AddInstruction(HloInstruction::CreateFusion( @@ -2757,7 +2779,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } // We need to know the infeed data shape to construct the infeed @@ -2781,7 +2803,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &outfeed_shape}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } HloInstruction* const outfeed_input = operands[0]; @@ -2796,7 +2818,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, &distribution}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2807,7 +2829,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["delta"] = {/*required=*/true, AttrTy::kInt64, &delta}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/0)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction( @@ -2818,7 +2840,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["algorithm"] = {/*required=*/true, AttrTy::kRandomAlgorithm, &algorithm}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateRngBitGenerator( @@ -2833,7 +2855,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &mantissa_bits}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } return builder->AddInstruction(HloInstruction::CreateReducePrecision( @@ -2867,7 +2889,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT AttrTy::kBracedHloComputationList, &branch_computations}; } - if (!ParseAttributes(attrs, allow_attributes)) { + if (!ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (branch_index_is_bool) { @@ -2958,7 +2980,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["api_version"] = {/*required=*/false, AttrTy::kCustomCallApiVersion, &api_version}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3088,7 +3110,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT LocTy loc = lexer_.GetLoc(); if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3163,10 +3185,17 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional indices_are_sorted = false; attrs["indices_are_sorted"] = {/*required=*/false, AttrTy::kBool, &indices_are_sorted}; + optional> operand_batching_dims; + attrs["operand_batching_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &operand_batching_dims}; + optional> start_indices_batching_dims; + attrs["start_indices_batching_dims"] = {/*required=*/false, + AttrTy::kBracedInt64List, + &start_indices_batching_dims}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3175,7 +3204,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT /*offset_dims=*/*offset_dims, /*collapsed_slice_dims=*/*collapsed_slice_dims, /*start_index_map=*/*start_index_map, - /*index_vector_dim=*/*index_vector_dim); + /*index_vector_dim=*/*index_vector_dim, + /*operand_batching_dims=*/ + operand_batching_dims ? *operand_batching_dims + : std::vector(), + /*start_indices_batching_dims=*/ + start_indices_batching_dims ? *start_indices_batching_dims + : std::vector()); if (!maybe_infer_shape([&] { return ShapeInference::InferGatherShape(operands[0]->shape(), operands[1]->shape(), @@ -3211,9 +3246,16 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT optional unique_indices = false; attrs["unique_indices"] = {/*required=*/false, AttrTy::kBool, &unique_indices}; + optional> input_batching_dims; + attrs["input_batching_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &input_batching_dims}; + optional> scatter_indices_batching_dims; + attrs["scatter_indices_batching_dims"] = {/*required=*/false, + AttrTy::kBracedInt64List, + &scatter_indices_batching_dims}; if ((!preset_operands && !ParseOperands(&operands, builder)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } @@ -3228,7 +3270,13 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT /*update_window_dims=*/*update_window_dims, /*inserted_window_dims=*/*inserted_window_dims, /*scatter_dims_to_operand_dims=*/*scatter_dims_to_operand_dims, - /*index_vector_dim=*/*index_vector_dim); + /*index_vector_dim=*/*index_vector_dim, + /*input_batching_dims=*/ + input_batching_dims ? *input_batching_dims + : std::vector(), + /*scatter_indices_batching_dims=*/ + scatter_indices_batching_dims ? *scatter_indices_batching_dims + : std::vector()); if (!maybe_infer_shape([&] { absl::InlinedVector arg_shapes; @@ -3254,7 +3302,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -3272,7 +3320,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/1)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -3290,7 +3338,7 @@ HloInstruction* HloParserImpl::CreateInstruction( // NOLINT &dimensions}; if ((!preset_operands && !ParseOperands(&operands, builder, /*expected_size=*/2)) || - !ParseAttributes(attrs, allow_attributes)) { + !ParseAttributes(attrs, allow_attributes, shape)) { return nullptr; } if (!maybe_infer_shape([&] { @@ -3355,7 +3403,7 @@ bool HloParserImpl::ParseCollectiveDeviceList( // ::= '{' (single_sharding | tuple_sharding) '}' // // tuple_sharding ::= single_sharding* (',' single_sharding)* -bool HloParserImpl::ParseSharding(OpSharding* sharding) { +bool HloParserImpl::ParseSharding(std::optional& sharding) { // A single sharding starts with '{' and is not followed by '{'. // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for // an empty tuple. @@ -3371,15 +3419,18 @@ bool HloParserImpl::ParseSharding(OpSharding* sharding) { // Tuple sharding. // Allow empty tuple shardings. + std::vector tuple_shardings; if (lexer_.GetKind() != TokKind::kRbrace) { do { - if (!ParseSingleSharding(sharding->add_tuple_shardings(), + std::optional tuple_sharding; + if (!ParseSingleSharding(tuple_sharding, /*lbrace_pre_lexed=*/false)) { return false; } + tuple_shardings.push_back(std::move(*tuple_sharding)); } while (EatIfPresent(TokKind::kComma)); } - sharding->set_type(OpSharding::TUPLE); + sharding = HloSharding::FlatTuple(std::move(tuple_shardings)); return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); } @@ -3403,11 +3454,21 @@ bool HloParserImpl::ParseFrontendAttributes( if (!ParseAttributeName(&attribute)) { return false; } - if (lexer_.GetKind() != TokKind::kString) { + + std::string result; + if (lexer_.GetKind() == TokKind::kString) { + if (!ParseString(&result)) { + return false; + } + } else if (lexer_.GetKind() == TokKind::kLbrace) { + if (!ParseJsonDict(&result)) { + return false; + } + } else { return false; } - (*frontend_attributes->mutable_map())[attribute] = lexer_.GetStrVal(); - lexer_.Lex(); + + (*frontend_attributes->mutable_map())[attribute] = result; } while (EatIfPresent(TokKind::kComma)); } return ParseToken(TokKind::kRbrace, @@ -3561,7 +3622,7 @@ bool HloParserImpl::ParseTileAssignment( // metadata ::= single_metadata | // ('{' [single_metadata (',' single_metadata)*] '}') // last_tile_dims ::= sharding_type_list -bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, +bool HloParserImpl::ParseSingleSharding(std::optional& sharding, bool lbrace_pre_lexed) { if (!lbrace_pre_lexed && !ParseToken(TokKind::kLbrace, @@ -3584,6 +3645,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, std::vector iota_reshape_dims; std::vector iota_transpose_perm; std::vector subgroup_types; + std::vector metadata; while (lexer_.GetKind() != TokKind::kRbrace) { switch (lexer_.GetKind()) { case TokKind::kw_maximal: @@ -3618,7 +3680,7 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, } } else if (lexer_.GetStrVal() == "metadata") { lexer_.Lex(); - if (!ParseSingleOrListMetadata(sharding->mutable_metadata())) { + if (!ParseSingleOrListMetadata(metadata)) { return false; } } else if (lexer_.GetStrVal() == "last_tile_dims") { @@ -3666,26 +3728,25 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, return Error(loc, "replicated shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::REPLICATED); + sharding = HloSharding::Replicate(metadata); } else if (maximal) { if (devices.size() != 1) { return Error(loc, "maximal shardings should have exactly one device assigned"); } - sharding->set_type(OpSharding::MAXIMAL); - sharding->add_tile_assignment_devices(devices[0]); + sharding = HloSharding::AssignDevice(devices[0], metadata); } else if (manual) { if (!devices.empty()) { return Error(loc, "manual shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::MANUAL); + sharding = HloSharding::Manual(metadata); } else if (unknown) { if (!devices.empty()) { return Error(loc, "unknown shardings should not have any devices assigned"); } - sharding->set_type(OpSharding::UNKNOWN); + sharding = HloSharding::Unknown(metadata); } else { if (tile_assignment_dimensions.empty()) { return Error( @@ -3693,10 +3754,6 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding->set_type(OpSharding::OTHER); - for (int64_t dim : tile_assignment_dimensions) { - sharding->add_tile_assignment_dimensions(dim); - } if (iota_transpose_perm.size() != iota_reshape_dims.size()) { return Error(loc, absl::StrFormat( @@ -3704,44 +3761,41 @@ bool HloParserImpl::ParseSingleSharding(OpSharding* sharding, "iota_reshape_dims : expected %lld, saw %lld.", iota_reshape_dims.size(), iota_transpose_perm.size())); } + if (last_tile_dim_replicate) { + CHECK(subgroup_types.empty()); + subgroup_types.push_back(OpSharding::REPLICATED); + } if (!iota_reshape_dims.empty()) { CHECK(devices.empty()); - absl::c_copy(iota_reshape_dims, - tsl::protobuf::RepeatedFieldBackInserter( - sharding->mutable_iota_reshape_dims())); - absl::c_copy(iota_transpose_perm, - tsl::protobuf::RepeatedFieldBackInserter( - sharding->mutable_iota_transpose_perm())); + sharding = + subgroup_types.empty() + ? HloSharding::IotaTile(tile_assignment_dimensions, + iota_reshape_dims, iota_transpose_perm, + metadata) + : HloSharding::Subgroup( + TileAssignment(tile_assignment_dimensions, + iota_reshape_dims, iota_transpose_perm), + subgroup_types, metadata); } else { if (devices.size() <= 1) { return Error( loc, "non-maximal shardings must have more than one device assigned"); } - for (int64_t device : devices) { - sharding->add_tile_assignment_devices(device); - } - } - - if (last_tile_dims) { - for (OpSharding::Type type : subgroup_types) { - sharding->add_last_tile_dims(type); - } - } else { - sharding->set_replicate_on_last_tile_dim(last_tile_dim_replicate); + auto tiles = std::make_shared>(tile_assignment_dimensions); + absl::c_copy(devices, tiles->begin()); + sharding = + subgroup_types.empty() + ? HloSharding::Tile(TileAssignment(std::move(tiles)), metadata) + : HloSharding::Subgroup(TileAssignment(std::move(tiles)), + subgroup_types, metadata); } } if (shard_as || shard_like) { - sharding->set_is_shard_group(true); - sharding->set_shard_group_id(shard_group_id); - if (shard_as) { - sharding->set_shard_group_type(OpSharding::AS); - } else { - sharding->set_shard_group_type(OpSharding::LIKE); - } - } else { - sharding->set_is_shard_group(false); + sharding = sharding->SetShardGroup( + shard_as ? HloSharding::ShardAs(shard_group_id) + : HloSharding::ShardLike(shard_group_id)); } lexer_.Lex(); @@ -3839,8 +3893,8 @@ bool HloParserImpl::ParseReplicaGroupsOnly( bool HloParserImpl::ParseDomain(DomainData* domain) { absl::flat_hash_map attrs; optional kind; - optional entry_sharding; - optional exit_sharding; + optional entry_sharding; + optional exit_sharding; attrs["kind"] = {/*required=*/true, AttrTy::kString, &kind}; attrs["entry"] = {/*required=*/true, AttrTy::kSharding, &entry_sharding}; attrs["exit"] = {/*required=*/true, AttrTy::kSharding, &exit_sharding}; @@ -3848,10 +3902,10 @@ bool HloParserImpl::ParseDomain(DomainData* domain) { return false; } if (*kind == ShardingMetadata::KindName()) { - auto entry_sharding_ptr = std::make_unique( - HloSharding::FromProto(*entry_sharding).value()); - auto exit_sharding_ptr = std::make_unique( - HloSharding::FromProto(*exit_sharding).value()); + auto entry_sharding_ptr = + std::make_unique(std::move(*entry_sharding)); + auto exit_sharding_ptr = + std::make_unique(std::move(*exit_sharding)); domain->entry_metadata = std::make_unique(std::move(entry_sharding_ptr)); domain->exit_metadata = @@ -4626,12 +4680,12 @@ bool HloParserImpl::ParseSubAttributes( // attributes ::= (',' attribute)* bool HloParserImpl::ParseAttributes( const absl::flat_hash_map& attrs, - bool allow_attributes) { + bool allow_attributes, const std::optional& shape) { LocTy loc = lexer_.GetLoc(); absl::flat_hash_set seen_attrs; if (allow_attributes) { while (EatIfPresent(TokKind::kComma)) { - if (!ParseAttributeHelper(attrs, &seen_attrs)) { + if (!ParseAttributeHelper(attrs, &seen_attrs, shape)) { return false; } } @@ -4645,12 +4699,14 @@ bool HloParserImpl::ParseAttributes( attr_it.first)); } } + return true; } bool HloParserImpl::ParseAttributeHelper( const absl::flat_hash_map& attrs, - absl::flat_hash_set* seen_attrs) { + absl::flat_hash_set* seen_attrs, + const std::optional& shape) { LocTy loc = lexer_.GetLoc(); std::string name; if (!ParseAttributeName(&name)) { @@ -4807,11 +4863,12 @@ bool HloParserImpl::ParseAttributeHelper( return true; } case AttrTy::kSharding: { - OpSharding sharding; - if (!ParseSharding(&sharding)) { + std::optional sharding; + if (!ParseSharding(sharding)) { return false; } - static_cast*>(attr_out_ptr)->emplace(sharding); + static_cast*>(attr_out_ptr) + ->emplace(std::move(*sharding)); return true; } case AttrTy::kCollectiveDeviceList: { @@ -4929,12 +4986,24 @@ bool HloParserImpl::ParseAttributeHelper( ->emplace(std::move(result)); return true; } + case AttrTy::kOriginalValue: { + // By the time this attribute is added, the instruciton shape should + // have been inferred. + if (!shape) { + return TokenError("expects instruction shape"); + } + return ParseOriginalValue( + static_cast>*>( + attr_out_ptr), + *shape); + } case AttrTy::kMetadata: { OpMetadata result; - if (!ParseMetadata(&result)) { + if (!ParseMetadata(result)) { return false; } - static_cast*>(attr_out_ptr)->emplace(result); + static_cast*>(attr_out_ptr) + ->emplace(std::move(result)); return true; } case AttrTy::kDistribution: { @@ -6225,8 +6294,59 @@ bool HloParserImpl::ParsePaddingConfig(PaddingConfig* padding) { return true; } +// original_value ::= original_value | '{' [shape_index] ',' original_array '}' +// [','] +bool HloParserImpl::ParseOriginalValue( + optional>* original_value, + const Shape& shape) { + VLOG(3) << "ParseOriginalValue"; + + if (!ParseToken(TokKind::kLbrace, "Expects '{'")) { + return false; + } + + *original_value = std::make_shared(shape); + + ShapeIndex leaf_shape_index; + while (lexer_.GetKind() != TokKind::kRbrace) { + if (lexer_.GetKind() == TokKind::kLparen) { + lexer_.Lex(); + leaf_shape_index.push_back(0); + } else if (lexer_.GetKind() == TokKind::kRparen) { + lexer_.Lex(); + leaf_shape_index.pop_back(); + } else if (lexer_.GetKind() == TokKind::kComma) { + lexer_.Lex(); + ++leaf_shape_index.back(); + } else if (lexer_.GetKind() == TokKind::kLbrace) { + lexer_.Lex(); + std::string instruction_name; + ShapeIndex shape_index; + if (!ParseString(&instruction_name)) { + return false; + } + if (lexer_.GetKind() != TokKind::kRbrace) { + if (!ParseShapeIndex(&shape_index)) { + return false; + } + } + *(**original_value)->mutable_element(leaf_shape_index) = { + instruction_name, shape_index}; + if (!ParseToken(TokKind::kRbrace, + "Expects '} at end of each OriginalArray'")) { + return false; + } + } else { + return false; + } + } + + lexer_.Lex(); + return true; +} + // '{' metadata_string '}' -bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { +bool HloParserImpl::ParseMetadata(OpMetadata& metadata) { absl::flat_hash_map attrs; optional op_type; optional op_name; @@ -6252,42 +6372,42 @@ bool HloParserImpl::ParseMetadata(OpMetadata* metadata) { return false; } if (op_type) { - metadata->set_op_type(*op_type); + metadata.set_op_type(*op_type); } if (op_name) { - metadata->set_op_name(*op_name); + metadata.set_op_name(*op_name); } if (source_file) { - metadata->set_source_file(*source_file); + metadata.set_source_file(*source_file); } if (source_line) { - metadata->set_source_line(*source_line); + metadata.set_source_line(*source_line); } if (profile_type) { for (const auto& type : *profile_type) { if (!ProfileType_IsValid(type)) { return false; } - metadata->add_profile_type(static_cast(type)); + metadata.add_profile_type(static_cast(type)); } } if (deduplicated_name) { - metadata->set_deduplicated_name(*deduplicated_name); + metadata.set_deduplicated_name(*deduplicated_name); } if (preserve_layout) { - metadata->set_preserve_layout(*preserve_layout); + metadata.set_preserve_layout(*preserve_layout); } else { - metadata->set_preserve_layout(false); + metadata.set_preserve_layout(false); } if (scheduling_name) { - metadata->set_scheduling_name(*scheduling_name); + metadata.set_scheduling_name(*scheduling_name); } return true; } // ::= single_metadata | ('{' [single_metadata (',' single_metadata)*] '}') bool HloParserImpl::ParseSingleOrListMetadata( - tsl::protobuf::RepeatedPtrField* metadata) { + std::vector& metadata) { if (lexer_.GetKind() == TokKind::kLbrace && lexer_.LookAhead() == TokKind::kLbrace) { if (!ParseToken(TokKind::kLbrace, "expected '{' to start metadata list")) { @@ -6296,7 +6416,7 @@ bool HloParserImpl::ParseSingleOrListMetadata( if (lexer_.GetKind() != TokKind::kRbrace) { do { - if (!ParseMetadata(metadata->Add())) { + if (!ParseMetadata(metadata.emplace_back())) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -6305,7 +6425,7 @@ bool HloParserImpl::ParseSingleOrListMetadata( return ParseToken(TokKind::kRbrace, "expected '}' to end metadata list"); } - return ParseMetadata(metadata->Add()); + return ParseMetadata(metadata.emplace_back()); } bool HloParserImpl::ParseOpShardingType(OpSharding::Type* type) { @@ -6680,14 +6800,14 @@ absl::StatusOr HloParserImpl::ParseLayoutOnly() { absl::StatusOr HloParserImpl::ParseShardingOnly() { lexer_.Lex(); - OpSharding op_sharding; - if (!ParseSharding(&op_sharding)) { + std::optional sharding; + if (!ParseSharding(sharding)) { return InvalidArgument("Syntax error:\n%s", GetError()); } if (lexer_.GetKind() != TokKind::kEof) { return InvalidArgument("Syntax error:\nExtra content after sharding"); } - return HloSharding::FromProto(op_sharding); + return std::move(*sharding); } absl::StatusOr diff --git a/third_party/xla/xla/service/hlo_parser_test.cc b/third_party/xla/xla/service/hlo_parser_test.cc index 1f50e26133b9e3..6378f08744e76f 100644 --- a/third_party/xla/xla/service/hlo_parser_test.cc +++ b/third_party/xla/xla/service/hlo_parser_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" +#include #include #include #include @@ -22,24 +23,34 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "xla/array.h" #include "xla/hlo/ir/collective_device_list.h" #include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_frontend_attributes.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/service/hlo_lexer.h" +#include "xla/service/hlo_module_config.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/window_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -68,7 +79,7 @@ std::string TestDataToString(const ::testing::TestParamInfo& data) { // // In general we want to avoid these because we want HLO text to be // round-trippable! But nested instructions, e.g. add(sqrt(x), y), cannot be -// round-triped without modification. +// round-tripped without modification. struct NonRoundtripTestData { std::string test_name; std::string input_module_string; @@ -461,6 +472,96 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 } +)" +}, +// composite call +{ +"CompositeCall", +R"(HloModule CompositeCall, entry_computation_layout={()->f32[]} + +%add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) +} + +ENTRY %CompositeCall.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} +} + +)" +}, +// composite call with extra frontend attributes +{ +"CompositeCallWithExtraFrontendAttributes", +R"(HloModule CompositeCall, entry_computation_layout={()->f32[]} + +%add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) +} + +ENTRY %CompositeCall.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1",foo="bar"} +} + +)" +}, +// composite call optional composite.attributes and composite.version +{ +"CompositeCallOptionalAttributesAndVersion", +R"(HloModule CompositeCall, entry_computation_layout={()->f32[]} + +%add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) +} + +ENTRY %CompositeCall.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar"} +} + +)" +}, +// composite call optional composite.attributes +{ +"CompositeCallOptionalAttributes", +R"(HloModule CompositeCall, entry_computation_layout={()->f32[]} + +%add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) +} + +ENTRY %CompositeCall.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.name="foo.bar",composite.version="1"} +} + +)" +}, +// composite call optional composite.version +{ +"CompositeCallOptionalVersion", +R"(HloModule CompositeCall, entry_computation_layout={()->f32[]} + +%add (x: f32[]) -> f32[] { + %x = f32[] parameter(0) + %constant = f32[] constant(2) + ROOT %z = f32[] add(f32[] %x, f32[] %constant) +} + +ENTRY %CompositeCall.v2 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant.1), to_apply=%add, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar"} +} + )" }, // CustomCall with backend_config. @@ -1059,6 +1160,18 @@ ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5] ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true } +)" +}, +{ +"BatchGather", +R"(HloModule StringifyGather, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0})->f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0}} + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46,512], start_indices: s64[10,9,8,7,5,512]) -> f32[10,9,8,7,30,29,28,27,26,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %start_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, operand_batching_dims={5}, start_indices_batching_dims={5}, index_vector_dim=4, slice_sizes={30,29,28,27,26,1} +} + )" }, { @@ -1078,6 +1191,25 @@ ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3 } +)" +}, +{ +"BatchScatter", +R"(HloModule StringifyScatter, entry_computation_layout={(f32[50,49,48,47,46,512]{5,4,3,2,1,0}, s64[10,9,8,7,5,512]{5,4,3,2,1,0}, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0})->f32[50,49,48,47,46,512]{5,4,3,2,1,0}} + +%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) +} + +ENTRY %Scatter (input_tensor: f32[50,49,48,47,46,512], scatter_indices: s64[10,9,8,7,5,512], updates: f32[10,9,8,7,30,29,28,27,26,512]) -> f32[50,49,48,47,46,512] { + %input_tensor = f32[50,49,48,47,46,512]{5,4,3,2,1,0} parameter(0) + %scatter_indices = s64[10,9,8,7,5,512]{5,4,3,2,1,0} parameter(1) + %updates = f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} parameter(2) + ROOT %scatter = f32[50,49,48,47,46,512]{5,4,3,2,1,0} scatter(f32[50,49,48,47,46,512]{5,4,3,2,1,0} %input_tensor, s64[10,9,8,7,5,512]{5,4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26,512]{9,8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, input_batching_dims={5}, scatter_indices_batching_dims={5}, index_vector_dim=4, to_apply=%add_F32.v3 +} + )" }, { @@ -1390,6 +1522,21 @@ ENTRY %test (p: f32[100]) -> u32[100] { ROOT %root = u32[100]{0} bitcast-convert(f32[100]{0} %p), metadata={op_type="a" op_name="b" source_file="c" source_line=1 profile_type={1} deduplicated_name="d" preserve_layout=true} } +)" +}, + +{ +"OriginalValue", +R"(HloModule test, entry_computation_layout={(f32[], f32[3]{0}, f32[2,3]{1,0})->((f32[], f32[3]{0}), f32[2,3]{1,0})} + +ENTRY %test (v1: f32[], v2: f32[3], v3: f32[2,3]) -> ((f32[], f32[3]), f32[2,3]) { + %v1 = f32[] parameter(0), original_value={{"v1"}} + %v2 = f32[3]{0} parameter(1), original_value={{"v2"}} + %tuple = (f32[], f32[3]{0}) tuple(f32[] %v1, f32[3]{0} %v2), original_value={({"v1"}, {"v2"})} + %v3 = f32[2,3]{1,0} parameter(2), original_value={{"v3"}} + ROOT %nested_tuple = ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple((f32[], f32[3]{0}) %tuple, f32[2,3]{1,0} %v3), original_value={(({"v1"}, {"v2"}), {"v3"})} +} + )" }, }); @@ -5360,5 +5507,20 @@ TEST_F(HloParserTest, ReplicaIdWithLayout) { .empty()); } +TEST_F(HloParserTest, OriginalValueWithoutShape) { + const std::string hlo_string = R"(HloModule test + +ENTRY %test { + %a = f32[2,10]{1,0} parameter(0), original_value={{"a"}} + ROOT %v = abs(%a), original_value={{"v"}} +} + + +)"; + EXPECT_THAT(ParseAndReturnUnverifiedModule(hlo_string).status(), + tsl::testing::StatusIs(tsl::error::INVALID_ARGUMENT, + HasSubstr("expects instruction shape"))); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/hlo_pass_pipeline_test.cc b/third_party/xla/xla/service/hlo_pass_pipeline_test.cc index d5ad880f72c7a7..502406bb54d1fc 100644 --- a/third_party/xla/xla/service/hlo_pass_pipeline_test.cc +++ b/third_party/xla/xla/service/hlo_pass_pipeline_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_rematerialization_test.cc b/third_party/xla/xla/service/hlo_rematerialization_test.cc index b30cf8293e48e9..c3a945345b3101 100644 --- a/third_party/xla/xla/service/hlo_rematerialization_test.cc +++ b/third_party/xla/xla/service/hlo_rematerialization_test.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/service/hlo_rematerialization_test_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/hlo_replication_analysis_test.cc b/third_party/xla/xla/service/hlo_replication_analysis_test.cc index 4cb5b9b8c43792..e57e7112226072 100644 --- a/third_party/xla/xla/service/hlo_replication_analysis_test.cc +++ b/third_party/xla/xla/service/hlo_replication_analysis_test.cc @@ -23,8 +23,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/hlo_runner_pjrt.cc b/third_party/xla/xla/service/hlo_runner_pjrt.cc index ccb239aee5f351..3965bf61870f3a 100644 --- a/third_party/xla/xla/service/hlo_runner_pjrt.cc +++ b/third_party/xla/xla/service/hlo_runner_pjrt.cc @@ -369,7 +369,9 @@ absl::StatusOr> HloRunnerPjRt::CreateExecutable( CreateExecutable(module.get(), compile_options)); auto executable = std::make_unique( - std::shared_ptr(std::move(module)), pjrt_executable.release()); + std::shared_ptr( + std::move(pjrt_executable->GetHloModules().value()[0])), + pjrt_executable.release()); std::unique_ptr exec = static_cast>(executable.release()); diff --git a/third_party/xla/xla/service/hlo_schedule_test.cc b/third_party/xla/xla/service/hlo_schedule_test.cc index 4ba1a982def9ef..4f96b30498b1c6 100644 --- a/third_party/xla/xla/service/hlo_schedule_test.cc +++ b/third_party/xla/xla/service/hlo_schedule_test.cc @@ -31,9 +31,9 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/hlo_unstacker.cc b/third_party/xla/xla/service/hlo_unstacker.cc index 024a41b0c48417..21d0eb9d42a27f 100644 --- a/third_party/xla/xla/service/hlo_unstacker.cc +++ b/third_party/xla/xla/service/hlo_unstacker.cc @@ -54,6 +54,7 @@ namespace { // TODO: b/352400145 - Unify the patterns, handlers and their type into a class // or struct. enum class PatternType { + DSFusionNoBitcastPattern, DSFusionPattern, NestedDSFusionPattern, Other, @@ -61,6 +62,8 @@ enum class PatternType { static std::string PatternTypeToString(PatternType pattern_type) { switch (pattern_type) { + case PatternType::DSFusionNoBitcastPattern: + return "DSFusionNoBitcastPattern"; case PatternType::DSFusionPattern: return "DSFusionPattern"; case PatternType::NestedDSFusionPattern: @@ -97,7 +100,8 @@ struct PatternInfo { // information for unstacking that is fixed across different unstacker // instastances. struct UnstackerMetadata { - static absl::StatusOr Create(HloModule* module) { + static absl::StatusOr Create( + HloModule* module, std::function unfuse_slice) { UnstackerMetadata metadata; TF_ASSIGN_OR_RETURN( bool prepared, @@ -111,6 +115,7 @@ struct UnstackerMetadata { metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config; metadata.bodies[instr->while_body()] = instr; } + metadata.unfuse_slice = unfuse_slice; return metadata; } absl::flat_hash_map unrollable_loop_bodies; @@ -123,6 +128,7 @@ struct UnstackerMetadata { const UnstackerMetadata&, const HloInstruction*, int64_t)>, std::function>> custom_handlers; + std::function unfuse_slice; }; // Performs the two-step unstacking. Each instance of this class is responsible @@ -198,7 +204,7 @@ class UnstackerTransformer { return {}; } - const UnstackerMetadata& GetMetadata() { return metadata_; } + const UnstackerMetadata& GetMetadata() const { return metadata_; } std::vector& GetUnstackedInstructions() { return unstacked_instrs_; @@ -440,9 +446,18 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, // later prefetched using async-slice by MSA. For other patterns, we // resort to the original unstacking computation until we find benefit in // doing otherwise. + HloInstruction* slice = nullptr; if (unstacker.GetPatternType() == PatternType::DSFusionPattern || - unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { - HloInstruction* dynamic_slice = root_instr->mutable_operand(0); + unstacker.GetPatternType() == PatternType::NestedDSFusionPattern || + unstacker.GetPatternType() == PatternType::DSFusionNoBitcastPattern) { + HloInstruction* dynamic_slice = nullptr; + if (unstacker.GetPatternType() == PatternType::DSFusionPattern || + unstacker.GetPatternType() == PatternType::NestedDSFusionPattern) { + dynamic_slice = root_instr->mutable_operand(0); + } else if (unstacker.GetPatternType() == + PatternType::DSFusionNoBitcastPattern) { + dynamic_slice = root_instr; + } std::vector new_start_indices; new_start_indices.reserve(dynamic_slice->shape().rank()); std::vector new_limit_indices; @@ -458,25 +473,22 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker, dynamic_slice->mutable_operand(0)->shape().dimensions(j)); new_strides.push_back(1); } - HloInstruction* slice = - while_instr->AddInstruction(HloInstruction::CreateSlice( - dynamic_slice->shape(), old_while_input, new_start_indices, - new_limit_indices, new_strides)); - - slices.push_back(slice); - } else { + slice = while_instr->AddInstruction(HloInstruction::CreateSlice( + dynamic_slice->shape(), old_while_input, new_start_indices, + new_limit_indices, new_strides)); + } + if (slice == nullptr || !unstacker.GetMetadata().unfuse_slice(slice)) { std::vector operands = { old_while_input, while_instr->AddInstruction(MakeScalarConstantWithShape( unstacking_computation->parameter_instruction(1)->shape(), i))}; - HloInstruction* slice = - while_instr->AddInstruction(HloInstruction::CreateFusion( - slice_shape, HloInstruction::FusionKind::kLoop, operands, - while_instr->GetModule()->AddEmbeddedComputation( - unstacking_computation->Clone()), - "hoisted")); - slices.push_back(slice); + slice = while_instr->AddInstruction(HloInstruction::CreateFusion( + slice_shape, HloInstruction::FusionKind::kLoop, operands, + while_instr->GetModule()->AddEmbeddedComputation( + unstacking_computation->Clone()), + "hoisted")); } + slices.push_back(slice); } } HloInstruction* new_operand_element = @@ -778,14 +790,58 @@ absl::Status UnstackDSFusionPattern( HloInstruction* bitcast = mutable_dynamic_slicing_fusion->AddInstruction( HloInstruction::CreateBitcast(mutable_dynamic_slicing_fusion->shape(), new_operand)); - HloInstruction* bitcast_fusion = - mutable_dynamic_slicing_fusion->AddInstruction( - HloInstruction::CreateFusion(mutable_dynamic_slicing_fusion->shape(), - HloInstruction::FusionKind::kLoop, - bitcast)); + return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( + bitcast); +} + +// This function recognizes fusions with the following pattern: +// fusion(stacked, f(loop_iteration_var)) +// computation { +// p0 = parameter(0) +// p1 = parameter(1) +// ROOT slice = dynamic_slice(p0, p1, zero, ...) +// } +// where f is a function of loop_iteration_var. It indicates that the slicing +// offset is effectively static after unrolling. +std::optional GetDSFusionNoBitcastPattern( + const UnstackerMetadata& metadata, const HloInstruction* instr, + int64_t stacked_operand_idx) { + VLOG(3) << "Checking DSFusionNoBitcast"; + HloInstruction* shape_covering_instr = + GetMostMajorEffectivelyStaticDynamicSliceInFusion(metadata, instr, 2, + stacked_operand_idx); + if (shape_covering_instr == nullptr) { + return std::nullopt; + } + if (instr->fused_instructions_computation()->root_instruction() != + shape_covering_instr) { + return std::nullopt; + } + PatternInfo pattern_info; + pattern_info.type = PatternType::DSFusionNoBitcastPattern; + pattern_info.instr = instr; + const Shape& slice_shape = shape_covering_instr->shape(); + const int64_t num_layers = instr->operand(0)->shape().dimensions(0); + pattern_info.unstacked_shape = + MakeUnstackedShapeFromSlice(slice_shape, num_layers); + pattern_info.unstacking_computation = instr->fused_instructions_computation(); + pattern_info.unstacked_instrs.push_back(instr); + return pattern_info; +} + +absl::Status UnstackDSFusionNoBitcastPattern( + HloInstruction* mutable_dynamic_slicing_fusion, const Shape& slice_shape) { + HloComputation* parent_loop = mutable_dynamic_slicing_fusion->parent(); + + HloInstruction* stacked = mutable_dynamic_slicing_fusion->mutable_operand(0); + HloInstruction* offset = mutable_dynamic_slicing_fusion->mutable_operand(1); + + HloInstruction* new_operand = + parent_loop->AddInstruction(HloInstruction::CreateCustomCall( + slice_shape, {stacked, offset}, "DynamicGte")); return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape( - bitcast_fusion); + new_operand); } // This function recognizes fusions with the following pattern: @@ -1290,7 +1346,8 @@ absl::Status UnstackReduceFusionPattern(HloInstruction* mutable_reduce_fusion, absl::StatusOr HloUnstacker::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { - TF_ASSIGN_OR_RETURN(auto metadata, UnstackerMetadata::Create(module)); + TF_ASSIGN_OR_RETURN(auto metadata, + UnstackerMetadata::Create(module, unfuse_slice_)); // The order of the patterns below is important, as it determines the order // in which the unstacking custom handlers are called. For example, applying // GetDSAndDUSPattern after GetDSFusionPattern would result in patterns of @@ -1310,6 +1367,8 @@ absl::StatusOr HloUnstacker::Run( std::make_pair(GetReduceFusionPattern, UnstackReduceFusionPattern)); metadata.custom_handlers.push_back( std::make_pair(GetNestedDSFusionPattern, UnstackNestedDSFusionPattern)); + metadata.custom_handlers.push_back(std::make_pair( + GetDSFusionNoBitcastPattern, UnstackDSFusionNoBitcastPattern)); std::vector entry_loops; for (HloInstruction* instr : @@ -1365,6 +1424,7 @@ absl::StatusOr HloUnstacker::Run( /*force_unroll=*/true, /*prepare=*/false)); CHECK(unrolled); } + VLOG(3) << "after unstacking \n" << module->ToString(); return true; } diff --git a/third_party/xla/xla/service/hlo_unstacker.h b/third_party/xla/xla/service/hlo_unstacker.h index eaa74ffc003468..222a1e511e6d47 100644 --- a/third_party/xla/xla/service/hlo_unstacker.h +++ b/third_party/xla/xla/service/hlo_unstacker.h @@ -18,6 +18,8 @@ limitations under the License. #include +#include + #include "absl/container/flat_hash_set.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" @@ -79,13 +81,18 @@ class HloUnstacker : public HloModulePass { public: ~HloUnstacker() override = default; - explicit HloUnstacker() = default; + explicit HloUnstacker(std::function unfuse_slice = + [](HloInstruction* instr) { return true; }) + : unfuse_slice_(unfuse_slice) {} absl::string_view name() const override { return "hlo_unstacker"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + std::function unfuse_slice_; }; } // namespace xla diff --git a/third_party/xla/xla/service/hlo_unstacker_test.cc b/third_party/xla/xla/service/hlo_unstacker_test.cc index 84724550052dc1..3b00f9236a1ae7 100644 --- a/third_party/xla/xla/service/hlo_unstacker_test.cc +++ b/third_party/xla/xla/service/hlo_unstacker_test.cc @@ -34,18 +34,18 @@ namespace { using UnstackerTest = HloTestBase; -int64_t GetSliceCountInEntry(HloModule* module) { - int64_t slice_instrs_count = 0; +int64_t GetInstrCountWithOpcodeInEntry(HloModule* module, HloOpcode opcode) { + int64_t instr_with_opcode_count = 0; for (HloInstruction* instr : module->entry_computation()->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kSlice) { - slice_instrs_count++; + if (instr->opcode() == opcode) { + instr_with_opcode_count++; } } - return slice_instrs_count; + return instr_with_opcode_count; } -TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { +TEST_F(UnstackerTest, UnstackDSFusionPattern) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { @@ -63,7 +63,8 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf + %fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) } @@ -80,7 +81,7 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { init = s32[] constant(0) while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body - while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } )"; @@ -90,12 +91,15 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); + // Check that the bitcast is unfused and there are not fusions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), - std::nullopt)); + std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackLoopSingleFusionUser2) { +TEST_F(UnstackerTest, UnstackReduceFusionPattern) { std::string hlo_string = R"( HloModule SimpleLoop dynamic-slice.609.reduce_sub_computation { @@ -148,28 +152,138 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUser2) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), - std::nullopt)); + std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) { +TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcast) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); + // Check that all the fusions are removed. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + +TEST_F(UnstackerTest, UnstackDSFusionPatternNoBitcastKeepFused) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { + %param_0.51117 = s8[3,128,128] parameter(0) + p1 = s32[] parameter(1) + %constant.85694 = s32[] constant(0) + ROOT %dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128] %param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694), dynamic_slice_sizes={1,128,128} + } + + %while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[], bf16[8,128], s8[3,128,128]) { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + p0 = bf16[8,128] get-tuple-element(wide_p), index=1 + p1 = s8[3,128,128] get-tuple-element(wide_p), index=2 + one = s32[] constant(1) + inc = s32[] add(i, one) + %fusion.67830 = s8[1,128,128] fusion(s8[3,128,128] p1, i), kind=kLoop, calls=%fused_computation.slice + bitcast.102 = s8[128,128] bitcast(s8[1,128,128] %fusion.67830) + conv = bf16[8,128] convolution(bf16[8,128] p0, s8[128,128] bitcast.102), dim_labels=bf_io->bf + ROOT out = (s32[], bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1) + } + + %while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] { + wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0) + i = s32[] get-tuple-element(wide_p), index=0 + %constant.12857 = s32[] constant(3) + ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT + } + + ENTRY main { + p0 = s8[3,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(0) + while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0) + while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input), condition=%while.cond , body=%while.body + while_use = s8[3,128,128] get-tuple-element(while.out), index=2 + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + auto unfuse = [](HloInstruction* instruction) { return false; }; + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, + HloUnstacker(unfuse).Run(module.get())); + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 0); + // Check that dynamic-slices are still fused. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 3); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + +TEST_F(UnstackerTest, UnstackDSFusionPatternWithDifferentLayout) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.30.clone (param_0.153: bf16[32,4,64,64,3], param_1.123: s32[]) -> bf16[64,4,64,3] { - %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} parameter(0) + %param_0.153 = bf16[32,4,64,64,3]{2,1,4,3,0} parameter(0) %param_1.123 = s32[]{:T(128)} parameter(1) %constant.227 = s32[]{:T(128)} constant(0) - %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"0","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}]},"used_scoped_memory_configs":[]} - ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} %dynamic-slice.5) + %dynamic-slice.5 = bf16[1,4,64,64,3]{2,1,4,3,0} dynamic-slice(bf16[32,4,64,64,3]{2,1,4,3,0} %param_0.153, s32[]{:T(128)} %param_1.123, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, s32[]{:T(128)} %constant.227, /*index=5*/s32[]{:T(128)} %constant.227), dynamic_slice_sizes={1,4,64,64,3} + ROOT %bitcast.102 = bf16[64,4,64,3]{0,1,3,2} bitcast(bf16[1,4,64,64,3]{2,1,4,3,0} %dynamic-slice.5) } %while.body (wide_param: (s32[], bf16[8,128], bf16[32,4,64,64,3])) -> (s32[], bf16[8,128], bf16[32,4,64,64,3]) { wide_p = (s32[], bf16[8,128], bf16[32,4,64,64,3]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 p0 = bf16[8,128] get-tuple-element(wide_p), index=1 - p1 = bf16[32,4,64,64,3]{2,1,4,3,0:T(4,128)(2,1)} get-tuple-element(wide_p), index=2 + p1 = bf16[32,4,64,64,3]{2,1,4,3,0} get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - %fusion.67830 = bf16[64,4,64,3]{0,1,3,2:T(4,128)(2,1)} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone + %fusion.67830 = bf16[64,4,64,3]{0,1,3,2} fusion(p1, i), kind=kLoop, calls=%fused_computation.30.clone ROOT out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(inc, p0, p1) } @@ -185,7 +299,7 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) { p1 = bf16[8,128] parameter(1) init = s32[] constant(0) while.input = (s32[], bf16[8,128], bf16[32,4,64,64,3]) tuple(init, p1, p0) - while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body + while.out = (s32[], bf16[8,128], bf16[32,4,64,64,3]) while(while.input), condition=%while.cond , body=%while.body while_use = bf16[32,4,64,64,3] get-tuple-element(while.out), index=2 ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 } @@ -195,11 +309,17 @@ TEST_F(UnstackerTest, UnstackLoopSingleFusionUserDifferentLayout) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), + 32); + // Check that dynamic-slices are still fused. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kFusion), + 0); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt)); } -TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPattern) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { @@ -252,14 +372,14 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } // Instead of slicing the entire shape, this test slices only even elements from // the first parameter. -TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDynamicIndex) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDynamicIndex) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice (param_0.51117: s8[6,128,128], p1: s32[]) -> s8[128,128] { @@ -317,7 +437,7 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDynamicIndex) { std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithMultipleIndex) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice.1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { @@ -391,12 +511,12 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserMultipleIndex) { EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. For each unstacked input, we // create 4 slices, 8 in total. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDiffereOperandsOrder) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithDiffereOperandsOrder) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { @@ -449,12 +569,12 @@ TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUserDiffereOperandsOrder) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackLoopMultipleNestedFusionUsersSameUnstackingComps) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternWithSameUnstackingComps) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[128,128] { @@ -525,12 +645,12 @@ TEST_F(UnstackerTest, UnstackLoopMultipleNestedFusionUsersSameUnstackingComps) { TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 3); + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 3); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } -TEST_F(UnstackerTest, NotUnstackLoopMultipleDifferentUnstackingComps) { +TEST_F(UnstackerTest, NotUnstackNestedDSFusionPatternWithSameUnstackingComps) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice.1 (param_0.51117: s8[3,128,128], p1: s32[]) -> s8[1,128,128] { @@ -585,10 +705,10 @@ TEST_F(UnstackerTest, NotUnstackLoopMultipleDifferentUnstackingComps) { EXPECT_FALSE(unstacked); } -TEST_F(UnstackerTest, UnstackMultipleLoops) { +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternSingleNestedLoop) { std::string hlo_string = R"( HloModule SimpleLoop - %fused_computation.slice1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { + %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[4,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) @@ -596,33 +716,33 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } - %fused_computation.inner1 (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { + %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { %param_0.34523 = bf16[8,128] parameter(0) %param_1.30691 = s8[4,128,128] parameter(1) p2 = s32[] parameter(2) - %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice1 + %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf } - %while.body.inner1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1 inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner1 + fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1) } - %while.cond.inner1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(4) ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT } - %while.body1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 param0 = bf16[8,128] get-tuple-element(wide_p), index=1 @@ -630,13 +750,13 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { one = s32[] constant(2) zero = s32[] constant(0) mult = s32[] multiply(i, one) - inner.in.1 = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) - inner.out.1 = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in.1), condition=%while.cond.inner1, body=%while.body.inner1 - fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out.1), index=1 + inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) + inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner + fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1 ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1) } - %while.cond1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(20) @@ -644,7 +764,30 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT } - %fused_computation.slice2 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { + ENTRY main { + weight = s8[4,128,128] parameter(0) + p1 = bf16[8,128] parameter(1) + init = s32[] constant(1) + while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight) + while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body + ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + auto original = module->Clone(); + TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); + EXPECT_TRUE(unstacked); + // Check for the creation of slice instructions. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 4); + EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), + std::nullopt, false)); +} + +TEST_F(UnstackerTest, UnstackNestedDSFusionPatternTwoNestedLoops) { + std::string hlo_string = R"( + HloModule SimpleLoop + %fused_computation.slice1 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[4,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) @@ -652,33 +795,33 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } - %fused_computation.inner2 (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { + %fused_computation.inner1 (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { %param_0.34523 = bf16[8,128] parameter(0) %param_1.30691 = s8[4,128,128] parameter(1) p2 = s32[] parameter(2) - %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice2 + %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice1 ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf } - %while.body.inner2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body.inner1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1 inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner2 + fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner1 ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1) } - %while.cond.inner2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond.inner1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(4) ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT } - %while.body2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 param0 = bf16[8,128] get-tuple-element(wide_p), index=1 @@ -686,13 +829,13 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { one = s32[] constant(2) zero = s32[] constant(0) mult = s32[] multiply(i, one) - inner.in.2 = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) - inner.out.2 = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in.2), condition=%while.cond.inner2, body=%while.body.inner2 - fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out.2), index=1 + inner.in.1 = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) + inner.out.1 = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in.1), condition=%while.cond.inner1, body=%while.body.inner1 + fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out.1), index=1 ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1) } - %while.cond2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond1 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(20) @@ -700,36 +843,7 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) { ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, add), direction=LT } - ENTRY main { - weight = s8[4,128,128] parameter(0) - p1 = bf16[8,128] parameter(1) - init = s32[] constant(1) - while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight) - while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond1 , body=%while.body1 - init2 = s32[] get-tuple-element(while.out), index=0 - second.while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init2, p1, weight) - second.while.out = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2 - out = bf16[8,128] get-tuple-element(while.out), index=1 - second.out = bf16[8,128] get-tuple-element(second.while.out), index=1 - ROOT result = bf16[8,128] add(out, second.out) - } - )"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo_string)); - auto original = module->Clone(); - TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); - EXPECT_TRUE(unstacked); - // Check for the creation of slice instructions. For each loop there is one - // unstacked input that creates 4 slices, in total 8 slices for two loops. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 8); - EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), - std::nullopt, false)); -} - -TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { - std::string hlo_string = R"( - HloModule SimpleLoop - %fused_computation.slice (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { + %fused_computation.slice2 (param_0.51117: s8[4,128,128], p1: s32[]) -> s8[128,128] { %param_0.51117 = s8[4,128,128] parameter(0) p1 = s32[] parameter(1) %constant.85694 = s32[] constant(0) @@ -737,33 +851,33 @@ TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { ROOT %bitcast.31250 = s8[128,128] bitcast(s8[1,128,128] %dynamic-slice.22040) } - %fused_computation.inner (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { + %fused_computation.inner2 (param_0.34523: bf16[8,128], param_1.30691: s8[4,128,128], p2: s32[]) -> bf16[8,128] { %param_0.34523 = bf16[8,128] parameter(0) %param_1.30691 = s8[4,128,128] parameter(1) p2 = s32[] parameter(2) - %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice + %fusion.67830 = s8[128,128] fusion(s8[4,128,128] %param_1.30691, p2), kind=kLoop, calls=%fused_computation.slice2 ROOT %convolution.3447 = bf16[8,128] convolution(bf16[8,128] %param_0.34523, s8[128,128] %fusion.67830), dim_labels=bf_io->bf } - %while.body.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body.inner2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 inner_param_0 = bf16[8,128] get-tuple-element(wide_p), index=1 inner_param_1 = s8[4,128,128] get-tuple-element(wide_p), index=2 one = s32[] constant(1) inc = s32[] add(i, one) - fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner + fusion.conv = bf16[8,128] fusion(inner_param_0, inner_param_1, i), kind=kOutput, calls=%fused_computation.inner2 ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(inc, fusion.conv, inner_param_1) } - %while.cond.inner (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond.inner2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(4) ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[] %constant.12857), direction=LT } - %while.body (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { + %while.body2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> (s32[], bf16[8,128], s8[4,128,128]) { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 param0 = bf16[8,128] get-tuple-element(wide_p), index=1 @@ -771,13 +885,13 @@ TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { one = s32[] constant(2) zero = s32[] constant(0) mult = s32[] multiply(i, one) - inner.in = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) - inner.out = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in), condition=%while.cond.inner, body=%while.body.inner - fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out), index=1 + inner.in.2 = (s32[], bf16[8,128], s8[4,128,128]) tuple(zero, param0, param1) + inner.out.2 = (s32[], bf16[8,128], s8[4,128,128]) while(inner.in.2), condition=%while.cond.inner2, body=%while.body.inner2 + fusion.conv.inner = bf16[8,128] get-tuple-element(inner.out.2), index=1 ROOT out = (s32[], bf16[8,128], s8[4,128,128]) tuple(mult, fusion.conv.inner, param1) } - %while.cond (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { + %while.cond2 (wide_param: (s32[], bf16[8,128], s8[4,128,128])) -> pred[] { wide_p = (s32[], bf16[8,128], s8[4,128,128]) parameter(0) i = s32[] get-tuple-element(wide_p), index=0 %constant.12857 = s32[] constant(20) @@ -790,8 +904,13 @@ TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { p1 = bf16[8,128] parameter(1) init = s32[] constant(1) while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight) - while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond , body=%while.body - ROOT out = bf16[8,128] get-tuple-element(while.out), index=1 + while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond1 , body=%while.body1 + init2 = s32[] get-tuple-element(while.out), index=0 + second.while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init2, p1, weight) + second.while.out = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2 + out = bf16[8,128] get-tuple-element(while.out), index=1 + second.out = bf16[8,128] get-tuple-element(second.while.out), index=1 + ROOT result = bf16[8,128] add(out, second.out) } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, @@ -799,13 +918,14 @@ TEST_F(UnstackerTest, UnstackNestedLoopSingleNestedFusionUser) { auto original = module->Clone(); TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get())); EXPECT_TRUE(unstacked); - // Check for the creation of slice instructions. - EXPECT_EQ(GetSliceCountInEntry(module.get()), 4); + // Check for the creation of slice instructions. For each loop there is one + // unstacked input that creates 4 slices, in total 8 slices for two loops. + EXPECT_EQ(GetInstrCountWithOpcodeInEntry(module.get(), HloOpcode::kSlice), 8); EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original), std::nullopt, false)); } -TEST_F(UnstackerTest, UnstackSingleLoopOnlyWithDSAndDUS) { +TEST_F(UnstackerTest, UnstackDSAndDUSPattern) { std::string hlo_string = R"( HloModule SimpleLoop %fused_computation.slice (param_0.51117: s32[4,3], offset: s32[]) -> s32[3] { @@ -869,7 +989,7 @@ TEST_F(UnstackerTest, UnstackSingleLoopOnlyWithDSAndDUS) { // Unstacking outer loop at index 1 forces to unstacked inner while at index 1 // as well. This is because the output of the outer loop at index 1 is aliased // to the output of the inner while at index 1. -TEST_F(UnstackerTest, UnstackNestedLoopWithDSAndDUS) { +TEST_F(UnstackerTest, UnstackDSAndDUSPatternNestedLoop) { std::string hlo_string = R"( HloModule SimpleLoop @@ -953,7 +1073,7 @@ TEST_F(UnstackerTest, UnstackNestedLoopWithDSAndDUS) { // Unstacking the first loop at index 1 forces to unstack the second loop at // index 1 as well. -TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUS) { +TEST_F(UnstackerTest, UnstackDSAndDUSPatternLoopFeedingLoop) { std::string hlo_string = R"( HloModule SimpleLoop @@ -970,45 +1090,43 @@ TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUS) { %param_0.51117 = bf16[4,1,8,257,128] parameter(0) offset = s32[] parameter(1) zero = s32[] constant(0) - %dynamic-slice.22040 = bf16[1,1,8,257,128] - dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, - zero, zero), dynamic_slice_sizes={1,1,8,257,128} ROOT %bitcast.31250 = - bf16[1,8,257,128] bitcast(%dynamic-slice.22040) + %dynamic-slice.22040 = bf16[1,1,8,257,128] dynamic-slice(bf16[4,1,8,257,128] %param_0.51117, offset, zero, zero, zero, zero), dynamic_slice_sizes={1,1,8,257,128} + ROOT %bitcast.31250 = bf16[1,8,257,128] bitcast(%dynamic-slice.22040) } first.body { loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 - get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 + get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 constant = bf16[1,8,257,128] constant({...}) sliced = bf16[1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1), kind=kLoop, calls=%fused_computation.slice tmp = bf16[1,8,257,128] add(sliced, sliced) one = s32[] constant(1) - idx = s32[] add(get-tuple-element.1, one) + idx = s32[] add(get-tuple-element.1, one) ROOT out = tuple(idx, get-tuple-element.2) } first.condition { - loop_var.1 = (s32[], bf16[4,1,8,257,128]) - parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), - index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[] - compare(get-tuple-element.1, constant.2), direction=LT + loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.2 = s32[] constant(4) + ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT } next.body { loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) - get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 - get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1),index=0 + get-tuple-element.2 = bf16[4,1,8,257,128] get-tuple-element(loop_var.1), index=1 constant = bf16[1,8,257,128] constant({...}) - update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice + update.sliced = bf16[4,1,8,257,128] fusion(get-tuple-element.2, get-tuple-element.1, constant), kind=kLoop, calls=%fused_computation.update.slice one = s32[] constant(1) - idx = s32[] add(get-tuple-element.1, one) + idx = s32[] add(get-tuple-element.1, one) ROOT out = tuple(idx, update.sliced) } next.condition { - loop_var.1 = (s32[], bf16[4,1,8,257,128]) - parameter(0) get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), - index=0 constant.2 = s32[] constant(4) ROOT less-than = pred[] - compare(get-tuple-element.1, constant.2), direction=LT + loop_var.1 = (s32[], bf16[4,1,8,257,128]) parameter(0) + get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0 + constant.2 = s32[] constant(4) + ROOT less-than = pred[] compare(get-tuple-element.1, constant.2), direction=LT } ENTRY SimpleLoop { @@ -1032,7 +1150,7 @@ TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUS) { EXPECT_TRUE(unstacked); } -TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUSFusionWithPad) { +TEST_F(UnstackerTest, UnstackDUSFusionWithPadPatternLoopFeedingLoop) { std::string hlo_string = R"( HloModule SimpleLoop fused_computation.75.clone { @@ -1107,7 +1225,7 @@ TEST_F(UnstackerTest, UnstackLoopFeedingLoopWithDUSFusionWithPad) { EXPECT_TRUE(unstacked); } -TEST_F(UnstackerTest, UnstackSingleLoopWithDSFusionWithAdd) { +TEST_F(UnstackerTest, UnstackDUSFusionWithAddPattern) { std::string hlo_string = R"( HloModule SimpleLoop diff --git a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc index 67259f886f3383..06ae99051ddcce 100644 --- a/third_party/xla/xla/service/hlo_value_semantics_analysis.cc +++ b/third_party/xla/xla/service/hlo_value_semantics_analysis.cc @@ -1757,6 +1757,7 @@ absl::Status HloValueSemanticsPropagation::HandleConditional( [&](const ShapeIndex& index, const HloValueSemantics* semantics) -> absl::Status { std::vector semantics_vector; + semantics_vector.reserve(semantics_tree_vec.size()); for (size_t i = 0; i < semantics_tree_vec.size(); ++i) { semantics_vector.push_back( *(semantics_tree_vec[i].find(index)->second)); diff --git a/third_party/xla/xla/service/hlo_verifier.cc b/third_party/xla/xla/service/hlo_verifier.cc index ebf03de9e146b0..3a8d958c2ddba4 100644 --- a/third_party/xla/xla/service/hlo_verifier.cc +++ b/third_party/xla/xla/service/hlo_verifier.cc @@ -35,6 +35,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" @@ -1322,6 +1323,34 @@ absl::Status ShapeVerifier::HandleCall(HloInstruction* call) { for (int64_t i = 0; i < call->to_apply()->num_parameters(); ++i) { TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i)); } + if (call->is_composite()) { + TF_RET_CHECK(call->has_frontend_attributes()) + << "A composite call op must have frontend attributes"; + auto map = call->frontend_attributes().map(); + if (auto name = map.find("composite.name"); + name == map.end() || name->second.empty()) { + return InvalidArgument( + "A composite call op must have frontend attributes with key " + "composite.name whose value is non-empty"); + } + if (auto attributes = map.find("composite.attributes"); + attributes != map.end() && attributes->second.empty()) { + return InvalidArgument( + "A composite call op must have frontend attributes with key " + "composite.attributes whose value is default: {} or non-empty"); + } + if (auto version_str = map.find("composite.version"); + version_str != map.end()) { + int64_t version = 0; + if (!absl::SimpleAtoi(version_str->second, &version) || version < 0) { + return InvalidArgument( + "A composite call op must have frontend attributes with a " + "composite.version whose value is a non-negative integer but got: " + "%s", + version_str->second); + } + } + } // The shape of kCall should match the shape of the computation it calls. return CheckShape(call, call->to_apply()->root_instruction()->shape()); } @@ -1920,6 +1949,26 @@ absl::Status ShapeVerifier::CheckShape( } return ShapesSame(instruction->shape(), inferred_shape, equal); } + case HloOpcode::kCopy: { + // Disallow host offloading copies which change FpPrecision. + if (opts_.IsLayoutSensitive()) { + if (instruction->shape().has_layout() && + inferred_shape.has_layout()) { + int64_t instruction_memory_space = + instruction->shape().layout().memory_space(); + int64_t operand_memory_space = + inferred_shape.layout().memory_space(); + if (instruction_memory_space != operand_memory_space && + (instruction_memory_space == Layout::kHostMemorySpace || + operand_memory_space == Layout::kHostMemorySpace)) { + // Is a host->device copy for a device->host copy. + return Shape::Equal().IgnoreMemorySpaceInLayout()( + instruction->shape(), inferred_shape); + } + } + } + [[fallthrough]]; + } // We allow arbitrary layout and f32->bf16 transformations on all other // instructions, although this may be made more strict pending discussion @@ -2907,6 +2956,15 @@ class InstructionVerifier : public DfsHloVisitorWithDefault { } } + if (instruction->has_to_apply() && + instruction->to_apply()->execution_thread() != + instruction->parent()->execution_thread()) { + return Internal( + "%s top_apply computation execution thread does not match (%s vs %s)", + instruction->name(), instruction->to_apply()->execution_thread(), + instruction->parent()->execution_thread()); + } + return absl::OkStatus(); } diff --git a/third_party/xla/xla/service/hlo_verifier_test.cc b/third_party/xla/xla/service/hlo_verifier_test.cc index 874077fce3e6b2..b7055e9ed31649 100644 --- a/third_party/xla/xla/service/hlo_verifier_test.cc +++ b/third_party/xla/xla/service/hlo_verifier_test.cc @@ -43,9 +43,9 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla.pb.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/platform.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -83,6 +83,15 @@ class HloVerifierTestLayoutSensitive : public HloTestBase { LayoutAssignment::InstructionCanChangeLayout) {} }; +class HloVerifierTestLayoutSensitiveAndAllowMixedPrecision + : public HloTestBase { + public: + HloVerifierTestLayoutSensitiveAndAllowMixedPrecision() + : HloTestBase(/*verifier_layout_sensitive=*/true, + /*allow_mixed_precision_in_hlo_verifier=*/true, + LayoutAssignment::InstructionCanChangeLayout) {} +}; + class HloVerifierTestLayoutFusion : public HloTestBase { public: HloVerifierTestLayoutFusion() @@ -216,8 +225,164 @@ TEST_F(HloVerifierTest, CheckCallThreadMismatch) { auto status = verifier().Run(module.get()).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - HasSubstr("expects parent computation thread name same as called " - "computation's thread name")); + HasSubstr("mycall top_apply computation execution thread does " + "not match (parallel_thread vs main)")); +} + +TEST_F(HloVerifierTest, CompositeCall) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.version="1"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CompositeCallMissingFrontendAttributes) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("A composite call op must have frontend attributes")); +} + +TEST_F(HloVerifierTest, CompositeCallOptionalAttributesAndVersion) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CompositeCallOptionalAttributes) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.name="foo.bar",composite.version="1"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CompositeCallMissingName) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.version="1"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("A composite call op must have frontend attributes " + "with key composite.name whose value is non-empty")); +} + +TEST_F(HloVerifierTest, CompositeCallOptionalVersion) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + EXPECT_TRUE(status.ok()); +} + +TEST_F(HloVerifierTest, CompositeCallNonNegativeVersion) { + constexpr absl::string_view hlo = R"( + HloModule Module + + add_n { + x = f32[] parameter(0) + constant = f32[] constant(2) + ROOT z = f32[] add(f32[] x, f32[] constant) + } + + ENTRY entry { + constant = f32[] constant(42) + ROOT mycall = f32[] call(constant), is_composite=true, to_apply=add_n, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="-1"} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT( + status.message(), + HasSubstr("A composite call op must have frontend attributes with a " + "composite.version whose value is a non-negative integer")); } TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) { @@ -2000,10 +2165,10 @@ TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) { )"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(kModuleStr)); - EXPECT_THAT(verifier().Run(module.get()).status().message(), - HasSubstr("Nested computations expects same computation's thread " - "name: parallel_thread vs main, in called computation " - "`add` vs caller computation `fused_computation`")); + EXPECT_THAT( + verifier().Run(module.get()).status().message(), + HasSubstr("crs0 top_apply computation execution thread does not match " + "(parallel_thread vs main)")); } TEST_F(HloVerifierTest, AllReduceVerifier) { @@ -2639,8 +2804,8 @@ TEST_F(HloVerifierTest, VerifyCustomCallThread) { .status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.message(), - HasSubstr("expects parent computation thread name same as called " - "computation's thread name")); + HasSubstr("custom top_apply computation execution thread does " + "not match (parallel_thread vs main)")); } TEST_F(HloVerifierTest, CheckWhileThread) { @@ -3133,6 +3298,49 @@ TEST_F(HloVerifierTestLayoutSensitive, "memory space from device to host")); } +TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision, + HostOffloadingCopyCannotChangeType) { + const char* const hlo_string = R"( +HloModule m + +ENTRY main { + param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0) + copy = bf16[1024,1024]{1,0:T(8,128)} copy(param) + ROOT dot = f32[1024,1024]{1,0:T(8,128)} dot(copy, copy), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Expected instruction to have shape equal to " + "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is " + "bf16[1024,1024]{1,0:T(8,128)}")); +} + +TEST_F(HloVerifierTestLayoutSensitiveAndAllowMixedPrecision, + HostOffloadingCopyCannotChangeLayout) { + const char* const hlo_string = R"( +HloModule m + +ENTRY main { + param = f32[1024,1024]{1,0:T(8,128)S(5)} parameter(0) + ROOT copy = f32[1024,1024]{0,1:T(8,128)} copy(param) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.message(), + HasSubstr("Expected instruction to have shape equal to " + "f32[1024,1024]{1,0:T(8,128)S(5)}, actual shape is " + "f32[1024,1024]{0,1:T(8,128)}")); +} + TEST_F(HloVerifierTestLayoutSensitive, MismatchedMinorToMajorSizeAndDimensionSize) { const char* const hlo_string = R"( diff --git a/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc b/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc index d054cd115c1370..fd85488a2239ec 100644 --- a/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc +++ b/third_party/xla/xla/service/host_memory_transfer_asyncifier_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/host_offload_legalize_test.cc b/third_party/xla/xla/service/host_offload_legalize_test.cc index 096f9a10560b44..0322c80a7504cf 100644 --- a/third_party/xla/xla/service/host_offload_legalize_test.cc +++ b/third_party/xla/xla/service/host_offload_legalize_test.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace m = ::xla::match; diff --git a/third_party/xla/xla/service/host_offload_utils.cc b/third_party/xla/xla/service/host_offload_utils.cc new file mode 100644 index 00000000000000..203c08e9d0c39a --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils.cc @@ -0,0 +1,243 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/call_graph.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/shape_util.h" +#include "xla/util.h" + +namespace xla { +namespace host_offload_utils { + +namespace { + +using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; +using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget; + +bool CustomCallReusesBuffer(const HloInstruction* custom_call, + int64_t operand_index) { + if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget || + custom_call->custom_call_target() == kMoveToHostCustomCallTarget) { + // Does not define a new buffer. + return true; + } + // Check the custom call's output_to_operand_aliasing. + const std::vector>>& + aliases = custom_call->output_operand_aliasing(); + for (const std::pair>& alias : + aliases) { + int64_t alias_operand_index = alias.second.first; + if (alias_operand_index == operand_index) { + // This operand aliases with the output. + return true; + } + } + // By default, assume custom calls define new buffers. + return false; +} + +} // namespace + +absl::StatusOr> GetSuccessors( + const InstructionAndShapeIndex& instruction_and_shape_index) { + std::vector result; + HloInstruction* instruction = instruction_and_shape_index.instruction; + if (instruction->IsRoot()) { + // Successor of the root is the call instruction(s). + std::unique_ptr call_graph = + CallGraph::Build(instruction->GetModule()); + auto callers = call_graph->GetComputationCallers(instruction->parent()); + for (HloInstruction* caller : callers) { + result.push_back({caller, instruction_and_shape_index.shape_index}); + } + } + for (HloInstruction* user : instruction->users()) { + if (user->opcode() == HloOpcode::kTuple) { + auto operand_indices = user->OperandIndices(instruction); + for (const auto i : operand_indices) { + auto tmp_shape_index = instruction_and_shape_index.shape_index; + tmp_shape_index.push_back(i); + result.push_back({user, std::move(tmp_shape_index)}); + } + } else if (user->opcode() == HloOpcode::kGetTupleElement) { + ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index; + const auto index = tmp_shape_index.front(); + if (index == user->tuple_index()) { + // This GTE is for the buffer we're tracking. + tmp_shape_index.pop_front(); + result.push_back({user, std::move(tmp_shape_index)}); + } + } else if (user->opcode() == HloOpcode::kCall) { + auto operand_indices = user->OperandIndices(instruction); + CHECK(user->called_computations().size() == 1) + << "Expect call to only have one called computation."; + for (const auto i : operand_indices) { + HloComputation* called_computation = + user->called_computations().front(); + HloInstruction* parameter_instruction = + called_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kWhile) { + auto operand_indices = user->OperandIndices(instruction); + HloComputation* while_body_computation = user->while_body(); + HloComputation* while_condition_computation = user->while_condition(); + for (const auto i : operand_indices) { + HloInstruction* parameter_instruction = + while_body_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + + HloInstruction* condition_instruction = + while_condition_computation->parameter_instruction(i); + result.push_back( + {condition_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kAsyncStart) { + auto operand_indices = user->OperandIndices(instruction); + CHECK(user->called_computations().size() == 1) + << "Expect async-start to only have one called computation."; + for (const auto i : operand_indices) { + HloComputation* called_computation = + user->called_computations().front(); + HloInstruction* parameter_instruction = + called_computation->parameter_instruction(i); + result.push_back( + {parameter_instruction, instruction_and_shape_index.shape_index}); + } + } else if (user->opcode() == HloOpcode::kCustomCall) { + const auto operand_indices = user->OperandIndices(instruction); + // TODO(b/342650757): Rather than a boolean indicating whether the + // instruction reuses the buffer, return the shape index of the output + // that the operand aliases with. + bool found_one = false; + for (const auto i : operand_indices) { + if (CustomCallReusesBuffer(user, i)) { + if (found_one) { + return absl::InternalError( + "Found multiple operands of a custom call that reuse the same " + "output buffer."); + } + result.push_back({user, instruction_and_shape_index.shape_index}); + found_one = true; + } + } + } else { + result.push_back({user, instruction_and_shape_index.shape_index}); + } + } + return result; +} + +std::vector GetPredecessors( + const InstructionAndShapeIndex& instruction_and_shape_index) { + std::vector result; + HloInstruction* instruction = instruction_and_shape_index.instruction; + if (instruction->opcode() == HloOpcode::kGetTupleElement) { + const int64_t index = instruction->tuple_index(); + auto tmp_shape_index = instruction_and_shape_index.shape_index; + tmp_shape_index.push_front(index); + result.push_back({instruction->mutable_operand(0), tmp_shape_index}); + } else if (instruction->opcode() == HloOpcode::kTuple) { + CHECK(!instruction_and_shape_index.shape_index.empty()) + << "Did not store an index before encountering a tuple."; + auto tmp_shape_index = instruction_and_shape_index.shape_index; + const int64_t index = tmp_shape_index.front(); + tmp_shape_index.pop_front(); + result.push_back({instruction->mutable_operand(index), tmp_shape_index}); + } else if (instruction->opcode() == HloOpcode::kCall) { + // Predecessor of a call is its computation's root instruction. + CHECK(instruction->called_computations().size() == 1) + << "Expect call to only have one called computation."; + HloComputation* called_computation = + instruction->called_computations().front(); + result.push_back({called_computation->root_instruction(), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kParameter) { + std::unique_ptr call_graph = + CallGraph::Build(instruction->GetModule()); + auto callers = call_graph->GetComputationCallers(instruction->parent()); + for (HloInstruction* caller : callers) { + result.push_back( + {caller->mutable_operand(instruction->parameter_number()), + instruction_and_shape_index.shape_index}); + } + } else if (instruction->opcode() == HloOpcode::kDynamicSlice) { + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } else if (instruction->opcode() == HloOpcode::kWhile) { + HloComputation* while_body_computation = instruction->while_body(); + result.push_back({while_body_computation->root_instruction(), + instruction_and_shape_index.shape_index}); + } else { + CHECK(instruction->operand_count() == 1) << absl::StreamFormat( + "Expecting instruction %s to have 1 operand, but it has %d.", + instruction->name(), instruction->operand_count()); + result.push_back({instruction->mutable_operand(0), + instruction_and_shape_index.shape_index}); + } + return result; +} + +bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) { + static constexpr std::array allowed_opcodes = { + HloOpcode::kGetTupleElement, + HloOpcode::kBitcast, + HloOpcode::kTuple, + HloOpcode::kCall, + HloOpcode::kWhile, + HloOpcode::kParameter, + HloOpcode::kOptimizationBarrier, + HloOpcode::kAsyncStart, + HloOpcode::kAsyncDone, + HloOpcode::kCustomCall}; + return absl::c_linear_search(allowed_opcodes, instruction->opcode()); +} + +bool operator==(const InstructionAndShapeIndex& lhs, + const InstructionAndShapeIndex& rhs) { + return lhs.instruction == rhs.instruction && + lhs.shape_index == rhs.shape_index; +} + +std::string InstructionAndShapeIndex::ToString() const { + return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(), + shape_index.ToString()); +} + +} // namespace host_offload_utils +} // namespace xla diff --git a/third_party/xla/xla/service/host_offload_utils.h b/third_party/xla/xla/service/host_offload_utils.h new file mode 100644 index 00000000000000..22e1c359dca83e --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils.h @@ -0,0 +1,101 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ +#define XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/literal_util.h" +#include "xla/service/hlo_buffer.h" +#include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/pattern_matcher.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace host_offload_utils { + +struct InstructionAndShapeIndex { + explicit InstructionAndShapeIndex(HloInstruction* instruction) + : instruction(instruction) {} + InstructionAndShapeIndex(HloInstruction* instruction, ShapeIndex shape_index) + : instruction(instruction), shape_index(shape_index) {} + HloInstruction* instruction; + ShapeIndex shape_index; + std::string ToString() const; + + template + static H Hash(H h, const InstructionAndShapeIndex& i) { + h = H::combine(std::move(h), i.instruction); + h = H::combine(std::move(h), i.shape_index); + return std::move(h); + } + + template + friend H AbslHashValue(H h, const InstructionAndShapeIndex& i) { + return InstructionAndShapeIndex::Hash(std::move(h), i); + } +}; + +bool operator==(const InstructionAndShapeIndex& lhs, + const InstructionAndShapeIndex& rhs); + +// If an instruction's user is a call, we descend into the call first. +// Eventually, a later invocation of this function while walking the graph will +// return the call itself as a successor of the ROOT instruction of the +// computation. +absl::StatusOr> GetSuccessors( + const InstructionAndShapeIndex& instruction_and_shape_index); + +// If an instruction's operand is a call, return the call now. A follow up call +// of this function on that call returns the ROOT. Eventually, once the given +// instruction is a parameter, the returned predecessor will be the appropriate +// operand of the call (not the call itself, since we already returned it). +std::vector GetPredecessors( + const InstructionAndShapeIndex& instruction_and_shape_index); + +// Returns true if the instruction is allowed to be in the +// middle of a pure memory offload path. +bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction); + +} // namespace host_offload_utils +} // namespace xla + +#endif // XLA_SERVICE_HOST_OFFLOAD_UTILS_H_ diff --git a/third_party/xla/xla/service/host_offload_utils_test.cc b/third_party/xla/xla/service/host_offload_utils_test.cc new file mode 100644 index 00000000000000..6f38b45ab09544 --- /dev/null +++ b/third_party/xla/xla/service/host_offload_utils_test.cc @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/host_offload_utils.h" + +#include +#include + +#include +#include "xla/shape_util.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace host_offload_utils { +namespace { + +class HostOffloadUtilsTest : public HloTestBase {}; + +TEST_F(HostOffloadUtilsTest, SimpleGetSuccessorsGetPredecessorsTest) { + const std::string& hlo_string = R"( +HloModule my_module +ENTRY main { + data_param = f32[1,2048,2048] parameter(0) + index_param = s32[] parameter(1) + constant_f32_0 = f32[] constant(0) + constant_s32_0 = s32[] constant(0) + broadcast = f32[2,2048,2048] broadcast(constant_f32_0), dimensions={} + offload_custom_call = f32[1,2048,2048] custom-call(data_param), custom_call_target="MoveToHost" + dynamic_update_slice = f32[2,2048,2048] dynamic-update-slice(broadcast, offload_custom_call, index_param, constant_s32_0, constant_s32_0) + dynamic_slice = f32[1,2048,2048] dynamic-slice(dynamic_update_slice, index_param, constant_s32_0, constant_s32_0), dynamic_slice_sizes={1,2048,2048} + ROOT load_custom_call = f32[1,2048,2048] custom-call(dynamic_slice), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* data_param = FindInstruction(module.get(), "data_param"); + ASSERT_NE(data_param, nullptr); + HloInstruction* offload_custom_call = + FindInstruction(module.get(), "offload_custom_call"); + ASSERT_NE(offload_custom_call, nullptr); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector succ, + GetSuccessors(InstructionAndShapeIndex(data_param, {}))); + std::vector expected_succ = { + InstructionAndShapeIndex(offload_custom_call, {})}; + EXPECT_EQ(succ, expected_succ); + + std::vector pred = + GetPredecessors(InstructionAndShapeIndex(offload_custom_call, {})); + std::vector expected_pred = { + InstructionAndShapeIndex(data_param, {})}; + EXPECT_EQ(pred, expected_pred); +} + +TEST_F(HostOffloadUtilsTest, ComputationGetSuccessorsGetPredecessorsTest) { + const std::string& hlo_string = R"( +HloModule my_module +other_computation { + param_0 = f32[2048] parameter(0) + param_1 = f32[2048] parameter(1) + ROOT tuple = (f32[2048], f32[2048]) tuple(param_0, param_1) +} +ENTRY main { + data_param = f32[2048] parameter(0) + other_param = f32[2048] parameter(1) + offload_custom_call = f32[2048] custom-call(data_param), custom_call_target="MoveToHost" + call = (f32[2048], f32[2048]) call(offload_custom_call, other_param), to_apply=other_computation + gte_0 = f32[2048] get-tuple-element(call), index=0 + gte_1 = f32[2048] get-tuple-element(call), index=1 + ROOT load_custom_call = f32[2048] custom-call(gte_0), custom_call_target="MoveToDevice" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + HloInstruction* call = FindInstruction(module.get(), "call"); + ASSERT_NE(call, nullptr); + HloInstruction* gte_0 = FindInstruction(module.get(), "gte_0"); + ASSERT_NE(gte_0, nullptr); + HloInstruction* tuple = FindInstruction(module.get(), "tuple"); + ASSERT_NE(tuple, nullptr); + + TF_ASSERT_OK_AND_ASSIGN(std::vector succ, + GetSuccessors(InstructionAndShapeIndex(call, {0}))); + std::vector expected_succ = { + InstructionAndShapeIndex(gte_0, {})}; + EXPECT_EQ(succ, expected_succ); + + std::vector pred = + GetPredecessors(InstructionAndShapeIndex(call, {0})); + std::vector expected_pred = { + InstructionAndShapeIndex(tuple, {0})}; + EXPECT_EQ(pred, expected_pred); +} + +} // namespace +} // namespace host_offload_utils +} // namespace xla diff --git a/third_party/xla/xla/service/host_offloader.cc b/third_party/xla/xla/service/host_offloader.cc index 95c97e94c704da..7e1971302c981b 100644 --- a/third_party/xla/xla/service/host_offloader.cc +++ b/third_party/xla/xla/service/host_offloader.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/service/hlo_cse.h" #include "xla/service/hlo_value.h" #include "xla/service/host_memory_offload_annotations.h" +#include "xla/service/host_offload_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape.h" #include "xla/shape_util.h" @@ -58,6 +59,7 @@ namespace { using ::xla::host_memory_offload_annotations::kMoveToDeviceCustomCallTarget; using ::xla::host_memory_offload_annotations::kMoveToHostCustomCallTarget; +using ::xla::host_offload_utils::InstructionAndShapeIndex; void SetMemorySpace(Shape* shape, int64_t memory_space_color) { CHECK(shape->has_layout()); @@ -85,210 +87,8 @@ bool SetBuffersToMemorySpaceColor( return changed; } -bool CustomCallReusesBuffer(const HloInstruction* custom_call, - int64_t operand_index) { - if (custom_call->custom_call_target() == kMoveToDeviceCustomCallTarget || - custom_call->custom_call_target() == kMoveToHostCustomCallTarget) { - // Does not define a new buffer. - return true; - } - // Check the custom call's output_to_operand_aliasing. - const std::vector>>& - aliases = custom_call->output_operand_aliasing(); - for (const std::pair>& alias : - aliases) { - int64_t alias_operand_index = alias.second.first; - if (alias_operand_index == operand_index) { - // This operand aliases with the output. - return true; - } - } - // By default, assume custom calls define new buffers. - return false; -} - -// If an instruction's user is a call, we descend into the call first. -// Eventually, a later invocation of this function while walking the graph will -// return the call itself as a successor of the ROOT instruction of the -// computation. -absl::StatusOr> GetSuccessors( - const InstructionAndShapeIndex& instruction_and_shape_index) { - std::vector result; - HloInstruction* instruction = instruction_and_shape_index.instruction; - if (instruction->IsRoot()) { - // Successor of the root is the call instruction(s). - std::unique_ptr call_graph = - CallGraph::Build(instruction->GetModule()); - auto callers = call_graph->GetComputationCallers(instruction->parent()); - for (HloInstruction* caller : callers) { - result.push_back({caller, instruction_and_shape_index.shape_index}); - } - } - for (HloInstruction* user : instruction->users()) { - if (user->opcode() == HloOpcode::kTuple) { - auto operand_indices = user->OperandIndices(instruction); - for (const auto i : operand_indices) { - auto tmp_shape_index = instruction_and_shape_index.shape_index; - tmp_shape_index.push_back(i); - result.push_back({user, std::move(tmp_shape_index)}); - } - } else if (user->opcode() == HloOpcode::kGetTupleElement) { - ShapeIndex tmp_shape_index = instruction_and_shape_index.shape_index; - const auto index = tmp_shape_index.front(); - if (index == user->tuple_index()) { - // This GTE is for the buffer we're tracking. - tmp_shape_index.pop_front(); - result.push_back({user, std::move(tmp_shape_index)}); - } - } else if (user->opcode() == HloOpcode::kCall) { - auto operand_indices = user->OperandIndices(instruction); - CHECK(user->called_computations().size() == 1) - << "Expect call to only have one called computation."; - for (const auto i : operand_indices) { - HloComputation* called_computation = - user->called_computations().front(); - HloInstruction* parameter_instruction = - called_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kWhile) { - auto operand_indices = user->OperandIndices(instruction); - HloComputation* while_body_computation = user->while_body(); - HloComputation* while_condition_computation = user->while_condition(); - for (const auto i : operand_indices) { - HloInstruction* parameter_instruction = - while_body_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - - HloInstruction* condition_instruction = - while_condition_computation->parameter_instruction(i); - result.push_back( - {condition_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kAsyncStart) { - auto operand_indices = user->OperandIndices(instruction); - CHECK(user->called_computations().size() == 1) - << "Expect async-start to only have one called computation."; - for (const auto i : operand_indices) { - HloComputation* called_computation = - user->called_computations().front(); - HloInstruction* parameter_instruction = - called_computation->parameter_instruction(i); - result.push_back( - {parameter_instruction, instruction_and_shape_index.shape_index}); - } - } else if (user->opcode() == HloOpcode::kCustomCall) { - const auto operand_indices = user->OperandIndices(instruction); - // TODO(b/342650757): Rather than a boolean indicating whether the - // instruction reuses the buffer, return the shape index of the output - // that the operand aliases with. - bool found_one = false; - for (const auto i : operand_indices) { - if (CustomCallReusesBuffer(user, i)) { - if (found_one) { - return absl::InternalError( - "Found multiple operands of a custom call that reuse the same " - "output buffer."); - } - result.push_back({user, instruction_and_shape_index.shape_index}); - found_one = true; - } - } - } else { - result.push_back({user, instruction_and_shape_index.shape_index}); - } - } - return result; -} - -// If an instruction's operand is a call, return the call now. A follow up call -// of this function on that call returns the ROOT. Eventually, once the given -// instruction is a parameter, the returned predecessor will be the appropriate -// operand of the call (not the call itself, since we already returned it). -std::vector GetPredecessors( - const InstructionAndShapeIndex& instruction_and_shape_index) { - std::vector result; - HloInstruction* instruction = instruction_and_shape_index.instruction; - if (instruction->opcode() == HloOpcode::kGetTupleElement) { - const int64_t index = instruction->tuple_index(); - auto tmp_shape_index = instruction_and_shape_index.shape_index; - tmp_shape_index.push_front(index); - result.push_back({instruction->mutable_operand(0), tmp_shape_index}); - } else if (instruction->opcode() == HloOpcode::kTuple) { - CHECK(!instruction_and_shape_index.shape_index.empty()) - << "Did not store an index before encountering a tuple."; - auto tmp_shape_index = instruction_and_shape_index.shape_index; - const int64_t index = tmp_shape_index.front(); - tmp_shape_index.pop_front(); - result.push_back({instruction->mutable_operand(index), tmp_shape_index}); - } else if (instruction->opcode() == HloOpcode::kCall) { - // Predecessor of a call is its computation's root instruction. - CHECK(instruction->called_computations().size() == 1) - << "Expect call to only have one called computation."; - HloComputation* called_computation = - instruction->called_computations().front(); - result.push_back({called_computation->root_instruction(), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kParameter) { - std::unique_ptr call_graph = - CallGraph::Build(instruction->GetModule()); - auto callers = call_graph->GetComputationCallers(instruction->parent()); - for (HloInstruction* caller : callers) { - result.push_back( - {caller->mutable_operand(instruction->parameter_number()), - instruction_and_shape_index.shape_index}); - } - } else if (instruction->opcode() == HloOpcode::kDynamicSlice) { - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } else if (instruction->opcode() == HloOpcode::kWhile) { - HloComputation* while_body_computation = instruction->while_body(); - result.push_back({while_body_computation->root_instruction(), - instruction_and_shape_index.shape_index}); - } else { - CHECK(instruction->operand_count() == 1) << absl::StreamFormat( - "Expecting instruction %s to have 1 operand, but it has %d.", - instruction->name(), instruction->operand_count()); - result.push_back({instruction->mutable_operand(0), - instruction_and_shape_index.shape_index}); - } - return result; -} - } // namespace -bool operator==(const InstructionAndShapeIndex& lhs, - const InstructionAndShapeIndex& rhs) { - return lhs.instruction == rhs.instruction && - lhs.shape_index == rhs.shape_index; -} - -std::string InstructionAndShapeIndex::ToString() const { - return absl::StrFormat("{Instr: %s, ShapeIndex: %s}", instruction->name(), - shape_index.ToString()); -} - -bool HostOffloader::IsValidDuringPureMemoryOffload( - const HloInstruction* instruction) const { - static constexpr std::array allowed_opcodes = { - HloOpcode::kGetTupleElement, - HloOpcode::kBitcast, - HloOpcode::kTuple, - HloOpcode::kCall, - HloOpcode::kWhile, - HloOpcode::kParameter, - HloOpcode::kOptimizationBarrier, - HloOpcode::kAsyncStart, - HloOpcode::kAsyncDone, - HloOpcode::kCustomCall}; - return absl::c_linear_search(allowed_opcodes, instruction->opcode()); -} - bool HostOffloader::InstructionIsAllowedBetweenMoveToHostAndDus( const HloInstruction* instruction) const { if (instruction->opcode() == HloOpcode::kReshape) { @@ -355,7 +155,8 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( // this so that we don't try to create an AllocateBuffer later. dynamic_update_slices_already_allocated_.insert(instruction); } - } else if (IsValidDuringPureMemoryOffload(instruction)) { + } else if (host_offload_utils::IsValidDuringPureMemoryOffload( + instruction)) { if (instruction->opcode() == HloOpcode::kAsyncStart) { // When visiting the parameter, we already set the memory space of the // input of the async-start; do not set it now. @@ -433,8 +234,9 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } } // Push successors onto the queue to be visited. - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape_index)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape_index)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -454,7 +256,8 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( } if (insert_copy_before) { - const auto predecessors = GetPredecessors(starting_instruction_and_index); + const auto predecessors = + host_offload_utils::GetPredecessors(starting_instruction_and_index); CHECK_EQ(predecessors.size(), 1); TF_ASSIGN_OR_RETURN(bool inserted_copy, InsertCopyBetween(predecessors.front(), @@ -687,7 +490,8 @@ HostOffloader::GetStartingInstructions( std::queue queue; TF_ASSIGN_OR_RETURN( const std::vector successors_of_custom_call, - GetSuccessors(InstructionAndShapeIndex(custom_call_instruction))); + host_offload_utils::GetSuccessors( + InstructionAndShapeIndex(custom_call_instruction))); for (const InstructionAndShapeIndex& successor : successors_of_custom_call) { queue.push(successor); } @@ -707,8 +511,9 @@ HostOffloader::GetStartingInstructions( } else { // Is a logical bitcast/reshape, we won't offload this yet. } - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -730,7 +535,7 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( std::queue queue; TF_ASSIGN_OR_RETURN( const std::vector successors_of_slice, - GetSuccessors(InstructionAndShapeIndex(slice))); + host_offload_utils::GetSuccessors(InstructionAndShapeIndex(slice))); for (const InstructionAndShapeIndex& successor : successors_of_slice) { queue.push(successor); } @@ -751,8 +556,9 @@ absl::Status HostOffloader::ValidateSliceLeadsToMoveToDeviceCustomCall( "the MoveToDevice custom call.", slice->name(), current_instruction->name())); } - TF_ASSIGN_OR_RETURN(const std::vector successors, - GetSuccessors(instruction_and_shape)); + TF_ASSIGN_OR_RETURN( + const std::vector successors, + host_offload_utils::GetSuccessors(instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { queue.push(successor); } @@ -824,7 +630,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( InstructionAndShapeIndex nested_instruction_and_shape = nested_queue.front(); nested_queue.pop(); - if (!IsValidDuringPureMemoryOffload( + if (!host_offload_utils::IsValidDuringPureMemoryOffload( nested_instruction_and_shape.instruction)) { return absl::InvalidArgumentError(absl::StrFormat( "Tensor which is moved to host is used by an invalid " @@ -838,7 +644,8 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( kHostMemorySpaceColor); TF_ASSIGN_OR_RETURN( const std::vector successors, - GetSuccessors(nested_instruction_and_shape)); + host_offload_utils::GetSuccessors( + nested_instruction_and_shape)); for (const InstructionAndShapeIndex& successor : successors) { nested_queue.push(successor); } @@ -851,7 +658,7 @@ absl::Status HostOffloader::CreateAllocateBufferForDynamicUpdateSlice( dynamic_update_slices_already_allocated_.insert(instruction); } const std::vector predecessors = - GetPredecessors(instruction_and_shape); + host_offload_utils::GetPredecessors(instruction_and_shape); for (const InstructionAndShapeIndex& predecessor : predecessors) { HloInstruction* predecessor_instruction = predecessor.instruction; if (predecessor_instruction->opcode() == HloOpcode::kBroadcast) { diff --git a/third_party/xla/xla/service/host_offloader.h b/third_party/xla/xla/service/host_offloader.h index 880cda3d77b621..8dfee6d455eb6b 100644 --- a/third_party/xla/xla/service/host_offloader.h +++ b/third_party/xla/xla/service/host_offloader.h @@ -26,36 +26,12 @@ #include "xla/service/hlo_alias_analysis.h" #include "xla/service/hlo_buffer.h" #include "xla/service/hlo_pass_interface.h" +#include "xla/service/host_offload_utils.h" namespace xla { class HloCostAnalysis; -struct InstructionAndShapeIndex { - explicit InstructionAndShapeIndex(HloInstruction* instruction) - : instruction(instruction) {} - InstructionAndShapeIndex(HloInstruction* instruction, ShapeIndex shape_index) - : instruction(instruction), shape_index(shape_index) {} - HloInstruction* instruction; - ShapeIndex shape_index; - std::string ToString() const; - - template - static H Hash(H h, const InstructionAndShapeIndex& i) { - h = H::combine(std::move(h), i.instruction); - h = H::combine(std::move(h), i.shape_index); - return std::move(h); - } - - template - friend H AbslHashValue(H h, const InstructionAndShapeIndex& i) { - return InstructionAndShapeIndex::Hash(std::move(h), i); - } -}; - -bool operator==(const InstructionAndShapeIndex& lhs, - const InstructionAndShapeIndex& rhs); - // This pass does "host memory offloading". If a tensor is annotated to be moved // to or from the host, this pass will remove the annotations and update each // tensor's layout with host memory spaces and insert copies if necessary. This @@ -90,17 +66,14 @@ class HostOffloader : public HloModulePass { absl::flat_hash_set validated_slices_; absl::flat_hash_map copies_created_after_; absl::flat_hash_set move_to_device_custom_calls_to_remove_; - absl::flat_hash_set already_inserted_copy_before_; + absl::flat_hash_set + already_inserted_copy_before_; // Sometimes previous transformations turn a DynamicSlice into a Slice. Since // we're doing a DMA between the host and device, we need to turn the Slice // back into a DynamicSlice. absl::StatusOr DynamifySlice(HloInstruction* slice); - // Returns true if the instruction is allowed to be in the - // middle of a pure memory offload path. - bool IsValidDuringPureMemoryOffload(const HloInstruction* instruction) const; - // Returns true if the instruction is allowed to be in the // middle of a path between a MoveToHost custom-call annotation and a // DynamicUpdateSlice. Ideally the custom-call should be immediately followed @@ -146,19 +119,22 @@ class HostOffloader : public HloModulePass { // Common function for doing the actual walking of the graph. Host memory // spaces are set and copies are inserted in here. absl::StatusOr WalkDownHostMemoryOffloadPaths( - const InstructionAndShapeIndex& starting_instruction_and_index, + const host_offload_utils::InstructionAndShapeIndex& + starting_instruction_and_index, bool insert_copy_before); // Given a custom call, this returns the first instruction and shape index to // start the host memory offload path from for each use of the custom call. - absl::StatusOr> GetStartingInstructions( - HloInstruction* custom_call_instruction); + absl::StatusOr> + GetStartingInstructions(HloInstruction* custom_call_instruction); // When a MoveToHost custom call is not paired with a DynamicUpdateSlice, a // copy from device to host must be inserted. absl::StatusOr InsertCopyBetween( - const InstructionAndShapeIndex& before_instruction_and_index, - const InstructionAndShapeIndex& after_instruction_and_index); + const host_offload_utils::InstructionAndShapeIndex& + before_instruction_and_index, + const host_offload_utils::InstructionAndShapeIndex& + after_instruction_and_index); // This is a fix for scheduling. Add copies to inputs of dynamic-update-slice // if the inserted value is directly a parameter of a computation. This is to diff --git a/third_party/xla/xla/service/host_offloader_test.cc b/third_party/xla/xla/service/host_offloader_test.cc index 85cc7742b3ce45..0d2ee3d295df72 100644 --- a/third_party/xla/xla/service/host_offloader_test.cc +++ b/third_party/xla/xla/service/host_offloader_test.cc @@ -37,8 +37,8 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace m = ::xla::match; diff --git a/third_party/xla/xla/service/host_offloading_prepare_test.cc b/third_party/xla/xla/service/host_offloading_prepare_test.cc index 9210d9824231c8..92d5490cfb2d15 100644 --- a/third_party/xla/xla/service/host_offloading_prepare_test.cc +++ b/third_party/xla/xla/service/host_offloading_prepare_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/host_memory_offload_annotations.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.cc b/third_party/xla/xla/service/latency_hiding_scheduler.cc index a17bc63b6f8804..dc59e5cca70151 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.cc +++ b/third_party/xla/xla/service/latency_hiding_scheduler.cc @@ -486,6 +486,15 @@ bool AsyncTracker::ReleasesSelectiveResource(const HloGraphNode* node) const { }); } +bool AsyncTracker::OccupiesSelectiveResource(const HloGraphNode* node) const { + return absl::c_any_of( + node->GetResources(), [&](const ResourcePair& resource) { + return resource.second == ResourceUsageType::kResourceOccupy && + GetResourceHazardType(resource.first) == + ResourceHazardType::kSelective; + }); +} + BufferInfoTracker::BufferInfoTracker( const HloModule* module, const HloAliasAnalysis* alias_analysis, const HloCostAnalysis::ShapeSizeFunction& shape_size_bytes) { @@ -731,6 +740,25 @@ DefaultSchedulerCore::ScheduleCandidate InitializeCandidate( namespace { +// Find the num hops to the closest selective resource overlap in ready set that +// provided node can be scheduled in between. +int64_t GetNumHopsToClosestSelectiveOverlap( + const DefaultSchedulerCore::ReadyQueueSet& ready_set, + const HloGraphNode* node) { + int64_t num_hops_to_closest_selective_resource_occupier = + std::numeric_limits::max(); + for (const HloGraphNode* n : ready_set) { + // Skip the node itself. + if (n == node) { + continue; + } + num_hops_to_closest_selective_resource_occupier = + std::min(num_hops_to_closest_selective_resource_occupier, + n->GetNumHopsToClosestSelectiveResourceOccupier()); + } + return num_hops_to_closest_selective_resource_occupier; +} + // Comparator for the ready set. This class represents the priority policies // for the nodes in the ready set. The policy can be whatever is appropriate to // reduce the execution time of the graph or achieve interesting properties @@ -802,7 +830,8 @@ class ReadySetLt { return *value; } } - // Otherwise pick a node that increases the pressure from the list. + // Otherwise pick a node that increases the pressure the least from the + // list. if (auto value = DefaultSchedulerCore::ChooseBestCandidate( a_increase.first < b_increase.first, a, b_increase.first < a_increase.first, b, @@ -880,6 +909,36 @@ class ReadySetLt { } } + auto async_depth_0_candidate = + [this](DefaultSchedulerCore::ScheduleCandidate& a, + DefaultSchedulerCore::ScheduleCandidate& b) + -> std::optional { + // If an instruction releasing a resource is not resource constrained and + // has an async depth of 0, delay it as much as possible to avoid + // potential cost model inefficiencies. For example, if a pair of + // async-start and async-done have no dependencies on other ops inside a + // loop, the async-start will be pushed to the beginning of the loop. + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + /*first_cond=*/!(a.node->DoesReleaseAnyResource() && + a.node->GetAsyncDepth() == 0 && + !IsResourceConstrained(a)), + a, + /*second_cond=*/ + !(b.node->DoesReleaseAnyResource() && + b.node->GetAsyncDepth() == 0 && !IsResourceConstrained(b)), + b, "kStartAtZeroDepth")) { + return value; + } + return std::nullopt; + }; + + if (sched_state_.config.aggressive_scheduling_policies && + sched_state_.config.prioritize_async_depth_over_stall) { + if (auto value = async_depth_0_candidate(a, b)) { + return *value; + } + } + const ApproximateLatencyEstimator::TimeCost a_ready_interval = std::max(a.node->GetReadyTime() - sched_state_.current_time, 0.0); const ApproximateLatencyEstimator::TimeCost b_ready_interval = @@ -906,19 +965,9 @@ class ReadySetLt { return *value; } } - if (sched_state_.config.aggressive_scheduling_policies) { - // If an instruction releasing a resource is not resource constrained and - // has an async depth of 0, delay it as much as possible to avoid - // potential cost model inefficiencies. - if (auto value = DefaultSchedulerCore::ChooseBestCandidate( - /*first_cond=*/!(a.node->DoesReleaseAnyResource() && - a.node->GetAsyncDepth() == 0 && - !IsResourceConstrained(a)), - a, - /*second_cond=*/ - !(b.node->DoesReleaseAnyResource() && - b.node->GetAsyncDepth() == 0 && !IsResourceConstrained(b)), - b, "kStartAtZeroDepth")) { + if (sched_state_.config.aggressive_scheduling_policies && + !sched_state_.config.prioritize_async_depth_over_stall) { + if (auto value = async_depth_0_candidate(a, b)) { return *value; } } @@ -981,6 +1030,31 @@ class ReadySetLt { return *value; } } + // If there are no selective overlaps open currently and there will be + // overlaps opened in the near future, hold off scheduling instructions + // that are valuable for selective overlaps. + if (sched_state_.config.enable_selective_resources && + sched_state_.selective_resource_releasers.empty()) { + int64_t distance_to_selective_overlap_for_a = + GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, a.node); + int64_t distance_to_selective_overlap_for_b = + GetNumHopsToClosestSelectiveOverlap(sched_state_.ready_set, b.node); + // If a is valuable for selective overlap and there is a selective + // overlap in the near future a can be scheduled inside, hold off + // scheduling a and schedule b instead. Same logic applies in reverse. + int64_t max_distance = + sched_state_.config.max_hops_to_closest_selective_overlap; + if (auto value = DefaultSchedulerCore::ChooseBestCandidate( + (a.node->GetValuableForSelectiveOverlap() && + distance_to_selective_overlap_for_a <= max_distance), + b, + (b.node->GetValuableForSelectiveOverlap() && + distance_to_selective_overlap_for_b <= max_distance), + a, "kNotValuableForSelectiveOverlap")) { + return *value; + } + } + if (sched_state_.config.aggressive_scheduling_policies) { // Favor nodes that unlock other nodes to be scheduled if possible. // This makes us more flexible in what we can use in scheduling. @@ -1672,6 +1746,8 @@ HloScheduleGraph::HloScheduleGraph( new_node_it->second->GetResources()); new_node_it->second->releases_selective_resource_ = async_tracker->ReleasesSelectiveResource(new_node_it->second.get()); + new_node_it->second->occupies_selective_resource_ = + async_tracker->OccupiesSelectiveResource(new_node_it->second.get()); // Gather while instructions for subsequent send-done dependency checks. if (instr->opcode() == HloOpcode::kWhile) { while_instrs.push_back(instr); @@ -1879,6 +1955,25 @@ void HloScheduleGraph::InitializeGraphAnalysis( while (!stack.empty()) { auto* node = stack.back(); stack.pop_back(); + // If a node occupies a selective resource, it is the closest selective + // resource occupier to itself and is 0 hops away. Otherwise, the num hops + // to closest selective resource occupier is the minimum of that of all + // predecessors plus 1. + if (async_tracker->OccupiesSelectiveResource(node)) { + node->num_hops_to_closest_selective_resource_occupier_ = 0; + } else { + int64_t closest_predecessor_distance = + std::numeric_limits::max(); + for (auto& pred : node->GetPredecessors()) { + closest_predecessor_distance = std::min( + closest_predecessor_distance, + pred.Target().num_hops_to_closest_selective_resource_occupier_); + } + if (closest_predecessor_distance != std::numeric_limits::max()) { + node->num_hops_to_closest_selective_resource_occupier_ = + closest_predecessor_distance + 1; + } + } if (async_tracker->IsSupportedAsyncDone(node->GetInstr())) { for (auto& pred : node->GetPredecessors()) { node->SetAsyncDepth( diff --git a/third_party/xla/xla/service/latency_hiding_scheduler.h b/third_party/xla/xla/service/latency_hiding_scheduler.h index 76ce8b307f7184..b0d8a8d08e9886 100644 --- a/third_party/xla/xla/service/latency_hiding_scheduler.h +++ b/third_party/xla/xla/service/latency_hiding_scheduler.h @@ -16,7 +16,6 @@ limitations under the License. #ifndef XLA_SERVICE_LATENCY_HIDING_SCHEDULER_H_ #define XLA_SERVICE_LATENCY_HIDING_SCHEDULER_H_ -#include #include #include #include @@ -132,11 +131,13 @@ struct SchedulerConfig { bool force_send_recv_to_use_same_resource = false; bool use_real_cost_model = false; bool aggressive_scheduling_policies = false; + bool prioritize_async_depth_over_stall = false; bool enable_release_start_policy = false; bool resource_sharing = false; bool resource_serializing = false; bool depth_based_memory_pressure_reduction = false; bool enable_selective_resources = false; + int64_t max_hops_to_closest_selective_overlap = 0; int64_t rerun = 0; }; @@ -284,6 +285,9 @@ class AsyncTracker { // Returns whether the provided node releases a selective resource. bool ReleasesSelectiveResource(const HloGraphNode* node) const; + // Returns whether the provided node occupies a selective resource. + bool OccupiesSelectiveResource(const HloGraphNode* node) const; + inline CanonicalAsyncOp GetCanonicalAsyncOp(const HloInstruction& hlo) const { return get_canonical_async_op_(hlo); } @@ -386,6 +390,17 @@ class HloGraphNode { bool ReleasesSelectiveResource() const { return releases_selective_resource_; } + bool OccupiesSelectiveResource() const { + return occupies_selective_resource_; + } + int64_t GetNumHopsToClosestSelectiveResourceOccupier() const { + return num_hops_to_closest_selective_resource_occupier_; + } + void SetNumHopsToClosestSelectiveResourceOccupier( + int64_t num_hops_to_closest_selective_resource_occupier) { + num_hops_to_closest_selective_resource_occupier_ = + num_hops_to_closest_selective_resource_occupier; + } ResourcesVector GetResources() const { return resources_; } bool DoesOccupyAnyResource() const { @@ -525,6 +540,11 @@ class HloGraphNode { bool valuable_for_selective_overlap_ = true; // Whether this node releases a selective resource. bool releases_selective_resource_ = false; + // Whether this node occupies a selective resource. + bool occupies_selective_resource_ = false; + // Nums hops to closest selective resource occupier. + int64_t num_hops_to_closest_selective_resource_occupier_ = + std::numeric_limits::max(); }; // Schedule graph that can be used to drive scheduling @@ -920,7 +940,6 @@ class DefaultSchedulerCore : public SchedulerCore { virtual absl::StatusOr FindAndExtractBestNodeAvailable( SchedulingState& sched_state, DefaultSchedulerCore::ShouldSkipNodeFunction should_skip_node); - bool DoesNodeReleaseSelectiveResource(const HloGraphNode* node) const; void DumpLatencyHidingSchedule( const HloComputation* computation, const HloScheduleGraph& schedule_graph, const std::vector& instructions, diff --git a/third_party/xla/xla/service/layout_assignment.cc b/third_party/xla/xla/service/layout_assignment.cc index 688af31615c710..f749cdba55d57a 100644 --- a/third_party/xla/xla/service/layout_assignment.cc +++ b/third_party/xla/xla/service/layout_assignment.cc @@ -482,7 +482,8 @@ absl::Status LayoutAssignment::SetInstructionLayout( absl::Status LayoutAssignment::SetInstructionLayout( const Shape& shape_with_layout, const HloInstruction* instruction, - bool mandatory, bool dfs, bool allow_alias, int64_t priority) { + bool mandatory, bool dfs, bool allow_alias, int64_t priority, + ShapeIndexView subshape_index) { VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " << ShapeUtil::HumanStringWithLayout(shape_with_layout) << ": priority = " << priority << " : mandatory = " << mandatory @@ -499,8 +500,12 @@ absl::Status LayoutAssignment::SetInstructionLayout( // instruction. TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, - [this, dfs, instruction, mandatory, allow_alias, priority]( - const Shape& subshape, const ShapeIndex& index) -> absl::Status { + [this, dfs, instruction, mandatory, allow_alias, priority, + subshape_index](const Shape& subshape, + const ShapeIndex& index) -> absl::Status { + if (!subshape_index.empty() && index != subshape_index) { + return absl::OkStatus(); + } auto buffers = points_to_analysis_->GetPointsToSet(instruction).element(index); CHECK_EQ(1, buffers.size()); diff --git a/third_party/xla/xla/service/layout_assignment.h b/third_party/xla/xla/service/layout_assignment.h index ba12a2a325bc99..ba59743018c386 100644 --- a/third_party/xla/xla/service/layout_assignment.h +++ b/third_party/xla/xla/service/layout_assignment.h @@ -378,14 +378,16 @@ class LayoutAssignment : public HloModulePass { absl::Status SetInstructionLayout(const Shape& shape_with_layout, const HloInstruction* instruction, bool mandatory = true, bool dfs = true, - bool allow_alias = false) { + bool allow_alias = false, + ShapeIndexView subshape_index = {}) { return SetInstructionLayout(shape_with_layout, instruction, mandatory, dfs, - allow_alias, current_priority_); + allow_alias, current_priority_, subshape_index); } absl::Status SetInstructionLayout(const Shape& shape_with_layout, const HloInstruction* instruction, bool mandatory, bool dfs, bool allow_alias, - int64_t priority); + int64_t priority, + ShapeIndexView subshape_index = {}); // Set the same given layout across all components of the instruction output. // It works the same as the API above if the output is a single array. absl::Status SetInstructionLayout(const Layout& layout, diff --git a/third_party/xla/xla/service/layout_assignment_test.cc b/third_party/xla/xla/service/layout_assignment_test.cc index 139124bd6c09bb..0b294c46ddef17 100644 --- a/third_party/xla/xla/service/layout_assignment_test.cc +++ b/third_party/xla/xla/service/layout_assignment_test.cc @@ -44,9 +44,9 @@ limitations under the License. #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/layout_normalization.cc b/third_party/xla/xla/service/layout_normalization.cc index 2dce620c81b267..16781509e22c60 100644 --- a/third_party/xla/xla/service/layout_normalization.cc +++ b/third_party/xla/xla/service/layout_normalization.cc @@ -742,21 +742,31 @@ class LayoutNormalizationVisitor : public DfsHloRewriteVisitor { Shape s = hlo->shape(); HloOpcode opcode = hlo->opcode(); TF_RET_CHECK(opcode == HloOpcode::kClamp || opcode == HloOpcode::kSelect); - HloInstruction* p = hlo->mutable_operand(0); - HloInstruction* i1 = hlo->mutable_operand(1); - HloInstruction* i2 = hlo->mutable_operand(2); - TF_RET_CHECK(p->shape().layout() == s.layout()); - TF_RET_CHECK(i1->shape().layout() == s.layout()); - TF_RET_CHECK(i2->shape().layout() == s.layout()); + HloInstruction* arg0 = hlo->mutable_operand(0); + HloInstruction* arg1 = hlo->mutable_operand(1); + HloInstruction* arg2 = hlo->mutable_operand(2); + if (opcode == HloOpcode::kClamp) { + TF_RET_CHECK(arg1->shape().layout() == s.layout()); + } else if (opcode == HloOpcode::kSelect) { + TF_RET_CHECK(arg1->shape().layout() == s.layout()); + TF_RET_CHECK(arg2->shape().layout() == s.layout()); + } else { + TF_RET_CHECK(false); + } - TF_ASSIGN_OR_RETURN(HloInstruction * p_0, GetNormalizedInput(p)); - TF_ASSIGN_OR_RETURN(HloInstruction * i1_0, GetNormalizedInput(i1)); - TF_ASSIGN_OR_RETURN(HloInstruction * i2_0, GetNormalizedInput(i2)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg0, + GetNormalizedInput(arg0)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg1, + GetNormalizedInput(arg1)); + TF_ASSIGN_OR_RETURN(HloInstruction * normalized_arg2, + GetNormalizedInput(arg2)); TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferTernaryOpShape( - opcode, p_0, i1_0, i2_0)); + opcode, normalized_arg0, + normalized_arg1, normalized_arg2)); HloInstruction* normalized = hlo->parent()->AddInstruction( - HloInstruction::CreateTernary(new_shape, opcode, p_0, i1_0, i2_0)); + HloInstruction::CreateTernary(new_shape, opcode, normalized_arg0, + normalized_arg1, normalized_arg2)); hlo->SetupDerivedInstruction(normalized); SetVisited(*normalized); diff --git a/third_party/xla/xla/service/layout_normalization_test.cc b/third_party/xla/xla/service/layout_normalization_test.cc index d2b9d92d2fb934..88ea4828ec597a 100644 --- a/third_party/xla/xla/service/layout_normalization_test.cc +++ b/third_party/xla/xla/service/layout_normalization_test.cc @@ -644,10 +644,26 @@ TEST_F(LayoutNormalizationTest, Select) { HloModule module ENTRY main { - p0 = f32[1,17,9,9]{1,3,2,0} parameter(0) - p1 = f32[1,17,9,9]{1,3,2,0} parameter(1) - b = pred[1,17,9,9]{1,3,2,0} parameter(2) - ROOT out = f32[1,17,9,9]{1,3,2,0} select(b, p0, p1), metadata={op_name="test"} + lhs = f32[1,17,9,9]{1,3,2,0} parameter(0) + rhs = f32[1,17,9,9]{1,3,2,0} parameter(1) + p = pred[1,17,9,9]{1,3,2,0} parameter(2) + ROOT out = f32[1,17,9,9]{1,3,2,0} select(p, lhs, rhs), metadata={op_name="test"} +} +)"; + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[1,9,9,17]{3,2,1,0} select({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, SelectScalarPredicate) { + const char* hlo = R"( +HloModule module + +ENTRY main { + lhs = f32[1,17,9,9]{1,3,2,0} parameter(0) + rhs = f32[1,17,9,9]{1,3,2,0} parameter(1) + p = pred[] parameter(2) + ROOT out = f32[1,17,9,9]{1,3,2,0} select(p, lhs, rhs), metadata={op_name="test"} } )"; CheckLayoutNormalization(hlo, R"( @@ -734,10 +750,44 @@ TEST_F(LayoutNormalizationTest, Clamp) { HloModule m ENTRY main { - p0 = f32[64,1,32]{1,0,2} parameter(0) - p1 = f32[64,1,32]{1,0,2} parameter(1) - p2 = f32[64,1,32]{1,0,2} parameter(2) - ROOT out = f32[64,1,32]{1,0,2} clamp(f32[64,1,32]{1,0,2} p0, f32[64,1,32]{1,0,2} p1, f32[64,1,32]{1,0,2} p2), metadata={op_name="test"} + lb = f32[64,1,32]{1,0,2} parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[64,1,32]{1,0,2} parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[64,1,32]{1,0,2} lb, f32[64,1,32]{1,0,2} in, f32[64,1,32]{1,0,2} ub), metadata={op_name="test"} +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[32,64,1]{2,1,0} clamp({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, ClampScalarBounds) { + const char* hlo = R"( +HloModule m + +ENTRY main { + lb = f32[] parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[] parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[] lb, f32[64,1,32]{1,0,2} in, f32[] ub), metadata={op_name="test"} +} +)"; + + CheckLayoutNormalization(hlo, R"( +// CHECK: f32[32,64,1]{2,1,0} clamp({{.*}}, {{.*}}, {{.*}}), metadata={op_name="test"} +)"); +} + +TEST_F(LayoutNormalizationTest, ClampScalarLb) { + const char* hlo = R"( +HloModule m + +ENTRY main { + lb = f32[] parameter(0) + in = f32[64,1,32]{1,0,2} parameter(1) + ub = f32[64,1,32]{1,0,2} parameter(2) + ROOT out = f32[64,1,32]{1,0,2} clamp(f32[] lb, f32[64,1,32]{1,0,2} in, f32[64,1,32]{1,0,2} ub), metadata={op_name="test"} } )"; diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.cc b/third_party/xla/xla/service/llvm_ir/ir_array.cc index 29a4f4b467ebf4..8c9c290000f81c 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.cc +++ b/third_party/xla/xla/service/llvm_ir/ir_array.cc @@ -527,6 +527,7 @@ llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index, if (!index.LinearValidOnShape(shape_)) { // Create a valid linear index. std::vector dimensions; + dimensions.reserve(shape_.rank()); for (int64_t i = 0; i < shape_.rank(); ++i) { dimensions.push_back(shape_.dimensions(i)); } diff --git a/third_party/xla/xla/service/llvm_ir/ir_array.h b/third_party/xla/xla/service/llvm_ir/ir_array.h index 9ec78b09aaac8c..691f93fa1570d8 100644 --- a/third_party/xla/xla/service/llvm_ir/ir_array.h +++ b/third_party/xla/xla/service/llvm_ir/ir_array.h @@ -250,9 +250,9 @@ class IrArray { IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape); // Default implementations of copying and moving. - IrArray(IrArray&& other) = default; + IrArray(IrArray&& other) noexcept = default; IrArray(const IrArray& other) = default; - IrArray& operator=(IrArray&& other) = default; + IrArray& operator=(IrArray&& other) noexcept = default; IrArray& operator=(const IrArray& other) = default; llvm::Value* GetBasePointer() const { return base_ptr_; } diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 399c335ff387ca..0ed7bacdd94a99 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -715,17 +715,6 @@ std::map MergeMetadata( return result; } -static absl::Status CreateAndWriteStringToFile( - const std::string& directory_name, const std::string& file_name, - const std::string& text) { - std::unique_ptr f; - TF_RETURN_IF_ERROR(tsl::Env::Default()->RecursivelyCreateDir(directory_name)); - TF_RETURN_IF_ERROR(tsl::Env::Default()->NewWritableFile(file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(text)); - TF_RETURN_IF_ERROR(f->Close()); - return absl::OkStatus(); -} - void DumpIrIfEnabled(const HloModule& hlo_module, const llvm::Module& llvm_module, bool optimized, absl::string_view filename_suffix) { diff --git a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc index 22f4cd0b8c9205..ca738619aa8ab8 100644 --- a/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc +++ b/third_party/xla/xla/service/mapped_ptr_container_sorter_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "absl/functional/bind_front.h" #include "absl/log/log.h" #include "xla/test.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/memory_space_assignment/BUILD b/third_party/xla/xla/service/memory_space_assignment/BUILD index 0473d757ed9359..56209b12778ed2 100644 --- a/third_party/xla/xla/service/memory_space_assignment/BUILD +++ b/third_party/xla/xla/service/memory_space_assignment/BUILD @@ -110,6 +110,7 @@ xla_cc_test( "//xla/tests:test_utils", "//xla/tests:verified_hlo_module", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -120,7 +121,6 @@ xla_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:status", @@ -310,11 +310,11 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:hlo_cost_analysis", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -355,12 +355,12 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/service/heap_simulator", "//xla/tests:hlo_test_base", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -465,6 +465,7 @@ xla_cc_test( "//xla/service:hlo_value", "//xla/tests:hlo_test_base", "//xla/tests:verified_hlo_module", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", @@ -474,7 +475,6 @@ xla_cc_test( "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", "@com_googlesource_code_re2//:re2", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc index 4611453fef48ac..67371364f1cd0d 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.cc +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.cc @@ -1493,7 +1493,8 @@ void MsaAlgorithm::CreateAllocationValuesForJointProcessedIntervals( continue; } - if (interval.size > available_heap_size()) { + if (!options_.enable_window_prefetch && + interval.size > available_heap_size()) { VLOG(3) << "Skip " << interval.buffer->ToShortString() << " because the buffer is larger than the heap size."; continue; @@ -2152,6 +2153,12 @@ absl::StatusOr MsaAlgorithm::AllocateAllocationValues( options_.alternate_memory_space; VLOG(4) << "require_no_copy_alternate_mem_allocation = " << require_no_copy_alternate_mem_allocation; + if (require_no_copy_alternate_mem_allocation && + allocation_value.size() > available_heap_size()) { + VLOG(3) << "Skip " << allocation_value.value()->ToShortString() + << " because the buffer is larger than the heap size."; + continue; + } if (!options_.is_position_allowed_in_alternate_mem_fn( allocation_value.defining_position())) { if (require_no_copy_alternate_mem_allocation) { @@ -3018,8 +3025,12 @@ void MsaAlgorithm::CreateOrAddToAliasedOffset( const AllocationSequence& allocations, int64_t time) { for (auto allocation_it = allocations.rbegin(); allocation_it != allocations.rend(); ++allocation_it) { + // The use case of GetLiveAllocationAt is to find the allocation that + // corresponds to the full buffer. Window prefetched allocations allocates + // only partial buffers, so we want to skip them. if ((*allocation_it)->start_time() <= time && - (*allocation_it)->end_time() >= time) { + (*allocation_it)->end_time() >= time && + !(*allocation_it)->is_window_prefetched_allocation()) { return allocation_it->get(); } } @@ -4197,6 +4208,11 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { << "Not trying to prefetch because use requires buffer in default mem."; (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); + + // If the buffer is placed in default memory, we can also try window + // prefetching it, which will try to prefetch only a window worth of data to + // alternate memory. + WindowPrefetch(request, **prev_allocation_in_default_mem_it); return Result::kSuccess; } @@ -4286,9 +4302,28 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateSegment(AllocationRequest& request) { // default memory. (*prev_allocation_in_default_mem_it)->set_end_time(request.end_time); (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use); + + // If the buffer is placed in default memory, we can try window prefetching + // it, which will try to prefetch only a window worth of data to alternate + // memory. + WindowPrefetch(request, **prev_allocation_in_default_mem_it); return allocation_result; } +void MsaAlgorithm::AddAsyncCopyForWindowPrefetch( + Allocation& prev_allocation, HloUse use, const Chunk& chunk, + int64_t exclusive_start_time, int64_t inclusive_end_time, + AllocationSequence* allocations, AliasedOffset* aliased_offset, + float resource, const WindowPrefetchedAllocation::Options& options) { + allocations->push_back(std::make_unique( + prev_allocation, use, chunk, exclusive_start_time, inclusive_end_time, + options)); + + RegisterAsyncCopy(MemorySpace::kAlternate, exclusive_start_time, + inclusive_end_time, allocations, aliased_offset, resource, + /*cross_program_prefetch_index=*/std::nullopt); +} + void MsaAlgorithm::AddAsyncCopy( Allocation& prev_allocation, MemorySpace memory_space, std::optional chunk, int64_t exclusive_start_time, int64_t end_time, @@ -4306,6 +4341,16 @@ void MsaAlgorithm::AddAsyncCopy( prev_allocation, memory_space, chunk, exclusive_start_time, copy_done_schedule_before_time, end_time, cross_program_prefetch_index)); + RegisterAsyncCopy(memory_space, exclusive_start_time, + copy_done_schedule_before_time, allocations, aliased_offset, + resource, cross_program_prefetch_index); +} + +void MsaAlgorithm::RegisterAsyncCopy( + MemorySpace memory_space, int64_t exclusive_start_time, + int64_t copy_done_schedule_before_time, AllocationSequence* allocations, + AliasedOffset* aliased_offset, float resource, + std::optional cross_program_prefetch_index) { // Register the additional async copy with the interval tree to keep track of // the limit at any given time. pending_async_copies_.push_back({exclusive_start_time, @@ -4445,7 +4490,8 @@ MsaAlgorithm::Result MsaAlgorithm::AllocateInAlternateMemoryNoCopy( prev_allocation = request.allocation_value->allocation_sequence()->back().get(); can_eliminate_copy = - (prev_allocation->memory_space() == MemorySpace::kAlternate); + (prev_allocation->memory_space() == MemorySpace::kAlternate && + !prev_allocation->is_window_prefetched_allocation()); } if (!can_eliminate_copy) { @@ -4718,9 +4764,41 @@ std::string DescribeSlicedBufferMove( } // namespace -MsaAlgorithm::Result MsaAlgorithm::Prefetch( +MsaAlgorithm::Result MsaAlgorithm::WindowPrefetch( const AllocationRequest& request, Allocation& prev_allocation_in_default_mem) { + if (!options_.enable_window_prefetch) { + return Result::kSuccess; + } + + const HloUse use = request.use->hlo_use; + VLOG(3) << "Considering window prefetch for use=" << use.ToString(); + + // Get the window prefetch details for this use. + WindowPrefetchDetail details = + options_.window_prefetch_detail_fn(use.instruction); + for (const WindowPrefetchDetail::WindowDetail& window : details.windows()) { + if (window.operand() != use.operand_number) { + continue; + } + + WindowPrefetchedAllocation::Options options; + options.bytes = window.size(); + options.uid = window.uid(); + options.alternate_memory_space = options_.alternate_memory_space; + options.notify_operand_appended_fn = options_.notify_operand_appended_fn; + AllocationRequest window_prefetch_request = request; + window_prefetch_request.window_prefetch_options = &options; + window_prefetch_request.size = window.size(); + const Shape shape = ShapeUtil::MakeShape(U8, {window.size()}); + Prefetch(window_prefetch_request, prev_allocation_in_default_mem, &shape); + } + return Result::kSuccess; +} + +MsaAlgorithm::Result MsaAlgorithm::Prefetch( + const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem, const Shape* shape) { // Try partially placing the buffer in the alternate space. The time that is // overlapped will be used to asynchronously copy the buffer from the // default memory to the alternate memory. @@ -4743,6 +4821,10 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( PrefetchContext context; context.request = &request; context.prev_allocation_in_default_mem = &prev_allocation_in_default_mem; + // If the request has window prefetch options, it is called from window + // prefetch. + context.window_prefetch = (request.window_prefetch_options != nullptr); + CHECK(!context.window_prefetch || options_.enable_window_prefetch); // Create a SliceProposal and WorkingIntervals. SetupPrefetchWorkingIntervalsAndSliceProposal(context); @@ -4757,8 +4839,13 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( return check_result; } const HloUse& use = request.use->hlo_use; - context.full_shape = &ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); + if (shape != nullptr) { + context.full_shape = shape; + } else { + context.full_shape = &ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), + use.operand_index); + } // While uses might be allowed to have additional outstanding prefetches. context.extra_async_copy_limit = use.instruction->opcode() == HloOpcode::kWhile @@ -4849,14 +4936,26 @@ MsaAlgorithm::Result MsaAlgorithm::Prefetch( << context.unsliced_solution->prefetch_picker_debug_string; AddToPendingChunks(context.unsliced_solution_intervals.full, context.unsliced_solution->chunk_candidate); - AddAsyncCopy( - *context.prev_allocation_in_default_mem, MemorySpace::kAlternate, - context.unsliced_solution->chunk_candidate, - context.unsliced_solution_intervals.full.start - 1, - context.request->end_time, context.prefetch_end_time, - context.request->allocation_value->mutable_allocation_sequence(), - context.request->preferred_offset, - context.unsliced_solution->prefetch_resource); + if (context.window_prefetch) { + AddAsyncCopyForWindowPrefetch( + *context.prev_allocation_in_default_mem, request.use->hlo_use, + context.unsliced_solution->chunk_candidate, + context.unsliced_solution_intervals.full.start - 1, + context.prefetch_end_time, + context.request->allocation_value->mutable_allocation_sequence(), + context.request->preferred_offset, + context.unsliced_solution->prefetch_resource, + *context.request->window_prefetch_options); + } else { + AddAsyncCopy( + *context.prev_allocation_in_default_mem, MemorySpace::kAlternate, + context.unsliced_solution->chunk_candidate, + context.unsliced_solution_intervals.full.start - 1, + context.request->end_time, context.prefetch_end_time, + context.request->allocation_value->mutable_allocation_sequence(), + context.request->preferred_offset, + context.unsliced_solution->prefetch_resource); + } request.allocation_value->allocation_sequence()->back()->AddUse( request.use->hlo_use); @@ -4929,7 +5028,9 @@ void MsaAlgorithm::SetupPrefetchWorkingIntervalsAndSliceProposal( context.sliced_solution_intervals.full; // Attempt to generate a slice proposal. - GenerateSliceProposal(context); + if (!context.window_prefetch) { + GenerateSliceProposal(context); + } // Setup the full SlicedBufferIntervals for the sliced and unsliced solutions. // If there is no slice proposal, we will not try a sliced solution. In such a diff --git a/third_party/xla/xla/service/memory_space_assignment/algorithm.h b/third_party/xla/xla/service/memory_space_assignment/algorithm.h index 5e2073bcc183ec..52d0f0ee563747 100644 --- a/third_party/xla/xla/service/memory_space_assignment/algorithm.h +++ b/third_party/xla/xla/service/memory_space_assignment/algorithm.h @@ -514,6 +514,10 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { absl::Span all_use_times; // See the comment for require_copy_allocation HloInstruction* required_copy_allocation_for; + // Data structure that contains the options for making window prefetched + // allocations. + const WindowPrefetchedAllocation::Options* window_prefetch_options = + nullptr; }; // This struct contains mandatory memory assignments at a given time. E.g., an @@ -669,6 +673,11 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // Data structures used to compute and store the unsliced solution. WorkingIntervals unsliced_solution_intervals; std::optional unsliced_solution; + + // Indicates whether the prefetch is for a windowed prefetch. A window + // prefetch only prefetches a window worth of data. Its prefetch does not + // use sliced prefetch. + bool window_prefetch = false; }; // Result of an allocation, prefetch, eviction etc. request. The result is @@ -860,7 +869,8 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { // Try prefetching to alternate memory space. Result Prefetch(const AllocationRequest& request, - Allocation& prev_allocation_in_default_mem); + Allocation& prev_allocation_in_default_mem, + const Shape* shape = nullptr); // Helper methods used to implement Prefetch(). // @@ -888,6 +898,10 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { std::string AlternateMemoryAllocationAttemptToString( bool for_sliced_solution, const PrefetchContext& context) const; + // Try to prefetch a window worth of data into the alternate memory. + Result WindowPrefetch(const AllocationRequest& request, + Allocation& prev_allocation_in_default_mem); + // Find the best possible chunk candidate, where it has the longest possible // availability if no preferred offset is given, or at the preferred_offset if // it is given. @@ -1014,6 +1028,14 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { void ImportRepackedSlicedAllocation(RepackAllocationBlock& block); absl::Status AreRepackedSlicesValid(const RepackAllocationBlock& block); + // Registers an asynchronous copy with asynchronous copy data structures to + // keep track of its state. + void RegisterAsyncCopy(MemorySpace memory_space, int64_t exclusive_start_time, + int64_t copy_done_schedule_before_time, + AllocationSequence* allocations, + AliasedOffset* aliased_offset, float resource, + std::optional cross_program_prefetch_index); + // Adds an asynchronous copy to allocations. void AddAsyncCopy( Allocation& prev_allocation, MemorySpace memory_space, @@ -1032,6 +1054,15 @@ class MsaAlgorithm : public GlobalDecreasingSizeBestFitHeap { const std::vector& slice_decisions_sorted_by_start_time, int64_t prefetch_end_time, int64_t allocation_end_time); + // For window prefetching, adds a WindowPrefetchedAllocation to allocations. + // Also updates asynchronous copy data structures, prefetch_interval_tree_, + // and aliasing data structures. + void AddAsyncCopyForWindowPrefetch( + Allocation& prev_allocation, HloUse use, const Chunk& chunk, + int64_t exclusive_start_time, int64_t inclusive_end_time, + AllocationSequence* allocations, AliasedOffset* aliased_offset, + float resource, const WindowPrefetchedAllocation::Options& options); + // This method is used for committing the chunk candidate but adding it to // pending_chunks_ so that we can "uncommit" them in case we need to roll back // this allocation sequence. diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.cc b/third_party/xla/xla/service/memory_space_assignment/allocation.cc index 50bec57e28d9c8..8699aabef99b56 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.cc +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.cc @@ -37,6 +37,8 @@ limitations under the License. #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/service/heap_simulator/allocation_block.h" #include "xla/service/heap_simulator/heap_simulator.h" #include "xla/service/hlo_value.h" @@ -854,6 +856,133 @@ bool MirroredAllocation::operator==(const Allocation& other) const { return casted_other != nullptr && (*this) == (*casted_other); } +WindowPrefetchedAllocation::WindowPrefetchedAllocation( + Allocation& prev_allocation, HloUse use, const HeapSimulator::Chunk& chunk, + int64_t prefetch_start_schedule_after_time, + int64_t prefetch_done_schedule_before_time, const Options& options) + : Allocation( + {nullptr, {}}, MemorySpace::kAlternate, chunk, + ExclusiveToInclusiveStartTime(prefetch_start_schedule_after_time), + InclusiveToExclusiveEndTime(prefetch_done_schedule_before_time), + /*is_scoped_allocation=*/false, + /*cross_program_prefetch_index=*/std::nullopt), + options_(options), + prev_allocation_(prev_allocation), + use_(use), + prefetch_start_schedule_after_(prefetch_start_schedule_after_time), + prefetch_done_schedule_before_(prefetch_done_schedule_before_time), + bytes_(chunk.size) {} + +HloPosition WindowPrefetchedAllocation::defining_position() const { + HloPosition defining_position = original_defining_position(); + if (defining_position.instruction == nullptr) { + return prev_allocation_.defining_position(); + } + return defining_position; +} + +int64_t WindowPrefetchedAllocation::earliest_available_time() const { + return prefetch_done_schedule_before_; +} + +absl::Status WindowPrefetchedAllocation::InsertWindowPrefetchInstruction( + HloInstruction* producing_instruction, HloInstruction* use_instruction, + HloComputation* computation) { + // Derive the shape for window buffer. + Shape shape = ShapeUtil::MakeShape(U8, {options_.bytes}); + Layout layout = LayoutUtil::MakeLayout({0}); + layout.set_memory_space(options_.alternate_memory_space); + *shape.mutable_layout() = layout; + + // Insert a new parameter in the fused computation. + HloComputation* fused_computation = + use_instruction->fused_instructions_computation(); + const int64_t num_parameters = fused_computation->num_parameters(); + std::string name = absl::StrCat("window-buffer.", num_parameters); + HloInstruction* param = fused_computation->AddParameter( + HloInstruction::CreateParameter(num_parameters, shape, name)); + + // Insert async WindowPrefetch instructions as operands to the fusion. + HloInstruction* prefetch = + computation->AddInstruction(HloInstruction::CreateCustomCall( + shape, {producing_instruction}, "WindowPrefetch")); + TF_ASSIGN_OR_RETURN(prefetch_instruction_, + computation->CreateAsyncInstructions(prefetch, {})); + use_instruction->AppendOperand(prefetch_instruction_); + + // Insert instruction to consume the added operands and forwards the original + // fusion output. + auto get_or_create_consumer = + [](HloComputation* computation) -> HloInstruction* { + HloInstruction* root = computation->root_instruction(); + // If the root is already a WindowPrefetchBuffer, we don't need to create + // a new one. + if (root->IsCustomCall("WindowPrefetchBuffer")) { + return root; + } + HloInstruction* new_root = + computation->AddInstruction(HloInstruction::CreateCustomCall( + root->shape(), {root}, "WindowPrefetchBuffer")); + computation->set_root_instruction(new_root); + return new_root; + }; + HloInstruction* consumer = get_or_create_consumer(fused_computation); + consumer->AppendOperand(param); + return absl::OkStatus(); +} + +absl::Status WindowPrefetchedAllocation::Process() { + HloInstruction* producing_instruction = AddGetTupleElements(); + HloComputation* computation = producing_instruction->parent(); + HloInstruction* use_instruction = use_.instruction; + CHECK_EQ(use_instruction->opcode(), HloOpcode::kFusion); + + TF_RETURN_IF_ERROR(InsertWindowPrefetchInstruction( + producing_instruction, use_instruction, computation)); + + // Notify the backend that an operand has been appended as a window prefetch + // buffer. + int64_t use_operand = use_instruction->operand_count() - 1; + options_.notify_operand_appended_fn(use_instruction, options_.uid, + use_operand); + + // Set the original defining position to the window prefetch instruction. + set_original_defining_position(HloPosition{prefetch_instruction_, {}}); + AddUse(HloUse{use_instruction, use_operand}); + return absl::OkStatus(); +} + +void WindowPrefetchedAllocation::MarkIfNeeded( + absl::flat_hash_set& needed_allocations) const { + MarkNeeded(needed_allocations); +} + +void WindowPrefetchedAllocation::MarkNeeded( + absl::flat_hash_set& needed_allocations) const { + needed_allocations.insert(this); + prev_allocation_.MarkNeeded(needed_allocations); +} + +std::string WindowPrefetchedAllocation::ToString() const { + return absl::StrCat("WindowPrefetched Allocation"); +} + +bool WindowPrefetchedAllocation::operator==( + const WindowPrefetchedAllocation& other) const { + return this->base_is_equal(static_cast(other)) && + prefetch_done_schedule_before() == + other.prefetch_done_schedule_before() && + prefetch_start_schedule_after() == + other.prefetch_start_schedule_after() && + prefetch() == other.prefetch() && bytes_ == other.bytes_; +} + +bool WindowPrefetchedAllocation::operator==(const Allocation& other) const { + const WindowPrefetchedAllocation* casted_other = + dynamic_cast(&other); + return casted_other != nullptr && (*this) == (*casted_other); +} + std::tuple GetAllocationSortTuple( const std::unique_ptr& allocation) { int64_t scheduled_on_or_before = allocation->start_time(); diff --git a/third_party/xla/xla/service/memory_space_assignment/allocation.h b/third_party/xla/xla/service/memory_space_assignment/allocation.h index d0a4d72642aba2..bb3b324e07700c 100644 --- a/third_party/xla/xla/service/memory_space_assignment/allocation.h +++ b/third_party/xla/xla/service/memory_space_assignment/allocation.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -130,6 +131,7 @@ class Allocation { virtual bool is_pinned_allocation() const = 0; virtual bool is_copy_allocation() const = 0; virtual bool is_sliced_copy_allocation() const = 0; + virtual bool is_window_prefetched_allocation() const = 0; // True if the allocation is for a copy or a sliced-copy. bool is_copy_like_allocation() const; @@ -211,6 +213,7 @@ class PinnedAllocation final : public Allocation { bool is_pinned_allocation() const override { return true; } bool is_copy_allocation() const override { return false; } bool is_sliced_copy_allocation() const override { return false; } + bool is_window_prefetched_allocation() const override { return false; } absl::Status Process() override; absl::Status PostProcess() override { return absl::OkStatus(); } void MarkIfNeeded(absl::flat_hash_set& needed_allocations) @@ -249,6 +252,7 @@ class CopyAllocation final : public Allocation { bool is_pinned_allocation() const override { return false; } bool is_copy_allocation() const override { return true; } bool is_sliced_copy_allocation() const override { return false; } + bool is_window_prefetched_allocation() const override { return false; } absl::Status Process() override; absl::Status PostProcess() override { return absl::OkStatus(); } void MarkIfNeeded(absl::flat_hash_set& needed_allocations) @@ -350,6 +354,7 @@ class SlicedCopyAllocation final : public Allocation { bool is_pinned_allocation() const override { return false; } bool is_copy_allocation() const override { return false; } bool is_sliced_copy_allocation() const override { return true; } + bool is_window_prefetched_allocation() const override { return false; } // MemorySpaceAssignment::Process() calls Process() to create asynchronous // slice copies, and a bitcast-concat call to glue the slices back together. absl::Status Process() override; @@ -393,6 +398,75 @@ class SlicedCopyAllocation final : public Allocation { absl::FunctionRef get_equivalent_s8_shape_fn_; }; +// This class represents an allocation resulting from asynchronously prefetching +// a window buffer. When a tensor is placed in the default memory, we can +// prefetch the window buffer of the tensor to the alternate memory space. This +// is called window prefetching. +class WindowPrefetchedAllocation final : public Allocation { + public: + struct Options { + int64_t bytes = 0; + int64_t uid = 0; + int64_t alternate_memory_space = 0; + std::function + notify_operand_appended_fn = + [](const HloInstruction*, int64_t, int64_t) {}; + }; + + WindowPrefetchedAllocation(Allocation& prev_allocation, HloUse use, + const HeapSimulator::Chunk& chunk, + int64_t prefetch_start_schedule_after_time, + int64_t prefetch_done_schedule_before_time, + const Options& options); + + // Overridden methods + // + HloPosition defining_position() const override; + int64_t earliest_available_time() const override; + bool is_pinned_allocation() const override { return false; } + bool is_copy_allocation() const override { return false; } + bool is_sliced_copy_allocation() const override { return false; } + bool is_window_prefetched_allocation() const override { return true; } + // MemorySpaceAssignment::Process() calls Process() to create asynchronous + // window prefetches. + absl::Status Process() override; + absl::Status PostProcess() override { return absl::OkStatus(); } + // Marks the allocation as needed. + void MarkIfNeeded(absl::flat_hash_set& needed_allocations) + const override; + void MarkNeeded(absl::flat_hash_set& needed_allocations) + const override; + std::string ToString() const override; + bool operator==(const WindowPrefetchedAllocation& other) const; + bool operator==(const Allocation& other) const override; + int64_t bytes() const { return bytes_; } + int64_t prefetch_start_schedule_after() const { + return prefetch_start_schedule_after_; + } + int64_t prefetch_done_schedule_before() const { + return prefetch_done_schedule_before_; + } + HloInstruction* prefetch() const { return prefetch_instruction_; } + + private: + // This method is called by Process() to create window prefetch instructions. + // These instructions include a pair of async WindowPrefetch outside the + // fusion and a WindowPrefetchBuffer inside the fusion. The + // WindowPrefetchBuffer is used for consuming the appended window buffer + // operands. + absl::Status InsertWindowPrefetchInstruction( + HloInstruction* producing_instruction, HloInstruction* use_instruction, + HloComputation* computation); + + Options options_; + HloInstruction* prefetch_instruction_ = nullptr; + Allocation& prev_allocation_; + HloUse use_; + int64_t prefetch_start_schedule_after_; + int64_t prefetch_done_schedule_before_; + int64_t bytes_; +}; + // An allocation in the default memory space that mirrors another Allocation // object. This is useful to model an eviction that happens before a while op // so that we don't need to redundantly evict the buffer after the while op as @@ -409,6 +483,7 @@ class MirroredAllocation final : public Allocation { bool is_pinned_allocation() const override { return false; } bool is_copy_allocation() const override { return false; } bool is_sliced_copy_allocation() const override { return false; } + bool is_window_prefetched_allocation() const override { return false; } absl::Status Process() override; absl::Status PostProcess() override { return absl::OkStatus(); } void MarkIfNeeded(absl::flat_hash_set& needed_allocations) @@ -442,6 +517,7 @@ class ParentAllocation final : public Allocation { bool is_pinned_allocation() const override { return false; } bool is_copy_allocation() const override { return false; } bool is_sliced_copy_allocation() const override { return false; } + bool is_window_prefetched_allocation() const override { return false; } absl::Status Process() override; absl::Status PostProcess() override; void MarkIfNeeded(absl::flat_hash_set& needed_allocations) diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h index 364027c79e760a..72229fcab2d273 100644 --- a/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis.h @@ -89,7 +89,7 @@ class BaseCosts { // The bandwidth of copies to/from alternate memory. virtual float BytesPerSecond() = 0; - // The compute cost of instruction. The compute cost assumes 0 memory transer + // The compute cost of instruction. The compute cost assumes 0 memory transfer // is required. virtual float ComputeSeconds(const HloInstruction& instruction) = 0; diff --git a/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc index e4d93dd8c8f61f..39d4dbbded7bd2 100644 --- a/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/cost_analysis_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc index b0487cab8fdbb9..1dff5221026f82 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_bound_loop_optimizer_test.cc @@ -55,9 +55,9 @@ limitations under the License. #include "xla/status_macros.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto index 77faa69e74c8c8..e15d564dac8f35 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -46,6 +46,21 @@ message SlicedPrefetchOptions { uint64 preferred_slice_size = 5; } +// Memory space assignment options for prefetching windows of data +message WindowPrefetchDetail { + message WindowDetail { + // Index of the operand that is window prefetched. + int64 operand = 1; + // Window buffer size in bytes. + int64 size = 2; + // Unique identifier to distinguish the buffers that are associated with the + // same operand. + int64 uid = 3; + } + + repeated WindowDetail windows = 1; +} + // Options for memory-bound loop optimizations in memory space assignment. If // enabled, this pass can optimize memory-bound unrolled loops to maximize the // bandwidth utilized and minimize the execution time. diff --git a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc index 7547901d1aa4a8..40f5e1d9f94031 100644 --- a/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/memory_space_assignment_test.cc @@ -78,9 +78,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/platform/status.h" @@ -8282,6 +8282,65 @@ TEST_F(MemorySpaceAssignmentTest, HoistCopyStart) { } } +TEST_F(MemorySpaceAssignmentTest, WindowPrefetch) { + absl::string_view hlo_string = R"( +HloModule module, is_scheduled=true + +%fused_computation { + %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0) + %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1) + %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2) + %add0 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%p0, %p1) + ROOT %add1 = bf16[64,8]{1,0:T(8,128)(2,1)} add(%add0, %p2) +} + +entry { + %p0 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(0) + %p1 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(1) + %p2 = bf16[64,8]{1,0:T(8,128)(2,1)} parameter(2) + ROOT fusion = bf16[64,8]{1,0:T(8,128)(2,1)} fusion(bf16[64,8]{1,0:T(8,128)(2,1)} %p0, bf16[64,8]{1,0:T(8,128)(2,1)} %p1, bf16[64,8]{1,0:T(8,128)(2,1)} %p2), kind=kLoop, calls=%fused_computation +} + +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + // Get info about window prefetch buffers, such as which operands they + // correspond to and their sizes. + auto window_prefetch_detail_fn = [&](const HloInstruction* instruction) { + WindowPrefetchDetail window_prefetch_detail; + const HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + if (instruction == fusion) { + for (int i = 0; i < 3; ++i) { + auto* operand = window_prefetch_detail.add_windows(); + operand->set_operand(i); + operand->set_size(32); + } + } + return window_prefetch_detail; + }; + + Options options = DefaultMemorySpaceOptions(); + options.enable_window_prefetch = true; + options.window_prefetch_detail_fn = window_prefetch_detail_fn; + AssignMemorySpace(module.get(), options, /*max_prefetch_interval=*/10, + /*min_prefetch_interval=*/0); + const HloInstruction* fusion = FindInstruction(module.get(), "fusion"); + // The fusion instruction should have 5 operands: the 3 original operands + // plus 2 window prefetch buffers. + EXPECT_EQ(fusion->operand_count(), 5); + + // The root of the fusion should be a WindowPrefetchBuffer. The first operand + // should be the original root, and the second and third operands should be + // the window prefetch buffers. + HloInstruction* root = fusion->fused_expression_root(); + EXPECT_TRUE(root->IsCustomCall("WindowPrefetchBuffer")); + EXPECT_EQ(root->operand_count(), 3); + EXPECT_EQ(root->operand(1), fusion->fused_parameter(3)); + EXPECT_EQ(root->operand(2), fusion->fused_parameter(4)); + VLOG(2) << "module: " << module->ToString(); +} + using AsynchronousCopyOrderingTest = ::testing::Test; TEST_F(AsynchronousCopyOrderingTest, Simple) { diff --git a/third_party/xla/xla/service/memory_space_assignment/options.h b/third_party/xla/xla/service/memory_space_assignment/options.h index 3a1d8488118afb..fb9730ced90641 100644 --- a/third_party/xla/xla/service/memory_space_assignment/options.h +++ b/third_party/xla/xla/service/memory_space_assignment/options.h @@ -57,6 +57,10 @@ using ReservedScopedMemoryFunction = std::function& /*outputs_in_alternate_memory*/)>; using PositionRequiresContiguousAllocationFunction = std::function; +using WindowPrefetchDetailFunction = + std::function; +using WindowPrefetchNotifyOperandAppendedFunction = + std::function; // The different options to be passed to the Run() API. struct Options { @@ -111,6 +115,15 @@ struct Options { position_requires_contiguous_allocation_fn = [](const HloPosition&) { return false; }; + // This function is called to get details about window prefetches. + WindowPrefetchDetailFunction window_prefetch_detail_fn = + [](const HloInstruction*) { return WindowPrefetchDetail(); }; + + // This function is called to notify that an operand has been appended as a + // window prefetch buffer. + WindowPrefetchNotifyOperandAppendedFunction notify_operand_appended_fn = + [](HloInstruction*, int64_t, int64_t) {}; + // If true, we will try to reduce scoped allocation buffer size for all // instructions if their operand/output has been allocated in alternate // memory. @@ -234,6 +247,13 @@ struct Options { // Option to always spill buffers from alternate memory to default memory // and prefetching back to alternate memory(if needed) just in time for use. bool always_spill_to_default_memory = false; + + // If true, enables window prefetching. Window prefetching is a mechanism + // where we prefetch windows of data into the alternate memory before the + // first use of the buffer. This allows large tensors to be prefetched as well + // and gives MSA more flexibility in choosing the prefetch time and how much + // data to prefetch. + bool enable_window_prefetch = false; }; } // namespace memory_space_assignment } // namespace xla diff --git a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc index f3433ce7b569de..3b61a70f9309f5 100644 --- a/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc +++ b/third_party/xla/xla/service/memory_space_assignment/simulator_test.cc @@ -39,7 +39,7 @@ limitations under the License. #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/memory_space_propagation_test.cc b/third_party/xla/xla/service/memory_space_propagation_test.cc index 940a4ebbcc400e..98ae47c8b164f2 100644 --- a/third_party/xla/xla/service/memory_space_propagation_test.cc +++ b/third_party/xla/xla/service/memory_space_propagation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/name_uniquer.cc b/third_party/xla/xla/service/name_uniquer.cc index 6fb7351251b57a..124cd6f427e119 100644 --- a/third_party/xla/xla/service/name_uniquer.cc +++ b/third_party/xla/xla/service/name_uniquer.cc @@ -83,8 +83,8 @@ std::string NameUniquer::GetUniqueName(absl::string_view prefix) { int64_t numeric_suffix = 0; size_t separator_index = root.rfind(separator_); if (separator_index != std::string::npos && (separator_index > 0) && - (separator_index < root.size() - 1)) { - std::string after_suffix = root.substr(separator_index + 1); + (separator_index < root.size() - separator_.size())) { + std::string after_suffix = root.substr(separator_index + separator_.size()); if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) { has_numeric_suffix = true; // Remove numeric suffix from root. diff --git a/third_party/xla/xla/service/name_uniquer_test.cc b/third_party/xla/xla/service/name_uniquer_test.cc index 6ebdfffedb73d0..64e02229d1a871 100644 --- a/third_party/xla/xla/service/name_uniquer_test.cc +++ b/third_party/xla/xla/service/name_uniquer_test.cc @@ -14,17 +14,12 @@ limitations under the License. ==============================================================================*/ #include "xla/service/name_uniquer.h" - -#include -#include -#include - #include "tsl/platform/test.h" namespace xla { namespace { -class NameUniquerTest : public ::testing::Test {}; +using NameUniquerTest = ::testing::Test; TEST_F(NameUniquerTest, SimpleUniquer) { NameUniquer uniquer; @@ -126,5 +121,13 @@ TEST_F(NameUniquerTest, AvoidKeywords) { EXPECT_EQ("Pred", uniquer.GetUniqueName("Pred")); } +TEST_F(NameUniquerTest, DetectSeparator) { + NameUniquer uniquer; + + EXPECT_EQ(uniquer.GetUniqueName("a__1"), "a__1"); + EXPECT_EQ(uniquer.GetUniqueName("a"), "a"); + EXPECT_EQ(uniquer.GetUniqueName("a"), "a__2"); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/pattern_matcher.h b/third_party/xla/xla/service/pattern_matcher.h index b17c53a9baf699..76979f097ef1f9 100644 --- a/third_party/xla/xla/service/pattern_matcher.h +++ b/third_party/xla/xla/service/pattern_matcher.h @@ -16,34 +16,43 @@ limitations under the License. #ifndef XLA_SERVICE_PATTERN_MATCHER_H_ #define XLA_SERVICE_PATTERN_MATCHER_H_ -#include +#include +#include #include #include #include #include #include #include +#include #include #include #include #include "absl/algorithm/container.h" #include "absl/container/inlined_vector.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "absl/utility/utility.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/ptrvec.h" +#include "xla/layout.h" #include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/util.h" #include "xla/xla_data.pb.h" namespace xla { @@ -2673,6 +2682,7 @@ XLA_UNOP_PATTERN(RoundNearestAfz) XLA_UNOP_PATTERN(Bitcast) XLA_UNOP_PATTERN(BitcastConvert) XLA_UNOP_PATTERN(Broadcast) +XLA_UNOP_PATTERN(Cbrt) XLA_UNOP_PATTERN(Ceil) XLA_UNOP_PATTERN(Convert) XLA_UNOP_PATTERN(Copy) @@ -2686,6 +2696,7 @@ XLA_UNOP_PATTERN(CollectivePermute) XLA_UNOP_PATTERN(CollectivePermuteStart) XLA_UNOP_PATTERN(CollectivePermuteDone) XLA_UNOP_PATTERN(Domain) +XLA_UNOP_PATTERN(Erf) XLA_UNOP_PATTERN(Exp) XLA_UNOP_PATTERN(Expm1) XLA_UNOP_PATTERN(Fft) @@ -2695,6 +2706,7 @@ XLA_UNOP_PATTERN(Imag) XLA_UNOP_PATTERN(Infeed) XLA_UNOP_PATTERN(IsFinite) XLA_UNOP_PATTERN(Log) +XLA_UNOP_PATTERN(Logistic) XLA_UNOP_PATTERN(Not) XLA_UNOP_PATTERN(Negate) XLA_UNOP_PATTERN(OptimizationBarrier) diff --git a/third_party/xla/xla/service/pattern_matcher_gmock.h b/third_party/xla/xla/service/pattern_matcher_gmock.h index e183211d645d50..eeb7b1caabb4e1 100644 --- a/third_party/xla/xla/service/pattern_matcher_gmock.h +++ b/third_party/xla/xla/service/pattern_matcher_gmock.h @@ -18,7 +18,10 @@ limitations under the License. #include +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/test.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc index 81cff291024fe8..c0a279537f686d 100644 --- a/third_party/xla/xla/service/pattern_matcher_gmock_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_gmock_test.cc @@ -15,7 +15,15 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" +#include +#include +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/service/pattern_matcher.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/test.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/pattern_matcher_test.cc b/third_party/xla/xla/service/pattern_matcher_test.cc index cd020c821b0c00..73da06ae7c1eea 100644 --- a/third_party/xla/xla/service/pattern_matcher_test.cc +++ b/third_party/xla/xla/service/pattern_matcher_test.cc @@ -15,14 +15,25 @@ limitations under the License. #include "xla/service/pattern_matcher.h" +#include +#include #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal_util.h" #include "xla/service/hlo_parser.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/util.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc index ff2d766b8b07c2..e795c475792c76 100644 --- a/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc +++ b/third_party/xla/xla/service/profile_guided_latency_estimator_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_schedule.h" #include "xla/service/latency_hiding_scheduler.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" diff --git a/third_party/xla/xla/service/real_imag_expander_test.cc b/third_party/xla/xla/service/real_imag_expander_test.cc index a7349a64011d62..429042745427f0 100644 --- a/third_party/xla/xla/service/real_imag_expander_test.cc +++ b/third_party/xla/xla/service/real_imag_expander_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/reshape_mover_test.cc b/third_party/xla/xla/service/reshape_mover_test.cc index 8c1bce4d0103f7..5ad138e1a94302 100644 --- a/third_party/xla/xla/service/reshape_mover_test.cc +++ b/third_party/xla/xla/service/reshape_mover_test.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/scatter_expander_test.cc b/third_party/xla/xla/service/scatter_expander_test.cc index a74eabf4080ef7..4d135d3bb26dad 100644 --- a/third_party/xla/xla/service/scatter_expander_test.cc +++ b/third_party/xla/xla/service/scatter_expander_test.cc @@ -26,8 +26,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 53ff55705d529c..4271cc897f41d7 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -3794,6 +3794,7 @@ ShapeInference::InferCollectivePermuteDoneShape(const Shape& operand_shape) { static absl::Status ValidateGatherDimensionNumbers( const Shape& input_shape, absl::Span start_indices_shape, const GatherDimensionNumbers& dim_numbers) { + // Validate offset_dims in GatherDimensionNumbers. if (!absl::c_is_sorted(dim_numbers.offset_dims())) { return InvalidArgument( "Output window dimensions in gather op must be ascending; got: %s.", @@ -3834,6 +3835,7 @@ static absl::Status ValidateGatherDimensionNumbers( start_indices_shape[dim_numbers.index_vector_dim()]); } + // Validate start_index_map in GatherDimensionNumbers. for (int i = 0; i < dim_numbers.start_index_map_size(); i++) { int64_t operand_dim_for_start_index_i = dim_numbers.start_index_map(i); if (operand_dim_for_start_index_i < 0 || @@ -3858,6 +3860,7 @@ static absl::Status ValidateGatherDimensionNumbers( StrJoin(dim_numbers.start_index_map(), ", ")); } + // Validate collapsed_slice_dims in GatherDimensionNumbers. for (int64_t collapsed_dim : dim_numbers.collapsed_slice_dims()) { if (collapsed_dim < 0 || collapsed_dim >= input_shape.dimensions_size()) { return InvalidArgument( @@ -3881,6 +3884,69 @@ static absl::Status ValidateGatherDimensionNumbers( StrJoin(dim_numbers.collapsed_slice_dims(), ", ")); } + // Validate operand_batching_dims and start_indices_batching_dims are of the + // same size. + if (dim_numbers.operand_batching_dims_size() != + dim_numbers.start_indices_batching_dims_size()) { + return InvalidArgument( + "operand_batching_dims and start_indices_batching_dims in gather op " + "must be of the same size; got: %d and %d.", + dim_numbers.operand_batching_dims_size(), + dim_numbers.start_indices_batching_dims_size()); + } + + // Validate operand_batching_dims in GatherDimensionNumbers. + for (int64_t operand_batching_dim : dim_numbers.operand_batching_dims()) { + if (operand_batching_dim < 0 || + operand_batching_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid operand_batching_dims set in gather op; valid range is [0, " + "%d), got: %d.", + input_shape.dimensions_size(), operand_batching_dim); + } + } + + if (!absl::c_is_sorted(dim_numbers.operand_batching_dims())) { + return InvalidArgument( + "operand_batching_dims in gather op must be sorted; got: %s", + StrJoin(dim_numbers.operand_batching_dims(), ", ")); + } + + if (absl::c_adjacent_find(dim_numbers.operand_batching_dims()) != + dim_numbers.operand_batching_dims().end()) { + return InvalidArgument( + "Repeated dimensions not allowed in operand_batching_dims in gather " + "op; " + "got: %s.", + StrJoin(dim_numbers.operand_batching_dims(), ", ")); + } + + // Validate start_indices_batching_dims in GatherDimensionNumbers. + for (int i = 0; i < dim_numbers.start_indices_batching_dims_size(); i++) { + int64_t start_indices_batching_dim_i = + dim_numbers.start_indices_batching_dims(i); + if (start_indices_batching_dim_i < 0 || + start_indices_batching_dim_i >= start_indices_shape.size()) { + return InvalidArgument( + "Invalid start_indices_batching_dims; domain is [0, %d), got: " + "%d->%d.", + start_indices_shape.size(), i, start_indices_batching_dim_i); + } + } + + std::vector sorted_start_indices_batching_dims( + dim_numbers.start_indices_batching_dims().begin(), + dim_numbers.start_indices_batching_dims().end()); + + absl::c_sort(sorted_start_indices_batching_dims); + + if (absl::c_adjacent_find(sorted_start_indices_batching_dims) != + sorted_start_indices_batching_dims.end()) { + return InvalidArgument( + "Repeated dimensions are not allowed in start_indices_batching_dims; " + "got: %s.", + StrJoin(dim_numbers.start_indices_batching_dims(), ", ")); + } return absl::OkStatus(); } @@ -3943,13 +4009,16 @@ static absl::Status ValidateGatherDimensionNumbers( if (slice_sizes.size() != gather_dim_numbers.offset_dims_size() + - gather_dim_numbers.collapsed_slice_dims_size()) { + gather_dim_numbers.collapsed_slice_dims_size() + + gather_dim_numbers.operand_batching_dims_size()) { return InvalidArgument( "All components of the offset index in a gather op must either be a " - "offset dimension or explicitly collapsed; got len(slice_sizes)=%lu, " - "output_slice_sizes=%s, collapsed_slice_dims=%s.", + "offset dimension or explicitly collapsed or explicitly batched; got " + "len(slice_sizes)=%lu, output_slice_sizes=%s, collapsed_slice_dims=%s, " + "operand_batching_dims=%s.", slice_sizes.size(), StrJoin(gather_dim_numbers.offset_dims(), ","), - StrJoin(gather_dim_numbers.collapsed_slice_dims(), ",")); + StrJoin(gather_dim_numbers.collapsed_slice_dims(), ","), + StrJoin(gather_dim_numbers.operand_batching_dims(), ",")); } for (int i = 0; i < slice_sizes.size(); i++) { @@ -3974,6 +4043,16 @@ static absl::Status ValidateGatherDimensionNumbers( } } + for (int i = 0; i < gather_dim_numbers.operand_batching_dims_size(); i++) { + if (slice_sizes[gather_dim_numbers.operand_batching_dims(i)] > 1) { + return InvalidArgument( + "Gather op can only have operand_batching_dims with bound 1 or 0, " + "but bound is %d for index %d at position %d.", + slice_sizes[gather_dim_numbers.operand_batching_dims(i)], + gather_dim_numbers.operand_batching_dims(i), i); + } + } + int64_t result_rank = gather_dim_numbers.offset_dims_size() + (expanded_start_indices_shape.size() - 1); int64_t offset_dims_seen = 0; @@ -3990,6 +4069,8 @@ static absl::Status ValidateGatherDimensionNumbers( absl::c_binary_search(gather_dim_numbers.offset_dims(), i); if (is_window_index) { while (absl::c_binary_search(gather_dim_numbers.collapsed_slice_dims(), + offset_dims_seen) || + absl::c_binary_search(gather_dim_numbers.operand_batching_dims(), offset_dims_seen)) { offset_dims_seen++; } @@ -4075,7 +4156,8 @@ absl::Status ValidateScatterDimensionNumbers( // Validate window size. auto window_size = dim_numbers.update_window_dims_size() + - dim_numbers.inserted_window_dims_size(); + dim_numbers.inserted_window_dims_size() + + dim_numbers.input_batching_dims_size(); if (window_size != operand_shape.rank()) { return InvalidArgument( "Scatter op has window of size %d; doesn't match operand of rank %d.", @@ -4117,6 +4199,61 @@ absl::Status ValidateScatterDimensionNumbers( StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ", ")); } + // Validate input_batching_dims and scatter_indices_batching_dims in + // ScatterDimensionNumbers. + if (dim_numbers.input_batching_dims_size() != + dim_numbers.scatter_indices_batching_dims_size()) { + return InvalidArgument( + "input_batching_dims and scatter_indices_batching_dims in scatter op " + "must be of the same size; got: %d and %d.", + dim_numbers.input_batching_dims_size(), + dim_numbers.scatter_indices_batching_dims_size()); + } + + // Validate input_batching_dims in ScatterDimensionNumbers. + if (!absl::c_is_sorted(dim_numbers.input_batching_dims())) { + return InvalidArgument( + "input_batching_dims in scatter op must be sorted; got: %s.", + StrJoin(dim_numbers.input_batching_dims(), ", ")); + } + if (absl::c_adjacent_find(dim_numbers.input_batching_dims()) != + dim_numbers.input_batching_dims().end()) { + return InvalidArgument( + "input_batching_dims in scatter op must not repeat; got: %s.", + StrJoin(dim_numbers.input_batching_dims(), ", ")); + } + for (int64_t input_batching_dim : dim_numbers.input_batching_dims()) { + if (input_batching_dim < 0 || + input_batching_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid input_batching_dims set in scatter op; valid range is [0, " + "%d), got: %d.", + operand_shape.dimensions_size(), input_batching_dim); + } + } + + // Validate scatter_indices_batching_dims in ScatterDimensionNumbers. + for (int64_t scatter_indices_batching_dim : + dim_numbers.scatter_indices_batching_dims()) { + if (scatter_indices_batching_dim < 0 || + scatter_indices_batching_dim >= scatter_indices_shape.size()) { + return InvalidArgument( + "Invalid scatter_indices_batching_dims set in scatter op; valid " + "range is [0, %d), got: %d.", + scatter_indices_shape.size(), scatter_indices_batching_dim); + } + } + std::vector sorted_scatter_indices_batching_dims( + dim_numbers.scatter_indices_batching_dims().begin(), + dim_numbers.scatter_indices_batching_dims().end()); + absl::c_sort(sorted_scatter_indices_batching_dims); + if (absl::c_adjacent_find(sorted_scatter_indices_batching_dims) != + sorted_scatter_indices_batching_dims.end()) { + return InvalidArgument( + "scatter_indices_batching_dims in scatter op must not repeat; got: %s.", + StrJoin(dim_numbers.scatter_indices_batching_dims(), ", ")); + } + return absl::OkStatus(); } @@ -4169,7 +4306,7 @@ absl::Status ValidateScatterDimensionNumbers( TF_RETURN_IF_ERROR(ExpectArray( updates_shape, absl::StrCat("updates ", operand_i, " of scatter op"))); - int64_t inserted_dims_seen = 0; + int64_t inserted_dims_seen = 0, input_batching_dims_seen = 0; std::vector max_update_slice_sizes; const auto dimensions_size = operand_shape.dimensions_size(); max_update_slice_sizes.reserve(dimensions_size); @@ -4178,6 +4315,11 @@ absl::Status ValidateScatterDimensionNumbers( scatter_dim_numbers.inserted_window_dims_size() && scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { ++inserted_dims_seen; + } else if (input_batching_dims_seen < + scatter_dim_numbers.input_batching_dims_size() && + scatter_dim_numbers.input_batching_dims( + input_batching_dims_seen) == i) { + ++input_batching_dims_seen; } else { max_update_slice_sizes.push_back(operand_shape.dimensions(i)); } diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 14c3e804563815..29ae32add358e3 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -2870,6 +2870,24 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { << ShapeUtil::HumanString(gather_shape); } +TEST_F(GatherShapeInferenceTest, TensorFlowGatherBatchingDims) { + TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, + ShapeInference::InferGatherShape( + ShapeUtil::MakeShape(F32, {100, 64, 5, 48}), + ShapeUtil::MakeShape(S64, {5, 100, 32}), + HloGatherInstruction::MakeGatherDimNumbers( + /*offset_dims=*/{3}, + /*collapsed_slice_dims=*/{1}, + /*start_index_map=*/{1}, + /*index_vector_dim=*/3, + /*operand_batching_dims=*/{0, 2}, + /*start_indices_batching_dims=*/{1, 0}), + /*slice_sizes=*/{1, 1, 1, 8})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {5, 100, 32, 8}))) + << ShapeUtil::HumanString(gather_shape); +} + TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(const Shape gather_shape, ShapeInference::InferGatherShape( @@ -3481,6 +3499,27 @@ TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { << statusor.status(); } +TEST_P(ScatterShapeInferenceTest, + TfScatterBatchingDimsWithUpdatesBiggerThanInput) { + const auto shapes = CreateShapes({100, 64, 48}, s64_tensor({100, 32}), + {100, 65, 32}, types()); + const absl::StatusOr statusor = ShapeInference::InferScatterShape( + shapes.ptrs, to_apply(types()), + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{2}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/2, + /*input_batching_dims=*/{0}, + /*scatter_indices_batching_dims=*/{0})); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + TEST_P(ScatterShapeInferenceTest, TfScatterWithUpdatesNotMatchingIndices) { const auto shapes = CreateShapes({64, 48}, s64_vector(32), {64, 31}, types()); const absl::StatusOr statusor = ShapeInference::InferScatterShape( diff --git a/third_party/xla/xla/service/sharding_propagation.cc b/third_party/xla/xla/service/sharding_propagation.cc index 0a8e3cf14a42f4..5239d6c7d30575 100644 --- a/third_party/xla/xla/service/sharding_propagation.cc +++ b/third_party/xla/xla/service/sharding_propagation.cc @@ -358,7 +358,7 @@ bool SupportSpatialPartitioning( computation_map.find(instruction->parent()) == computation_map.end() && !(is_entry_root && allow_spmd_sharding_propagation_to_output)) { // We don't support sharding the root instruction of a computation yet, - // unless the computation is a while body. + // unless the computation is in computation_map. return false; } @@ -2037,8 +2037,7 @@ bool InferDynamicUpdateSliceShardingFromOperand0( } bool ShardingPropagation::InferShardingFromShardGroup( - HloInstruction* instruction, const ComputationMap& computation_map, - int64_t aggressiveness, + HloInstruction* instruction, int64_t aggressiveness, const absl::flat_hash_set& shard_group) { if (!CanPropagateThroughAtAggressiveLevel(*instruction, aggressiveness)) { return false; @@ -2887,14 +2886,23 @@ absl::StatusOr ShardingPropagation::Run( return std::vector{inst, callee->root_instruction()}; } else if (inst->opcode() == HloOpcode::kParameter) { auto it = computation_map.find(inst->parent()); - if (it != computation_map.end() && - it->second->opcode() == HloOpcode::kConditional) { - HloInstruction* cond = it->second; - for (int64_t i = 1; i < cond->operand_count(); ++i) { - if (cond->called_computations()[i - 1] == inst->parent()) { - return std::vector{inst, cond->mutable_operand(i)}; + if (it != computation_map.end()) { + if (it->second->opcode() == HloOpcode::kConditional) { + HloInstruction* cond = it->second; + for (int64_t i = 1; i < cond->operand_count(); ++i) { + if (cond->called_computations()[i - 1] == inst->parent()) { + return std::vector{inst, + cond->mutable_operand(i)}; + } } } + if (it->second->opcode() == HloOpcode::kCall) { + HloInstruction* call = it->second; + int64_t operand_index = inst->parameter_number(); + CHECK_LT(operand_index, call->operand_count()); + return std::vector{ + inst, call->mutable_operand(operand_index)}; + } } return std::vector{}; } else { @@ -2937,9 +2945,11 @@ absl::StatusOr ShardingPropagation::Run( auto it = computation_map.find(instruction->parent()); if (it != computation_map.end()) { propagate_to_instruction(it->second); - // Propagate parameter shardings back to conditional's operands. + // Propagate parameter shardings back to conditional's and + // call's operands. if (instruction->opcode() == HloOpcode::kParameter && - it->second->opcode() == HloOpcode::kConditional) { + (it->second->opcode() == HloOpcode::kConditional || + it->second->opcode() == HloOpcode::kCall)) { propagate_to_instruction(instruction); } } @@ -2955,8 +2965,8 @@ absl::StatusOr ShardingPropagation::Run( } } - // Populate computation_map in order to associate while bodies to their - // while instructions. + // Populate computation_map in order to associate while bodies and conditions + // to their while instructions. for (auto computation : module->computations(execution_threads)) { for (auto instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile || @@ -2983,6 +2993,7 @@ absl::StatusOr ShardingPropagation::Run( } if (instruction->opcode() == HloOpcode::kWhile) { computation_map[instruction->while_body()] = instruction; + computation_map[instruction->while_condition()] = instruction; } else { for (HloComputation* c : instruction->called_computations()) { computation_map[c] = instruction; @@ -3134,8 +3145,8 @@ absl::StatusOr ShardingPropagation::Run( continue; } already_inferred_from_shard_group.insert(instruction); - if (InferShardingFromShardGroup(instruction, computation_map, - aggressiveness, shard_group)) { + if (InferShardingFromShardGroup(instruction, aggressiveness, + shard_group)) { ++inferred_from_shard_group_counter; any_changed = true; VLOG(2) << "Add sharding (shard group): " diff --git a/third_party/xla/xla/service/sharding_propagation.h b/third_party/xla/xla/service/sharding_propagation.h index 66be9e7e501e32..22cb7af042545d 100644 --- a/third_party/xla/xla/service/sharding_propagation.h +++ b/third_party/xla/xla/service/sharding_propagation.h @@ -140,8 +140,7 @@ class ShardingPropagation : public HloModulePass { private: bool InferShardingFromShardGroup( - HloInstruction* instruction, const ComputationMap& computation_map, - int64_t aggressiveness, + HloInstruction* instruction, int64_t aggressiveness, const absl::flat_hash_set& shard_group); bool InferShardingFromOperands( HloInstruction* instruction, const ComputationMap& computation_map, diff --git a/third_party/xla/xla/service/sharding_propagation_test.cc b/third_party/xla/xla/service/sharding_propagation_test.cc index ac04389c805878..072f43644ccd83 100644 --- a/third_party/xla/xla/service/sharding_propagation_test.cc +++ b/third_party/xla/xla/service/sharding_propagation_test.cc @@ -2757,6 +2757,60 @@ ENTRY %entry { } } +TEST_F(ShardingPropagationTest, PropagateShardingInWhileCondition) { + const char* const hlo_string = R"( +HloModule module + +%cond { + %vars.cond = (u32[], f32[]) parameter(0) + %count.cond = u32[] get-tuple-element(%vars.cond), index=0 + %limit = u32[] constant(10) + ROOT %lt = pred[] compare(%count.cond, %limit), direction=LT +} + +%body { + %vars = (u32[], f32[]) parameter(0) + %count = u32[] get-tuple-element(%vars), index=0 + %acc = f32[] get-tuple-element(%vars), index=1 + + %one = u32[] constant(1) + %count.1 = u32[] add(u32[] %count, u32[] %one) + %acc.1 = f32[] add(f32[] %acc, f32[] %acc) + ROOT %tuple = (u32[], f32[]) tuple(%count.1, %acc.1) +} + +ENTRY %entry { + %p0 = f32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + %zero = u32[] constant(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + %init = (u32[], f32[]) tuple(%zero, %p0) + ROOT %while = (u32[], f32[]) while(%init), body=%body, condition=%cond +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation(/*is_spmd=*/false, /*propagate_metadata=*/false, + /*allow_spmd_sharding_propagation_to_output=*/{true}) + .Run(module.get())); + EXPECT_TRUE(changed); + HloSharding single_sharding = + ParseSharding("{devices=[2,2]<=[4] last_tile_dims={manual, replicated}}") + .value(); + HloSharding tuple_sharding = HloSharding::SingleTuple( + module->entry_computation()->root_instruction()->shape(), + single_sharding); + + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + EXPECT_TRUE(instruction->has_sharding()); + EXPECT_EQ(instruction->sharding(), instruction->shape().IsTuple() + ? tuple_sharding + : single_sharding); + } + } +} + TEST_P(ParameterizedMetadataTest, WhileGetShardingFromRecvInBody) { const char* const hlo_string = R"( HloModule module @@ -12070,5 +12124,36 @@ ENTRY %elementwise { "last_tile_dim_replicate}}")); } +TEST_F(ShardingPropagationTest, CallPropagation) { + const absl::string_view hlo_string = R"( +HloModule module + +called_computation { + p0 = bf16[20,2,68096,8512] parameter(0) + %add_called_comp = bf16[20,2,68096,8512] add(p0, p0) + ROOT tuple = (bf16[20,2,68096,8512]) tuple(add_called_comp) +} + +ENTRY main { + %param0 = bf16[20,2,68096,8512] parameter(0) + %add = bf16[20,2,68096,8512] add(param0, param0) + ROOT %call = (bf16[20,2,68096,8512]) call(add), to_apply=%called_computation, sharding={{devices=[1,1,16,64]<=[64,16]T(1,0)}} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, + ShardingPropagation( + /*is_spmd=*/true, /*propagate_metadata=*/true, + /*allow_spmd_sharding_propagation_to_output=*/{false}, + /*allow_spmd_sharding_propagation_to_parameters=*/{false}) + .Run(module.get())); + XLA_VLOG_LINES(1, module->ToString()); + EXPECT_TRUE(changed); + auto* add = FindInstruction(module.get(), "add"); + ASSERT_NE(add, nullptr); + EXPECT_THAT(add, op::Sharding("{devices=[1,1,16,64]<=[64,16]T(1,0)}")); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/slice_sinker_test.cc b/third_party/xla/xla/service/slice_sinker_test.cc index cbbdafc877cda2..413710bd6a225b 100644 --- a/third_party/xla/xla/service/slice_sinker_test.cc +++ b/third_party/xla/xla/service/slice_sinker_test.cc @@ -30,8 +30,8 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/sort_simplifier_test.cc b/third_party/xla/xla/service/sort_simplifier_test.cc index ea8f208271a571..678ce7c37eb905 100644 --- a/third_party/xla/xla/service/sort_simplifier_test.cc +++ b/third_party/xla/xla/service/sort_simplifier_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/space_to_batch_converter.cc b/third_party/xla/xla/service/space_to_batch_converter.cc index 751b6d11dc979c..45f21136b9e13b 100644 --- a/third_party/xla/xla/service/space_to_batch_converter.cc +++ b/third_party/xla/xla/service/space_to_batch_converter.cc @@ -1734,6 +1734,10 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, } if (consumer->opcode() == HloOpcode::kReduce) { + // Do not propagate through tuple outputs. + if (consumer->shape().IsTuple()) { + return false; + } // Support only the trivial case where both batch and split spatial dim are // being reduced @@ -1741,8 +1745,13 @@ bool ConvolutionVisitor::SupportedOpForPropagation(HloInstruction* consumer, auto result = instr_to_dim_map_[consumer->mutable_operand(0)]; const int64_t batch_dim = result[DimMapper(SpaceToBatchDimMap::kBatch)]; const int64_t space_dim = result[DimMapper(SpaceToBatchDimMap::kSpace0)]; - VLOG(1) << "Checking if reduce is supported batch_dim " << batch_dim - << " space_dim " << space_dim << " reduce " << consumer->ToString(); + // Support the trivial case where none of the batch and split spatial dim + // are being reduced. + return !absl::c_linear_search(reduce_dims, batch_dim) && + !absl::c_linear_search(reduce_dims, space_dim); + + // Support only the trivial case where both batch and split spatial dim are + // being reduced return absl::c_linear_search(reduce_dims, batch_dim) && absl::c_linear_search(reduce_dims, space_dim); } @@ -2072,16 +2081,116 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, } if (consumer->opcode() == HloOpcode::kReduce) { - auto new_consumer = computation->AddInstruction(consumer->Clone()); + auto reduce_dims = consumer->dimensions(); + auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; auto first_operand = old_to_new_instrs_[consumer->mutable_operand(0)]; + auto permute_dims = instr_to_dim_permute_map_[first_operand]; - auto dim_map_val = instr_to_dim_map_[consumer->mutable_operand(0)]; const int64_t old_batch_dim = dim_map_val[DimMapper(SpaceToBatchDimMap::kBatch)]; + const int64_t space_dim = + dim_map_val[DimMapper(SpaceToBatchDimMap::kSpace0)]; - auto permute_dims = instr_to_dim_permute_map_[first_operand]; const int64_t new_batch_dim = DimLookUp(permute_dims, old_batch_dim); + const int64_t new_space_dim = DimLookUp(permute_dims, space_dim); + std::vector changed_dims(consumer->dimensions().size()); + + // Support the trivial case where none of the batch and split spatial dim + // are being reduced. + if (!absl::c_linear_search(reduce_dims, old_batch_dim) && + !absl::c_linear_search(reduce_dims, space_dim)) { + for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { + changed_dims[i] = DimLookUp(permute_dims, consumer->dimensions(i)); + } + + // Decide where the new batch and space dims are in the output. + int64_t new_output_batch_dim = new_batch_dim; + int64_t new_output_space_dim = new_space_dim; + for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { + if (changed_dims[i] < new_batch_dim) { + new_output_batch_dim--; + } + if (changed_dims[i] < new_space_dim) { + new_output_space_dim--; + } + } + + // Decide where the new batch and space dims are in the original reduce's + // output. + int64_t old_output_batch_dim = old_batch_dim; + int64_t old_output_space_dim = space_dim; + for (int64_t i = 0; i < consumer->dimensions().size(); ++i) { + if (reduce_dims[i] < old_batch_dim) { + old_output_batch_dim--; + } + if (reduce_dims[i] < space_dim) { + old_output_space_dim--; + } + } + + HloInstruction* new_consumer = nullptr; + TF_ASSIGN_OR_RETURN( + new_consumer, + MakeReduceHlo(first_operand, consumer->mutable_operand(1), + changed_dims, consumer->called_computations()[0])); + + VLOG(3) << " new_output_batch_dim " << new_output_batch_dim << " size " + << first_operand->shape().dimensions(new_batch_dim) + << " new_output_space_dim " << new_output_space_dim << " size " + << first_operand->shape().dimensions(new_space_dim); + + std::vector dim_map(kNumMappedDims); + dim_map[DimMapper(SpaceToBatchDimMap::kBatch)] = old_output_batch_dim; + dim_map[DimMapper(SpaceToBatchDimMap::kSpace0)] = old_output_space_dim; + // We don't know where the feature dim is, so set it to -1. + dim_map[DimMapper(SpaceToBatchDimMap::kFeature)] = -1; + + instr_to_dim_map_[consumer] = dim_map; + const int64_t rank = first_operand->shape().rank(); + + const int64_t output_rank = new_consumer->shape().rank(); + + // Make a map of each dim in original reduce output to input. + std::vector old_reduce_output_to_input(output_rank); + int dim_number_to_assign_old = 0; + for (int64_t i = 0; i < rank; ++i) { + if (auto it = absl::c_find(reduce_dims, i); it != reduce_dims.end()) { + continue; + } + old_reduce_output_to_input[i] = dim_number_to_assign_old++; + } + // Make a map of each dim in new reduce output to the new input. + std::vector new_reduce_output_to_input(output_rank); + int dim_number_to_assign_new = 0; + for (int64_t i = 0; i < rank; ++i) { + if (auto it = absl::c_find(changed_dims, i); it != changed_dims.end()) { + continue; + } + new_reduce_output_to_input[i] = dim_number_to_assign_new++; + } + + std::vector new_permute_dims(output_rank); + // From the output dims to input dims mapping, figure how the old output + // dims are mapped to the new output dims. + for (int64_t i = 0; i < output_rank; ++i) { + new_permute_dims[i] = std::distance( + new_reduce_output_to_input.begin(), + absl::c_find( + new_reduce_output_to_input, + DimLookUp(permute_dims, old_reduce_output_to_input[i]))); + } + + instr_to_dim_permute_map_[new_consumer] = new_permute_dims; + old_to_new_instrs_[consumer] = new_consumer; + + // Because batch and split spatial dims are not reduced, further + // propagation is needed. + return true; + } + + HloInstruction* new_consumer = + computation->AddInstruction(consumer->Clone()); auto retval = GetSpatialDimsToSplit(consumer->mutable_operand(0)); std::vector old_spatial_dims = retval.first; std::vector new_spatial_dims = retval.second; @@ -2092,7 +2201,6 @@ absl::StatusOr ConvolutionVisitor::Propagate(HloInstruction* consumer, consumer->mutable_operand(1), new_batch_dim, new_spatial_dims, old_batch_dim, old_spatial_dims)); - std::vector changed_dims(new_consumer->dimensions().size()); for (int64_t i = 0; i < new_consumer->dimensions().size(); ++i) { changed_dims[i] = DimLookUp(permute_dims, new_consumer->dimensions(i)); } @@ -3746,8 +3854,9 @@ bool ConvolutionVisitor::DoesConvolutionFeedUnpropagatableOp( } int64_t depth_to_use = depth; - // When we see a convolution, we reduce the depth to look further for. - if (user->opcode() == HloOpcode::kConvolution) { + // When we see a convolution/dot, we reduce the depth to look further for. + if (user->opcode() == HloOpcode::kConvolution || + user->opcode() == HloOpcode::kDot) { depth_to_use--; } diff --git a/third_party/xla/xla/service/space_to_batch_converter_test.cc b/third_party/xla/xla/service/space_to_batch_converter_test.cc index e2ed3314bc4f6f..dbc11e962ef23c 100644 --- a/third_party/xla/xla/service/space_to_batch_converter_test.cc +++ b/third_party/xla/xla/service/space_to_batch_converter_test.cc @@ -272,5 +272,84 @@ TEST_F(SpaceToBatchConverterTest, PropagateThroughDot) { ASSERT_TRUE(converter.Run(module.get()).value()); } +TEST_F(SpaceToBatchConverterTest, PropagateOnTrivialReduce) { + std::string hlo_string = R"( + HloModule module + + %region_1.37 (Arg_0.38: f32[], Arg_1.39: f32[]) -> f32[] { + %Arg_0.38 = f32[] parameter(0) + %Arg_1.39 = f32[] parameter(1) + ROOT %add.40 = f32[] add(f32[] %Arg_0.38, f32[] %Arg_1.39) + } + + ENTRY computation { + %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) + %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) + %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1), + window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f + %constant.5 = f32[] constant(0) + ROOT %reduce.41 = f32[7,160,400]{2,1,0} reduce(%c, %constant.5), dimensions={3}, to_apply=%region_1.37 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + SpaceToBatchConverter converter( + SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); + ASSERT_TRUE(converter.Run(module.get()).value()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Transpose()); + EXPECT_THAT(root->operand(0)->operand(0)->operand(0)->operand(0), + op::Reduce()); + auto new_reduce = root->operand(0)->operand(0)->operand(0)->operand(0); + // Make sure we propagated on the reduce with the larger batch size. + EXPECT_EQ(new_reduce->shape().dimensions(1), + // batch*number_of_splits + 7 * 8); +} + +TEST_F(SpaceToBatchConverterTest, DoNotPropagateOnTupleReduce) { + std::string hlo_string = R"( + HloModule module + +%minmax_func.2717 { + %lhs_value.2718 = f32[] parameter(0) + %rhs_value.2720 = f32[] parameter(2) + %compare.2722 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=GE + %select.2723 = f32[] select(pred[] %compare.2722, f32[] %lhs_value.2718, f32[] %rhs_value.2720) + %compare.2725 = pred[] compare(f32[] %lhs_value.2718, f32[] %rhs_value.2720), direction=EQ + %lhs_index.2719 = f32[] parameter(1) + %rhs_index.2721 = f32[] parameter(3) + %minimum.2726 = f32[] minimum(f32[] %lhs_index.2719, f32[] %rhs_index.2721) + %select.2724 = f32[] select(pred[] %compare.2722, f32[] %lhs_index.2719, f32[] %rhs_index.2721) + %select.2727 = f32[] select(pred[] %compare.2725, f32[] %minimum.2726, f32[] %select.2724) + ROOT %tuple.4 = (f32[], f32[]) tuple(f32[] %select.2723, f32[] %select.2727) + } + + ENTRY computation { + %p0 = bf16[7,320,800,3]{3,2,1,0} parameter(0) + %p1 = bf16[3,3,3,32]{3,2,1,0} parameter(1) + %c = f32[7,160,400,32]{3,2,1,0} convolution( %p0, %p1), + window={size=3x3 stride=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f + %constant.5 = f32[] constant(0) + %constant.6 = f32[] constant(1) + ROOT %reduce.36 = (f32[7,160,400]{2,1,0}, f32[7,160,400]{2,1,0}) reduce(%c, %c, + %constant.5, %constant.6), dimensions={3}, to_apply=%minmax_func.2717 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + + auto computation = module->entry_computation(); + SpaceToBatchConverter converter( + SpaceToBatchController{true, true, true, true, /*number_of_splits=*/8}); + ASSERT_TRUE(converter.Run(module.get()).value()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Reduce()); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/spmd/BUILD b/third_party/xla/xla/service/spmd/BUILD index 9d7dbcbca473cc..dcd18651e770d5 100644 --- a/third_party/xla/xla/service/spmd/BUILD +++ b/third_party/xla/xla/service/spmd/BUILD @@ -108,6 +108,7 @@ xla_cc_test( "//xla/service:sharding_format_picker", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -115,7 +116,6 @@ xla_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], diff --git a/third_party/xla/xla/service/spmd/convolution_handler.cc b/third_party/xla/xla/service/spmd/convolution_handler.cc index 3985a81c810b01..a084c2ec98fae6 100644 --- a/third_party/xla/xla/service/spmd/convolution_handler.cc +++ b/third_party/xla/xla/service/spmd/convolution_handler.cc @@ -1028,43 +1028,8 @@ absl::Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { if (hlo->sharding().HasUniqueDevice()) { return DefaultAction(hlo); } - auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); - dot_as_convolution_util::DotConvolutionDimsInfo mapping; - for (const auto& dims : dims_info.batch_dims) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dims.lhs; - mapping.batch_dims.back().rhs = dims.rhs; - mapping.batch_dims.back().output = dims.output; - mapping.batch_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.contracting_dims) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dims.lhs; - mapping.contracting_dims.back().rhs = dims.rhs; - mapping.contracting_dims.back().output = dims.output; - mapping.contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.lhs_non_contracting_dims) { - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.lhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.lhs_non_contracting_dims.back().output = dims.output; - mapping.lhs_non_contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.rhs_non_contracting_dims) { - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = dims.lhs; - mapping.rhs_non_contracting_dims.back().rhs = dims.rhs; - mapping.rhs_non_contracting_dims.back().output = dims.output; - mapping.rhs_non_contracting_dims.back().spatial_dim = dims.spatial_dim; - } - for (const auto& dims : dims_info.conv_spatial_dims) { - mapping.conv_spatial_dims.emplace_back(); - mapping.conv_spatial_dims.back().lhs = dims.lhs; - mapping.conv_spatial_dims.back().rhs = dims.rhs; - mapping.conv_spatial_dims.back().output = dims.output; - mapping.conv_spatial_dims.back().spatial_dim = dims.spatial_dim; - } + const auto dims_info = dot_as_convolution_util::ParseConvolutionDimsInfo(hlo); + auto create_sharded_conv = [&](HloInstruction* lhs_hlo, HloInstruction* rhs_hlo, spmd::SpmdBuilder* b, @@ -1084,7 +1049,7 @@ absl::Status SpmdPartitioningVisitor::HandleConvolution(HloInstruction* hlo) { } }; - return HandleDotHelper(hlo, mapping, create_sharded_conv); + return HandleDotHelper(hlo, dims_info, create_sharded_conv); } } // namespace spmd diff --git a/third_party/xla/xla/service/spmd/dot_handler.cc b/third_party/xla/xla/service/spmd/dot_handler.cc index 123edb05a6b8ef..22f88cf0dad143 100644 --- a/third_party/xla/xla/service/spmd/dot_handler.cc +++ b/third_party/xla/xla/service/spmd/dot_handler.cc @@ -68,41 +68,8 @@ using hlo_sharding_util::GroupedSharding; } // namespace absl::Status SpmdPartitioningVisitor::HandleDot(HloInstruction* hlo) { - DotConvolutionDimsInfo mapping; - const auto& dnums = hlo->dot_dimension_numbers(); - int64_t next_output_dim = 0; - for (int64_t i = 0; i < dnums.lhs_batch_dimensions_size(); ++i) { - mapping.batch_dims.emplace_back(); - mapping.batch_dims.back().lhs = dnums.lhs_batch_dimensions(i); - mapping.batch_dims.back().rhs = dnums.rhs_batch_dimensions(i); - mapping.batch_dims.back().output = next_output_dim++; - } - for (int64_t i = 0; i < dnums.lhs_contracting_dimensions_size(); ++i) { - mapping.contracting_dims.emplace_back(); - mapping.contracting_dims.back().lhs = dnums.lhs_contracting_dimensions(i); - mapping.contracting_dims.back().rhs = dnums.rhs_contracting_dimensions(i); - mapping.contracting_dims.back().output = -1; - } - for (int64_t i = 0; i < hlo->operand(0)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.lhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.lhs_contracting_dimensions(), i)) { - continue; - } - mapping.lhs_non_contracting_dims.emplace_back(); - mapping.lhs_non_contracting_dims.back().lhs = i; - mapping.lhs_non_contracting_dims.back().rhs = -1; - mapping.lhs_non_contracting_dims.back().output = next_output_dim++; - } - for (int64_t i = 0; i < hlo->operand(1)->shape().rank(); ++i) { - if (absl::c_linear_search(dnums.rhs_batch_dimensions(), i) || - absl::c_linear_search(dnums.rhs_contracting_dimensions(), i)) { - continue; - } - mapping.rhs_non_contracting_dims.emplace_back(); - mapping.rhs_non_contracting_dims.back().lhs = -1; - mapping.rhs_non_contracting_dims.back().rhs = i; - mapping.rhs_non_contracting_dims.back().output = next_output_dim++; - } + DotConvolutionDimsInfo mapping = + dot_as_convolution_util::ParseDotGeneralFromDot(hlo); HloDotInstruction* dot = Cast(hlo); std::vector sparsity(dot->sparsity().begin(), @@ -1932,7 +1899,7 @@ absl::StatusOr PartitionBaseCase( has_reshape_operand(lhs) ? lhs.hlo()->operand(0) : lhs.hlo(); auto rhs_operand = has_reshape_operand(rhs) ? rhs.hlo()->operand(0) : rhs.hlo(); - for (auto loop : *windowed_dot_general_loops) { + for (const auto& loop : *windowed_dot_general_loops) { if (loop.while_loop->while_body()->name().find( "windowed_dot_general_body_ag") == 0) { auto cm_lhs = loop.while_loop->operand(0)->operand(0); @@ -2575,19 +2542,40 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( matching.sharding() != UngroupSharding(matching_grouped)) { return nullptr; } + + auto try_sharding_for_other_operand = [&](const HloSharding& sharding) { + PartitionedHlo other_reshard = other.Reshard(sharding); + std::optional grouped_sharding = + GetNonContractingPartitionGroupedShardingForOtherOperand( + lhs_matching, output_base_shape, other_reshard.hlo()->shape(), + other_contracting_partitions, other_non_contracting_partitions, + matching_contracting_partitions, + output_other_non_contracting_partitions, other_reshard.sharding(), + output_sharding, partitioned_non_contracting_dims, + lhs_matching ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims, + dims_mapping.contracting_dims); + if (grouped_sharding) { + other = other_reshard; + } + return grouped_sharding; + }; std::optional other_grouped = - GetNonContractingPartitionGroupedShardingForOtherOperand( - lhs_matching, output_base_shape, other.hlo()->shape(), - other_contracting_partitions, other_non_contracting_partitions, - matching_contracting_partitions, - output_other_non_contracting_partitions, other.sharding(), - output_sharding, partitioned_non_contracting_dims, - lhs_matching ? dims_mapping.rhs_non_contracting_dims - : dims_mapping.lhs_non_contracting_dims, - dims_mapping.contracting_dims); + try_sharding_for_other_operand(other.sharding()); + if (!other_grouped && !other.sharding().IsReplicated() && + dims_mapping.conv_spatial_dims.empty()) { + const HloSharding expected_other_sharding = + hlo_sharding_util::InferDotOperandSharding( + &output_sharding, &matching.sharding(), lhs_matching ? 1 : 0, + dims_mapping, true, true); + // Try the expected sharding since it is no worse than the last resort + // (replicated sharding). + other_grouped = try_sharding_for_other_operand(expected_other_sharding); + } if (!other_grouped) { other = other.Replicate(); } + matching = matching.Reshard(UngroupSharding(matching_grouped)); auto per_group_partitioner_state = CreatePerGroupPartitioningState( matching.state(), matching_grouped.device_groups, b); @@ -2606,7 +2594,7 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( partially_replicated_other = other.hlo(); top_level_sharding_to_reset.emplace_back(other.hlo(), other.sharding()); partially_replicated_other->set_sharding(other_grouped->sharding); - } else if (!other.sharding().IsReplicated()) { + } else if (other_grouped && !other.sharding().IsReplicated()) { HloSharding target_sharding = UngroupSharding(*other_grouped); GroupedSharding target_group_sharding = hlo_sharding_util::GroupShardingOnDims(target_sharding, @@ -2630,18 +2618,16 @@ absl::StatusOr PartitionDotGroupOnNonContractingImpl( partially_replicated_other, partially_replicated_other->sharding()); partially_replicated_other->set_sharding(other_grouped->sharding); } + auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), per_group_partitioner_state); - TF_ASSIGN_OR_RETURN( - auto dot, - PartitionDot(lhs_matching ? matching_p : other_p, - lhs_matching ? other_p : matching_p, - GetPerGroupBaseShape(output_grouped, output_base_shape), - output_grouped.sharding, dims_mapping, - num_partitions / matching_grouped.device_groups.size(), - create_sharded_dot, conv_window, module, original_hlo, - options, b, windowed_dot_general_loops, visitor)); - return dot; + return PartitionDot(lhs_matching ? matching_p : other_p, + lhs_matching ? other_p : matching_p, + GetPerGroupBaseShape(output_grouped, output_base_shape), + output_grouped.sharding, dims_mapping, + num_partitions / matching_grouped.device_groups.size(), + create_sharded_dot, conv_window, module, original_hlo, + options, b, windowed_dot_general_loops, visitor); } std::pair @@ -3031,6 +3017,9 @@ DotConvolutionDimsInfo ConvertDimNumsWithFeatureGroupCount( const DotConvolutionDimsInfo& dims_mapping, HloInstruction* original_hlo) { const auto& dnums = original_hlo->convolution_dimension_numbers(); DotConvolutionDimsInfo new_dims_mapping; + new_dims_mapping.lhs_shape_rank = dims_mapping.lhs_shape_rank; + new_dims_mapping.rhs_shape_rank = dims_mapping.rhs_shape_rank; + new_dims_mapping.output_shape_rank = dims_mapping.output_shape_rank; new_dims_mapping.batch_dims = dims_mapping.batch_dims; new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; // Append batch dims. @@ -3060,6 +3049,9 @@ DotConvolutionDimsInfo ConvertDimNumsWithBatchGroupCount( const DotConvolutionDimsInfo& dims_mapping, HloInstruction* original_hlo) { const auto& dnums = original_hlo->convolution_dimension_numbers(); DotConvolutionDimsInfo new_dims_mapping; + new_dims_mapping.lhs_shape_rank = dims_mapping.lhs_shape_rank; + new_dims_mapping.rhs_shape_rank = dims_mapping.rhs_shape_rank; + new_dims_mapping.output_shape_rank = dims_mapping.output_shape_rank; new_dims_mapping.batch_dims = dims_mapping.batch_dims; new_dims_mapping.conv_spatial_dims = dims_mapping.conv_spatial_dims; new_dims_mapping.contracting_dims = dims_mapping.contracting_dims; diff --git a/third_party/xla/xla/service/spmd/shardy/BUILD b/third_party/xla/xla/service/spmd/shardy/BUILD index 055673516a652b..bd15f2048ec50d 100644 --- a/third_party/xla/xla/service/spmd/shardy/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/BUILD @@ -37,7 +37,6 @@ cc_library( xla_cc_test( name = "shardy_call_inliner_test", srcs = ["shardy_call_inliner_test.cc"], - env = {"XLA_FLAGS": "--xla_use_shardy=true"}, deps = [ ":shardy_call_inliner", "//xla/hlo/ir:hlo", @@ -143,8 +142,8 @@ xla_cc_binary( "//xla/service/spmd/shardy/mhlo_round_trip:mhlo_import", "//xla/service/spmd/shardy/mhlo_round_trip:shard_map_export", "//xla/service/spmd/shardy/round_trip_common:convert_sharding_custom_calls", - "//xla/service/spmd/shardy/round_trip_common:identity_to_pass_through_while_args", "//xla/service/spmd/shardy/round_trip_common:import_constants", + "//xla/service/spmd/shardy/round_trip_common:open_while_free_vars_sharding", "//xla/service/spmd/shardy/round_trip_common:shard_map_import", "//xla/service/spmd/shardy/sdy_round_trip:export_ops", "//xla/service/spmd/shardy/sdy_round_trip:export_shardings", diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc index b6c6a990fdc714..0ffff7134c61a6 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/export_ops.cc @@ -66,7 +66,6 @@ using ::mlir::StringRef; using ::mlir::success; using ::mlir::sdy::ConstantOp; -using ::mlir::sdy::IdentityOp; using ::mlir::sdy::kShardingAttr; using ::mlir::sdy::ReshardOp; using ::mlir::sdy::ShardingConstraintOp; @@ -88,20 +87,6 @@ class ConstantPattern : public OpConversionPattern { } }; -// Removes `sdy::IdentityOp`. -class IdentityPattern : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - private: - LogicalResult matchAndRewrite( - IdentityOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOp(op, adaptor.getInput()); - return success(); - } -}; - class ReshardPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -148,15 +133,14 @@ class ExportOpsPass // We do not expect to see ShardingConstraintOp in the input module. // ShardingConstraintOp should be replaced by ReshardOp before this pass. // Hence, we add ShardingConstraintOp as an illegal op. - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); // After converting `sdy.constant` into `mhlo.constant`, the constants // should not be deduped via folding. Fortunately, folding only happens in // greedy pattern rewriters. ExportHloShardingsPass does a simple walk, // which keeps the constants as is. - patterns.add(&context); + patterns.add(&context); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); @@ -166,8 +150,8 @@ class ExportOpsPass StringRef getArgument() const override { return "xla-sdy-export-ops"; } StringRef getDescription() const override { - return "Exports Shardy ops to MHLO ops. Processes sdy::IdentityOp, " - "sdy::ReshardOp, and sdy::ConstantOp."; + return "Exports Shardy ops to MHLO ops. Processes sdy::ReshardOp and " + "sdy::ConstantOp."; } void getDependentDialects(mlir::DialectRegistry& registry) const final { diff --git a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc index f30815c6416927..f72cc4a885c7b3 100644 --- a/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc +++ b/third_party/xla/xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.cc @@ -246,10 +246,10 @@ SmallVector getOrderedSubDimsFromIotaTileAssignment( tileDimIndex--; } subDims.push_back(SubDimInfo{ - .tileDimIndex = tileDimIndex, - .tileSubDimIndex = subDim++, - .reshapeDimIndex = iota.transpose_perm()[transPermIndex], - .size = axisSize, + /* .tileDimIndex = */ tileDimIndex, + /* .tileSubDimIndex = */ subDim++, + /* .reshapeDimIndex = */ iota.transpose_perm()[transPermIndex], + /* .size = */ axisSize, }); accTileSize *= axisSize; accDeviceSize *= axisSize; @@ -296,8 +296,10 @@ AnalyzeTileAssignmentResult analyzeTileAssignment( for (SubDimInfo subDimInfo : subDims) { mesh.push_back(subDimInfo.size); } - return AnalyzeTileAssignmentResult{.subDims = std::move(subDims), - .localMesh = std::move(mesh)}; + return AnalyzeTileAssignmentResult{ + /* .subDims = */ std::move(subDims), + /* .localMesh = */ std::move(mesh), + }; } // Collect shardings with the attr name kXlaShardingAttr in the `moduleOp`. diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD index f9fd53b2120329..e929f614006e81 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/BUILD @@ -32,9 +32,9 @@ cc_library( ) cc_library( - name = "identity_to_pass_through_while_args", - srcs = ["identity_to_pass_through_while_args.cc"], - hdrs = ["identity_to_pass_through_while_args.h"], + name = "import_constants", + srcs = ["import_constants.cc"], + hdrs = ["import_constants.h"], deps = [ "//xla/mlir_hlo", "@llvm-project//llvm:Support", @@ -48,9 +48,9 @@ cc_library( ) cc_library( - name = "import_constants", - srcs = ["import_constants.cc"], - hdrs = ["import_constants.h"], + name = "open_while_free_vars_sharding", + srcs = ["open_while_free_vars_sharding.cc"], + hdrs = ["open_while_free_vars_sharding.h"], deps = [ "//xla/mlir_hlo", "@llvm-project//llvm:Support", @@ -94,8 +94,8 @@ cc_library( hdrs = ["pipeline_passes.h"], deps = [ ":convert_sharding_custom_calls", - ":identity_to_pass_through_while_args", ":import_constants", + ":open_while_free_vars_sharding", ":shard_map_import", "//xla/mlir_hlo:mhlo_passes", "@llvm-project//mlir:FuncDialect", diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc deleted file mode 100644 index e1675c6a86b866..00000000000000 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h" - -#include - -#include "llvm/ADT/StringRef.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Transforms/DialectConversion.h" -#include "shardy/dialect/sdy/ir/dialect.h" -#include "shardy/dialect/sdy/ir/utils.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace xla { -namespace sdy { - -namespace { - -using ::mlir::StringRef; - -using ::mlir::func::FuncOp; - -// For every block argument of an `mhlo::WhileOp` that is directly returned by -// the body of the op (pass-through), add an `sdy::IdentityOp` between the block -// argument and the return op. -// -// This will prevent canonicalization from replacing these block arguments with -// the corresponding operands as free variables. -class AddIdentityToPassThroughWhileArgsPass - : public mlir::PassWrapper> { - public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - AddIdentityToPassThroughWhileArgsPass) - - void runOnOperation() final { - FuncOp funcOp = getOperation(); - mlir::IRRewriter rewriter(funcOp); - - funcOp.walk([&](mlir::mhlo::WhileOp op) { - mlir::Operation* returnOp = mlir::sdy::getBodyTerminator(op); - rewriter.setInsertionPoint(returnOp); - for (mlir::Value returnValue : returnOp->getOperands()) { - if (auto blockArg = mlir::dyn_cast(returnValue); - blockArg && blockArg.getOwner() == &op.getBody().front()) { - auto identityOp = rewriter.create( - returnValue.getLoc(), returnValue); - rewriter.replaceUsesWithIf(returnValue, identityOp, - [returnOp](mlir::OpOperand& use) { - return use.getOwner() == returnOp; - }); - } - } - }); - } - - StringRef getArgument() const override { - return "xla-sdy-add-identity-to-pass-through-while-args"; - } - - StringRef getDescription() const override { - return "Adds an identity op between pass-through block arguments of a " - "while op."; - } -}; - -} // namespace - -std::unique_ptr createAddIdentityToPassThroughWhileArgsPass() { - return std::make_unique(); -} - -void registerAddIdentityToPassThroughWhileArgsPass() { - mlir::registerPass(createAddIdentityToPassThroughWhileArgsPass); -} - -} // namespace sdy -} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc new file mode 100644 index 00000000000000..603b270eefa46f --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" + +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/TypeID.h" +#include "mlir/Transforms/RegionUtils.h" +#include "shardy/dialect/sdy/ir/dialect.h" +#include "shardy/dialect/sdy/ir/utils.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace xla { +namespace sdy { + +namespace { + +using ::mlir::StringRef; +using ::mlir::func::FuncOp; +using ::mlir::sdy::TensorShardingAttr; + +class OpenWhileFreeVarsShardingPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenWhileFreeVarsShardingPass) + + void runOnOperation() final { + FuncOp funcOp = getOperation(); + mlir::IRRewriter rewriter(funcOp); + + funcOp.walk([&](mlir::mhlo::WhileOp op) { + llvm::SetVector freeVars; + mlir::getUsedValuesDefinedAbove(op->getRegions(), freeVars); + rewriter.setInsertionPoint(op); + for (mlir::Value freeVar : freeVars) { + TensorShardingAttr sharding = mlir::sdy::getSharding(freeVar); + if (!sharding || sharding.getRank() == 0) { + continue; + } + auto shardingConstraint = + rewriter.create( + freeVar.getLoc(), freeVar, + TensorShardingAttr::getFullyOpenLike(sharding)); + // Only replace uses in the regions of the while op. + rewriter.replaceUsesWithIf( + freeVar, shardingConstraint, [op](mlir::OpOperand& use) { + return op->isProperAncestor(use.getOwner()); + }); + } + }); + } + + StringRef getArgument() const override { + return "xla-sdy-open-while-free-vars-sharding"; + } + + StringRef getDescription() const override { + return "Adds a fully open sharding constraint to free variables of while " + "op that already have a sharding."; + } +}; + +} // namespace + +std::unique_ptr createOpenWhileFreeVarsShardingPass() { + return std::make_unique(); +} + +void registerOpenWhileFreeVarsShardingPass() { + mlir::registerPass(createOpenWhileFreeVarsShardingPass); +} + +} // namespace sdy +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h new file mode 100644 index 00000000000000..c06776f3c368fc --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ +#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ + +#include + +#include "mlir/Pass/Pass.h" + +namespace xla { +namespace sdy { + +// Creates a pass that adds a fully open sharding constraint to free variables +// of while op that already have a user-defined sharding. +// +// This allows for their uses in the while op to be further sharded, which is +// important when converting to HLO as they will be lifted as passthrough while +// operands/results. +std::unique_ptr createOpenWhileFreeVarsShardingPass(); + +// Registers the xla-sdy-open-while-free-vars-sharding pass. +void registerOpenWhileFreeVarsShardingPass(); + +} // namespace sdy +} // namespace xla + +#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_OPEN_WHILE_FREE_VARS_SHARDING_H_ diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc index 5ddeb43ca76dd5..23960ab48aadca 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc +++ b/third_party/xla/xla/service/spmd/shardy/round_trip_common/pipeline_passes.cc @@ -20,8 +20,8 @@ limitations under the License. #include "mlir/Transforms/Passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" -#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h" namespace xla { @@ -36,11 +36,6 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { // changes happen before shardings are added to operations, to ensure the // correct shardings are added and that they are not lost by this pass. pm.addNestedPass(mlir::mhlo::createPrepareForExportPass()); - // The prepare-for-export pass lifts `mhlo::WhileOp` free variables, and added - // them as additional operands of the op whose corresponding block arguments - // are directly returned by the body of the op (pass-through). To prevent - // canonicalization from undoing this, we add identity ops. - pm.addNestedPass(createAddIdentityToPassThroughWhileArgsPass()); // We import `mhlo.constant` ops to `sdy.constant` ops so that constants // aren't folded in greedy pattern rewriters, which would lift them outside of @@ -51,13 +46,15 @@ void addCommonPreImportPasses(mlir::OpPassManager& pm) { pm.addNestedPass(mlir::mhlo::createFlattenTuplePass()); // We need to canonicalize redundant mhlo::GetTupleElementOp and - // mhlo::GetTupleOp. + // mhlo::GetTupleOp. We also need to canonicalize mhlo::WhileOp before + // `createOpenWhileFreeVarsShardingPass`. pm.addPass(mlir::createCanonicalizerPass()); } void addCommonPostImportPasses(mlir::OpPassManager& pm) { pm.addPass(createShardMapImportPass()); pm.addPass(createConvertShardingCustomCallsPass()); + pm.addNestedPass(createOpenWhileFreeVarsShardingPass()); } } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc index c12587c287109d..b5670e78ace9b3 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_opt_main.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/service/spmd/shardy/mhlo_round_trip/mhlo_import.h" #include "xla/service/spmd/shardy/mhlo_round_trip/shard_map_export.h" #include "xla/service/spmd/shardy/round_trip_common/convert_sharding_custom_calls.h" -#include "xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h" #include "xla/service/spmd/shardy/round_trip_common/import_constants.h" +#include "xla/service/spmd/shardy/round_trip_common/open_while_free_vars_sharding.h" #include "xla/service/spmd/shardy/round_trip_common/shard_map_import.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" @@ -55,7 +55,7 @@ int main(int argc, char** argv) { xla::sdy::registerMhloImportShardingsPass(); xla::sdy::registerShardMapImportPass(); xla::sdy::registerConvertShardingCustomCallsPass(); - xla::sdy::registerAddIdentityToPassThroughWhileArgsPass(); + xla::sdy::registerOpenWhileFreeVarsShardingPass(); xla::sdy::registerImportConstantsPass(); xla::sdy::registerMhloExportPipeline(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD index 2cf8f1d2cd73f6..7b9844bc9878d8 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/BUILD @@ -11,7 +11,7 @@ package_group( packages = [ "//learning/deepmind/partir/compiler/shardonnay/...", "//third_party/openxla/shardy/tools/...", - "//xla/service/spmd/shardy/...", + "//xla/...", ], ) @@ -79,6 +79,7 @@ cc_library( ":export_shardings", ":import_shardings", "//xla/service:hlo_proto_cc", + "//xla/service/spmd/shardy/mhlo_round_trip:export_shardings", "//xla/service/spmd/shardy/round_trip_common:pipeline_passes", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc index d4e14da8d576a6..b5bd21fbeaa04f 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_ops.cc @@ -61,7 +61,6 @@ using ::mlir::StringRef; using ::mlir::success; using ::mlir::sdy::ConstantOp; -using ::mlir::sdy::IdentityOp; using ::mlir::sdy::ShardingConstraintOp; using ::mlir::sdy::TensorShardingAttr; using ::mlir::sdy::TensorShardingPerValueAttr; @@ -81,20 +80,6 @@ class ConstantPattern : public OpConversionPattern { } }; -// Removes `sdy::IdentityOp`. -class IdentityPattern : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - - private: - LogicalResult matchAndRewrite( - IdentityOp op, OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOp(op, adaptor.getInput()); - return success(); - } -}; - class ShardingConstraintPattern : public OpConversionPattern { public: @@ -130,11 +115,10 @@ class SdyRoundTripExportOpsPass void runOnOperation() final { mlir::MLIRContext& context = getContext(); mlir::ConversionTarget target(context); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(); mlir::RewritePatternSet patterns(&context); - patterns.add( - &context); + patterns.add(&context); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { signalPassFailure(); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc index aec0a20775c73a..b076d5b215785c 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.cc @@ -92,7 +92,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { if (auto oldSharding = funcOp.getArgAttrOfType( argNum, kShardingAttr)) { addFrontendAttribute(funcOp, kShardingRoundTripAttr, oldSharding, argNum); - funcOp.removeArgAttr(argNum, kShardingAttr); } } @@ -122,7 +121,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { TensorShardingPerValueAttr::get(customCallOp.getContext(), sharding), builder); returnOperand.set(customCallOp.getResult(0)); - funcOp.removeResultAttr(resultNum, builder.getStringAttr(kShardingAttr)); } } @@ -130,7 +128,6 @@ LogicalResult exportFunc(FuncOp funcOp, OpBuilder& builder) { if (auto oldShardingPerValue = op->getAttrOfType(kShardingAttr)) { saveOpShardingPerValueAttr(op, oldShardingPerValue, builder); - op->removeAttr(kShardingAttr); } }); @@ -155,8 +152,6 @@ class SdyRoundTripExportShardingsPass } SmallVector mhloMeshes; - mlir::SymbolTableCollection symbolTableCollection; - SymbolTable& symbolTable = symbolTableCollection.getSymbolTable(moduleOp); // Saves the MeshOps for MHLO<->HLO round-trip and removes them from the // ModuleOp. for (MeshOp meshOp : @@ -164,7 +159,6 @@ class SdyRoundTripExportShardingsPass mhloMeshes.emplace_back( meshOp.getSymNameAttr(), getStringAttribute(meshOp.getMeshAttr(), builder)); - symbolTable.erase(meshOp); } addFrontendAttribute(moduleOp, kMeshesRoundTripAttr, DictionaryAttr::get(context, mhloMeshes)); diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h index dfbe7108694147..4b8ce6ab737419 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/export_shardings.h @@ -29,6 +29,10 @@ void registerSdyRoundTripExportShardingsPass(); // Creates the pass that converts the shardings from `kShardingAttr` to // `kShardingRoundTripAttr` in the HLO frontend attributes and saves the // mesh symbols as `kMeshesRoundTripAttr` in the module frontend attributes. +// +// NOTE: The `kShardingAttr`s are not removed from the ops. They are kept around +// because part of the `SdyRoundTripExportPipeline` it also converts the +// `kShardingAttr`s to `kXlaShardingAttr`s. std::unique_ptr createSdyRoundTripExportShardingsPass(); } // namespace sdy diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc index a5347a3b416c65..28cdc89c7c1125 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/import_shardings.cc @@ -96,6 +96,7 @@ void convertShardings(FuncOp funcOp) { // We need to wait until after we've converted all the Operations before // copying the result shardings. for (auto [argNum, argType] : llvm::enumerate(funcOp.getArgumentTypes())) { + funcOp.removeArgAttr(argNum, kXlaShardingAttr); // Attempt to extract the TensorShardingAttr from the frontend attributes of // the function argument/result. if (DictionaryAttr dictAttr = getFuncArgFrontendAttrs(funcOp, argNum)) { @@ -106,8 +107,16 @@ void convertShardings(FuncOp funcOp) { } } + // Due to `SdyRoundTripExportShardingsPass` keeping `mhlo.sharding`s, remove + // them purely for cleanliness of the module. + for (int64_t resNum = 0; resNum < funcOp.getNumResults(); ++resNum) { + funcOp.removeResultAttr( + resNum, StringAttr::get(funcOp.getContext(), kXlaShardingAttr)); + } + // Extract the round-tripped SDY shardings from the operations. funcOp.front().walk([&](Operation* op) { + op->removeAttr(kXlaShardingAttr); if (DictionaryAttr dictAttr = getFrontendAttrs(op)) { // NOTE: we are only setting the sharding on known custom-calls. For any // other op that has a `kShardingRoundTripAttr` we discard it. XLA diff --git a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc index f9eda62025d762..ee348edad68c17 100644 --- a/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc +++ b/third_party/xla/xla/service/spmd/shardy/sdy_round_trip/pipelines.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" #include "mlir/Support/LLVM.h" #include "xla/service/hlo.pb.h" +#include "xla/service/spmd/shardy/mhlo_round_trip/export_shardings.h" #include "xla/service/spmd/shardy/round_trip_common/pipeline_passes.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_ops.h" #include "xla/service/spmd/shardy/sdy_round_trip/export_shardings.h" @@ -38,7 +39,11 @@ void addSdyRoundTripExportPipeline(mlir::OpPassManager& pm) { // `createSdyRoundTripExportShardingsPass` and make use of // `createSdyRoundTripImportShardingsPass` to import them. pm.addPass(createSdyRoundTripExportOpsPass()); + // Preserve the SDY shardings for `createExportMhloShardingsPass` so that + // we have both `mhlo.sharding`s and hidden `sdy.sharding`s on the module. We + // want to have `mhlo.sharding`s for Pathways to read from. pm.addPass(createSdyRoundTripExportShardingsPass()); + pm.addPass(createExportMhloShardingsPass()); } void addSdyRoundTripImportPipeline(mlir::OpPassManager& pm) { diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc index 73a8479dcc4fc9..9f863e23a6715d 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner.cc @@ -24,10 +24,7 @@ namespace xla { bool ShardyCallInliner::IsInlineableCallOp(HloInstruction* instruction) const { return CallInliner::IsInlineableCallOp(instruction) && !instruction->has_backend_config() && - !(instruction->GetModule() - ->config() - .debug_options() - .xla_use_shardy() && + !(instruction->GetModule()->config().use_shardy_partitioner() && absl::StrContains(instruction->to_apply()->name(), "shmap_body")); } diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc index 861e934fee5779..00d952b3b80461 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_call_inliner_test.cc @@ -45,6 +45,7 @@ TEST_F(ShardyCallInlinerTest, MhloToHloShmapBodyNotInlined) { ROOT %custom-call.9 = f32[8,8]{1,0} custom-call(f32[1,8]{1,0} %custom-call.8), custom_call_target="SPMDShardToFullShape", sharding={devices=[8,1]<=[8]}, metadata={source_file="-" source_line=7} })"; TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hloString)); + module->mutable_config().set_use_shardy_partitioner(true); TF_ASSERT_OK_AND_ASSIGN(bool changed, ShardyCallInliner().Run(module.get())); VLOG(1) << module->ToString(); // The single call in the module is not inlined. diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc index 1735b3ccc30985..2514c46d91d3b2 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/spmd/shardy/shardy_xla_pass.h" #include -#include #include #include #include diff --git a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc index 59197cc4a38c1f..40463d0dc74fce 100644 --- a/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc +++ b/third_party/xla/xla/service/spmd/shardy/shardy_xla_pass_test.cc @@ -535,36 +535,42 @@ TEST_F(ShardyXLATest, RngBitGenerator) { TEST_F(ShardyXLATest, WhileWithFreeVariables) { const char* const hloString = R"( - HloModule main - - %region_0.6 (arg_tuple.7: (f32[32,96], s32[], s32[], s32[])) -> (f32[32,96], s32[], s32[], s32[]) { - %arg_tuple.7 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0) - %get-tuple-element.8 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=0 - %add.13 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.8, f32[32,96]{1,0} %get-tuple-element.8) - %get-tuple-element.9 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=1 - %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=3 - %add.12 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.11) - %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.7), index=2 - ROOT %tuple.14 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %add.13, s32[] %add.12, s32[] %get-tuple-element.10, s32[] %get-tuple-element.11) + HloModule main, entry_computation_layout={(f32[32,96]{1,0}, f32[32,96]{1,0})->f32[32,96]{1,0}} + + %region_0.7 (arg_tuple.8: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> (f32[32,96], s32[], s32[], s32[], f32[32,96]) { + %arg_tuple.8 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0) + %get-tuple-element.9 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=0 + %get-tuple-element.13 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=4 + %add.15 = f32[32,96]{1,0} add(f32[32,96]{1,0} %get-tuple-element.9, f32[32,96]{1,0} %get-tuple-element.13), metadata={source_file="-" source_line=25} + %get-tuple-element.10 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=1 + %get-tuple-element.12 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=3 + %add.14 = s32[] add(s32[] %get-tuple-element.10, s32[] %get-tuple-element.12), metadata={source_file="-" source_line=24} + %get-tuple-element.11 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.8), index=2 + ROOT %tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %add.15, s32[] %add.14, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[32,96]{1,0} %get-tuple-element.13) } - %region_1.15 (arg_tuple.16: (f32[32,96], s32[], s32[], s32[])) -> pred[] { - %arg_tuple.16 = (f32[32,96]{1,0}, s32[], s32[], s32[]) parameter(0) - %get-tuple-element.17 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=0 - %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=3 - %get-tuple-element.18 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=1 - %get-tuple-element.19 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %arg_tuple.16), index=2 - ROOT %compare.21 = pred[] compare(s32[] %get-tuple-element.18, s32[] %get-tuple-element.19), direction=LT + %region_1.17 (arg_tuple.18: (f32[32,96], s32[], s32[], s32[], f32[32,96])) -> pred[] { + %arg_tuple.18 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) parameter(0) + %get-tuple-element.19 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=0 + %get-tuple-element.22 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=3 + %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=4 + %get-tuple-element.20 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=1 + %get-tuple-element.21 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %arg_tuple.18), index=2 + ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.20, s32[] %get-tuple-element.21), direction=LT, metadata={source_file="-" source_line=21} } - ENTRY %main.27 (Arg_0.1: f32[32,96]) -> f32[32,96] { + ENTRY %main.30 (Arg_0.1: f32[32,96], Arg_1.2: f32[32,96]) -> f32[32,96] { %Arg_0.1 = f32[32,96]{1,0} parameter(0), sharding={devices=[2,2]<=[4]} - %constant.2 = s32[] constant(0) - %constant.4 = s32[] constant(32) - %constant.3 = s32[] constant(1) - %tuple.5 = (f32[32,96]{1,0}, s32[], s32[], s32[]) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.2, s32[] %constant.4, s32[] %constant.3) - %while.22 = (f32[32,96]{1,0}, s32[], s32[], s32[]) while((f32[32,96]{1,0}, s32[], s32[], s32[]) %tuple.5), condition=%region_1.15, body=%region_0.6 - ROOT %get-tuple-element.23 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[]) %while.22), index=0 + %constant.3 = s32[] constant(0) + %constant.5 = s32[] constant(32) + %constant.4 = s32[] constant(1) + %Arg_1.2 = f32[32,96]{1,0} parameter(1), sharding={devices=[2,1,2]<=[4] last_tile_dim_replicate} + %tuple.6 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %Arg_0.1, s32[] %constant.3, s32[] %constant.5, s32[] %constant.4, f32[32,96]{1,0} %Arg_1.2), metadata={source_file="-" source_line=19} + %while.25 = (f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) while((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %tuple.6), condition=%region_1.17, body=%region_0.7, metadata={source_file="-" source_line=19} + %get-tuple-element.27 = s32[] get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=1, metadata={source_file="-" source_line=19} + %get-tuple-element.26 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}, s32[], s32[], s32[], f32[32,96]{1,0}) %while.25), index=0, metadata={source_file="-" source_line=19} + %tuple.28 = (f32[32,96]{1,0}) tuple(f32[32,96]{1,0} %get-tuple-element.26) + ROOT %get-tuple-element.29 = f32[32,96]{1,0} get-tuple-element((f32[32,96]{1,0}) %tuple.28), index=0 })"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hloString)); @@ -575,10 +581,14 @@ TEST_F(ShardyXLATest, WhileWithFreeVariables) { HloInstruction* whileInst = FindInstruction(module.get(), xla::HloOpcode::kWhile); EXPECT_NE(whileInst, nullptr); - EXPECT_THAT( - whileInst, - op::Sharding( - "{{devices=[2,2]<=[4]}, {replicated}, {replicated}, {replicated}}")); + // Verify that the sharding of parameter(1) hasn't changed. + EXPECT_THAT(module->entry_computation()->parameter_instruction(1), + op::Sharding("{devices=[2,1,2]<=[4] last_tile_dim_replicate}")); + // Verify the sharding of the while, and specifically that the sharding of the + // result that corresponds to parameter(1) is further sharded. + EXPECT_THAT(whileInst, + op::Sharding("{{devices=[2,2]<=[4]}, {replicated}, {replicated}, " + "{devices=[2,2]<=[4]}, {replicated}}")); } TEST_F(ShardyXLATest, ShardMap) { diff --git a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir index a04734d8a1f667..f191acf1aaf687 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/import_shardings.mlir @@ -126,7 +126,7 @@ func.func @unknown_sharding(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4 // ----- // CHECK-LABEL: sdy.mesh @mesh = <> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @one_maximal_mesh( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>} @@ -138,8 +138,8 @@ func.func @one_maximal_mesh(%arg0: tensor<8x8xf32> {mhlo.sharding = "{maximal de // ----- -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = -// CHECK-LABEL: sdy.mesh @maximal_mesh_4 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_4 = // CHECK-LABEL: func @two_maximal_shardings_should_be_sorted( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_4, [{}, {}]>}, @@ -151,7 +151,7 @@ func.func @two_maximal_shardings_should_be_sorted(%arg0: tensor<8x8xf32> {mhlo.s } // ----- -// CHECK-COUNT-1: sdy.mesh @maximal_mesh_0 = +// CHECK-COUNT-1: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @duplicate_maximal_sharding_should_be_deduped( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@maximal_mesh_0, [{}, {}]>}, @@ -165,7 +165,7 @@ func.func @duplicate_maximal_sharding_should_be_deduped(%arg0: tensor<8x8xf32> { // ----- // CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=8, "axis_1"=4> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @two_meshes( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_1"}, {}]>}, @@ -180,7 +180,7 @@ func.func @two_meshes(%arg0: tensor<8x8xf32> {mhlo.sharding = "{devices=[4,1,8]< // ----- // CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=8, "axis_1"=4> -// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = +// CHECK-LABEL: sdy.mesh @maximal_mesh_0 = // CHECK-LABEL: func @maximal_sharding_on_op( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"axis_1"}, {}]>}, diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir index c2707bae962cd3..f3e17fd2defac1 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_export_pipeline.mlir @@ -4,8 +4,8 @@ sdy.mesh @mesh_0 = <"axis_0"=2, "axis_1"=4, "axis_2"=4> sdy.mesh @mesh_1 = <"axis_0"=16> sdy.mesh @mesh_2 = <"x"=8, "y"=4> sdy.mesh @mesh_3 = <"a"=2, "b"=2, "c"=2, "d"=2> -sdy.mesh @maximal_mesh_0 = -sdy.mesh @maximal_mesh_1 = +sdy.mesh @maximal_mesh_0 = +sdy.mesh @maximal_mesh_1 = // CHECK-NOT: sdy.mesh diff --git a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir index fdc7efb82ff446..b022afcb921d43 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/mhlo_import_pipeline.mlir @@ -52,38 +52,44 @@ func.func @shmap_body(%arg0: tensor<1x8xf32>, %arg1: tensor<1x8xf32>) -> (tensor // ----- +// CHECK-LABEL: sdy.mesh @mesh = <"axis_0"=2, "axis_1"=2> + // CHECK-LABEL: func @while_with_free_variables -func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { +func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.sharding = "{devices=[2,1,2]<=[4] last_tile_dim_replicate}"}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> - // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]]) + // CHECK-NEXT: %[[C32:.*]] = sdy.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, []>]>} dense<32> + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1 + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: mhlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2 - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1 - // CHECK-NEXT: %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2 - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]] + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] + // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<1> : tensor - %2 = mhlo.constant dense<32> : tensor + %2 = mhlo.constant {mhlo.sharding = "{replicated}"} dense<32> : tensor %3:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { %4 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor mhlo.return %4 : tensor } do { %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> } +// ----- + // CHECK-LABEL: func @while_with_sinked_constants func.func @while_with_sinked_constants(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> diff --git a/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir new file mode 100644 index 00000000000000..b87048e4979a62 --- /dev/null +++ b/third_party/xla/xla/service/spmd/shardy/test/open_while_free_vars_sharding.mlir @@ -0,0 +1,93 @@ +// RUN: sdy_opt %s -xla-sdy-open-while-free-vars-sharding 2>&1 | FileCheck %s + +sdy.mesh @mesh1 = <"a"=2> +sdy.mesh @mesh2 = <"b"=2> + +// CHECK-LABEL: func @while_with_free_variables +func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}, + %arg2: tensor<32x96xf32>) + -> (tensor<32x96xf32>, tensor<32x96xf32>) { + // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[C1:.*]] = mhlo.constant dense<1> + // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} + // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %[[ADD_0]] <@mesh2, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_2:.*]] = mhlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: %[[ADD_3:.*]] = mhlo.add %[[ADD_2]], %arg2 + // CHECK-NEXT: %[[ADD_4:.*]] = mhlo.add %[[ADD_3]], %[[SC_1]] + // CHECK-NEXT: mhlo.return %[[ADD_4]], %[[ADD_1]] + // CHECK-NEXT: } + // CHECK-NEXT: return %[[ADD_0]], %[[WHILE]]#0 + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + %2 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %3 = mhlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh2, [{}, {"b"}]>]>} : tensor<32x96xf32> + %4:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %5 = mhlo.compare LT, %iterArg_0, %2 : (tensor, tensor) -> tensor + mhlo.return %5 : tensor + } do { + %5 = mhlo.add %iterArg_0, %1 : tensor + %6 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + %7 = mhlo.add %6, %arg2 : tensor<32x96xf32> + %8 = mhlo.add %7, %3 : tensor<32x96xf32> + mhlo.return %8, %5 : tensor<32x96xf32>, tensor + } + return %3, %4#0 : tensor<32x96xf32>, tensor<32x96xf32> +} + +// CHECK-LABEL: func @free_var_used_in_multiple_while_ops +func.func @free_var_used_in_multiple_while_ops( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {sdy.sharding = #sdy.sharding<@mesh1, [{"a"}, {}]>}) + -> tensor<32x96xf32> { + // CHECK-NEXT: %[[C0:.*]] = mhlo.constant dense<0> + // CHECK-NEXT: %[[C32:.*]] = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> + // CHECK-NEXT: %[[SC_0:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE_0:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg, %[[SC_0]] + // CHECK-NEXT: mhlo.return %[[ADD_0]], %iterArg_0 + // CHECK-NEXT: } + // CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg1 <@mesh1, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE_1:.*]]:2 = mhlo.while(%iterArg = %[[WHILE_0]]#0, %iterArg_0 = %[[C0]]) + // CHECK-NEXT: cond { + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] + // CHECK-NEXT: mhlo.return %[[COND]] + // CHECK-NEXT: } do { + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC_1]] + // CHECK-NEXT: mhlo.return %[[ADD_1]], %iterArg_0 + // CHECK-NEXT: } + // CHECK-NEXT: return %[[WHILE_1]]#0 + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh1, []>]>} dense<32> : tensor + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + mhlo.return %4 : tensor + } do { + %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + } + %3:2 = mhlo.while(%iterArg = %2#0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor + cond { + %4 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor + mhlo.return %4 : tensor + } do { + %4 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + mhlo.return %4, %iterArg_0 : tensor<32x96xf32>, tensor + } + return %3#0 : tensor<32x96xf32> +} diff --git a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir index dcd81b29ef6767..66b227dffd0f93 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/round_trip_pipeline.mlir @@ -13,6 +13,7 @@ // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2> sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2> +// CHECK-LABEL: func @main func.func @main( // CHECK: %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>}) %arg0: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b"}p4]>} @@ -35,6 +36,7 @@ func.func @main( // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2> sdy.mesh @mesh = <"a"=2, "b"=2> +// CHECK-LABEL: func @main func.func @main( // CHECK: %arg0: tensor<8x16xf32>) %arg0: tensor<8x16xf32> @@ -55,6 +57,7 @@ func.func @main( // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2> sdy.mesh @mesh = <"a"=2, "b"=2> +// CHECK-LABEL: func @main func.func @main( // CHECK: %arg0: tensor<8x16xf32>) %arg0: tensor<8x16xf32> @@ -78,6 +81,7 @@ func.func @main( // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2> sdy.mesh @mesh = <"a"=2, "b"=2> +// CHECK-LABEL: func @main func.func @main( // CHECK: %arg0: tensor<8x16xf32>) %arg0: tensor<8x16xf32> @@ -97,7 +101,7 @@ func.func @main( // CHECK: sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2> sdy.mesh @mesh = <"a"=2, "b"=2, "c"=2> -// CHECK: @main( +// CHECK-LABEL: @main( // CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}p4]>}, // CHECK-SAME: %arg1: tensor<8x8xf32>, %arg2: tensor<8x8xf32> // CHECK-SAME: ) -> tensor<8x8xf32> { @@ -122,6 +126,7 @@ func.func @main( // CHECK: sdy.mesh @mesh = <"data"=2> sdy.mesh @mesh = <"data"=2> +// CHECK-LABEL: func @main func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK: sdy.sharding_constraint %arg0 <@mesh, [{"data", ?}, {?}]> : tensor<8x8xf32> %0 = sdy.sharding_constraint %arg0 <@mesh, [{"data", ?}, {?}]> : tensor<8x8xf32> @@ -144,6 +149,7 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { // CHECK: sdy.mesh @mesh_2 = <"x"=8, "y"=4> sdy.mesh @mesh_2 = <"x"=8, "y"=4> +// CHECK-LABEL: func @main func.func @main( // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { @@ -157,36 +163,41 @@ func.func @main( // ----- +// CHECK: sdy.mesh @mesh = <"x"=2> +sdy.mesh @mesh = <"x"=2> + // Test WhileOp with lifted free variables and sinked constants. -// CHECK: func @main -func.func @main(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { +// CHECK-LABEL: func @main +func.func @main( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[WHILE:.*]]:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]]) + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1 + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: mhlo.return %[[COND]] // CHECK-NEXT: } do { // CHECK-DAG: %[[C1:.*]] = sdy.constant dense<1> // CHECK-DAG: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] - // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-DAG: %[[IDENTITY:.*]] = sdy.identity %iterArg_1 - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY]] + // CHECK-DAG: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] + // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = sdy.constant dense<0> : tensor %1 = sdy.constant dense<32> : tensor - %2:3 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0, %iterArg_1 = %1) : tensor<32x96xf32>, tensor, tensor + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %0) : tensor<32x96xf32>, tensor cond { - %3 = mhlo.compare LT, %iterArg_0, %iterArg_1 : (tensor, tensor) -> tensor + %3 = mhlo.compare LT, %iterArg_0, %1 : (tensor, tensor) -> tensor mhlo.return %3 : tensor } do { %3 = sdy.constant dense<1> : tensor %4 = mhlo.add %iterArg_0, %3 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> - %6 = sdy.identity %iterArg_1 : tensor - mhlo.return %5, %4, %6 : tensor<32x96xf32>, tensor, tensor + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> + mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %2#0 : tensor<32x96xf32> } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir index 10c44c89262ba5..89b8722748f6cf 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_export_pipeline.mlir @@ -13,15 +13,15 @@ sdy.mesh @mesh_2 = <"x"=8, "y"=4> // CHECK-SAME: mesh_2 = \22#sdy.mesh<\\22x\\22=8, \\22y\\22=4>\22}"}} { // CHECK-LABEL: func @multiple_shardings( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\22axis_2\22}, {\22axis_0\22, \22axis_1\22}]>"}}, -// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_0\22, \22axis_2\22}]>"}}, -// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_1\22}]>"}}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{\22axis_2\22}, {\22axis_0\22, \22axis_1\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_0\22, \22axis_2\22}]>"}, mhlo.sharding = +// CHECK-SAME: %arg2: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_0, [{}, {\22axis_1\22}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{"axis_2"}, {"axis_0", "axis_1"}]>}, %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_0", "axis_2"}]>}, %arg2: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_0, [{}, {"axis_1"}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: mhlo.add -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\22axis_1\22, \22axis_0\22}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_0, [{\22axis_1\22, \22axis_0\22}, {}]>]>"}, mhlo.sharding = %0 = mhlo.add %arg0, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh_0, [{"axis_1","axis_0"}, {}]>]>} : tensor<8x8xf32> %1 = "mhlo.dot" (%0, %arg2) : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> @@ -31,7 +31,7 @@ func.func @multiple_shardings(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.shardi func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) { %0 = mhlo.constant dense<0.000000e+00> : tensor // CHECK: mhlo.reduce -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\22y\22}]>, <@mesh_2, [{\22y\22}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{}, {\22y\22}]>, <@mesh_2, [{\22y\22}, {}]>]>"}, mhlo.sharding = %1:2 = mhlo.reduce(%arg0 init: %0), (%arg1 init: %0) across dimensions = [1] {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{}, {"y"}]>, <@mesh_2, [{"y"}, {}]>]>} : (tensor<4x64x8xf32>, tensor<4x64x8xf32>, tensor, tensor) -> (tensor<4x8xf32>, tensor<4x8xf32>) @@ -44,20 +44,20 @@ func.func @multi_result_op(%arg0: tensor<4x64x8xf32>, %arg1: tensor<4x64x8xf32>) } // CHECK-LABEL: func @split_axes( -// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22y\22}, {\22x\22:(2)2}]>"}}, -// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22x\22:(1)2}, {\22x\22:(2)4}]>"}}) +// CHECK-SAME: %arg0: tensor<8x8xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22y\22}, {\22x\22:(2)2}]>"}, mhlo.sharding = +// CHECK-SAME: %arg1: tensor<8x16xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh_2, [{\22x\22:(1)2}, {\22x\22:(2)4}]>"}, mhlo.sharding = // CHECK-SAME: -> tensor<8x16xf32> { func.func @split_axes(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"y"}, {"x":(2)2}]>}, %arg1: tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x":(1)2}, {"x":(2)4}]>}) -> tensor<8x16xf32> { // CHECK-NEXT: "mhlo.dot" -// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22:(1)2, \22x\22:(4)2}, {}]>]>"}} +// CHECK-SAME: {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22:(1)2, \22x\22:(4)2}, {}]>]>"}, mhlo.sharding = %1 = "mhlo.dot" (%arg0, %arg1) {sdy.sharding = #sdy.sharding_per_value<[<@mesh_2, [{"x":(1)2, "x":(4)2}, {}]>]>} : (tensor<8x8xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %1 : tensor<8x16xf32> } // CHECK-LABEL: func @func_result_sharding_returning_func_arg( func.func @func_result_sharding_returning_func_arg( - // CHECK: %arg0: tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {mhlo.sharding = %arg0: tensor<8x16xf32> ) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}) { // CHECK: %[[CUSTOM_CALL:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%arg0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -67,10 +67,11 @@ func.func @func_result_sharding_returning_func_arg( // CHECK-LABEL: func @func_result_sharding_returning_op_value( func.func @func_result_sharding_returning_op_value( - // CHECK: %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>, tensor<8x16xf32>) { + // CHECK: %arg0: tensor<8x16xf32>) + // CHECK-SAME: -> (tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[1,4,8]<=[8,4]T(1,0) last_tile_dim_replicate}"}, tensor<8x16xf32> {mhlo.sharding = "{devices=[8,4]<=[32]}"}) { %arg0: tensor<8x16xf32>) -> (tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x", ?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{?}, {"y"}p4]>}, tensor<8x16xf32> {sdy.sharding = #sdy.sharding<@mesh_2, [{"x"}, {"y"}p1]>}) { // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %arg0, %arg0 : tensor<8x16xf32> - // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, \22y\22}, {}]>, <@mesh_2, [{\22y\22, \22x\22}, {}]>]>"}} : (tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor<8x16xf32>) + // CHECK-NEXT: %[[TEST_ONLY:.*]]:2 = mhlo.custom_call @sdy_testonly(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, \22y\22}, {}]>, <@mesh_2, [{\22y\22, \22x\22}, {}]>]>"}, mhlo.sharding = // CHECK-NEXT: %[[ADD_RESULT_SHARDING:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[ADD]]) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_0:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#0) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{?}, {\22y\22}p4]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> // CHECK-NEXT: %[[TEST_ONLY_RES_SHARDING_1:.*]] = mhlo.custom_call @local_xla.sdy.FuncResultSharding(%[[TEST_ONLY]]#1) {has_side_effect = true, mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22}, {\22y\22}p1]>]>"}} : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -83,16 +84,15 @@ func.func @func_result_sharding_returning_op_value( // CHECK-LABEL: func @sharding_constraint // CHECK-SAME: %arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.func @sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { - // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {?}]>]>"}} : (tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK: mhlo.custom_call @Sharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh_2, [{\22x\22, ?}, {?}]>]>"}, mhlo.sharding = %0 = sdy.sharding_constraint %arg0 <@mesh_2, [{"x", ?}, {?}]> : tensor<8x8xf32> return %0 : tensor<8x8xf32> } -// CHECK-LABEL: func @identity_and_constant -func.func @identity_and_constant() -> tensor { +// CHECK-LABEL: func @constant +func.func @constant() -> tensor { // CHECK-NEXT: %[[CONST:.*]] = mhlo.constant dense<0> // CHECK-NEXT: return %[[CONST]] %0 = sdy.constant dense<0> : tensor - %1 = sdy.identity %0 : tensor - return %1 : tensor + return %0 : tensor } diff --git a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir index 2354d83b424555..e782de02815699 100644 --- a/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir +++ b/third_party/xla/xla/service/spmd/shardy/test/sdy_round_trip_import_pipeline.mlir @@ -30,20 +30,22 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x } // CHECK-LABEL: func @while_with_free_variables - func.func @while_with_free_variables(%arg0: tensor<32x96xf32>) -> tensor<32x96xf32> { + func.func @while_with_free_variables( + %arg0: tensor<32x96xf32>, + %arg1: tensor<32x96xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}, {}]>"}}) + -> tensor<32x96xf32> { // CHECK-NEXT: %[[C0:.*]] = sdy.constant dense<0> // CHECK-NEXT: %[[C1:.*]] = sdy.constant dense<1> // CHECK-NEXT: %[[C32:.*]] = sdy.constant dense<32> - // CHECK-NEXT: %[[WHILE:.*]]:4 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]], %iterArg_1 = %[[C32]], %iterArg_2 = %[[C1]]) + // CHECK-NEXT: %[[SC:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{?}, {?}]> + // CHECK-NEXT: %[[WHILE:.*]]:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %[[C0]]) // CHECK-NEXT: cond { - // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %iterArg_1 + // CHECK-NEXT: %[[COND:.*]] = mhlo.compare LT, %iterArg_0, %[[C32]] // CHECK-NEXT: mhlo.return %[[COND]] // CHECK-NEXT: } do { - // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %iterArg_2 - // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %iterArg - // CHECK-NEXT: %[[IDENTITY_0:.*]] = sdy.identity %iterArg_1 - // CHECK-NEXT: %[[IDENTITY_1:.*]] = sdy.identity %iterArg_2 - // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]], %[[IDENTITY_0]], %[[IDENTITY_1]] + // CHECK-NEXT: %[[ADD_0:.*]] = mhlo.add %iterArg_0, %[[C1]] + // CHECK-NEXT: %[[ADD_1:.*]] = mhlo.add %iterArg, %[[SC]] + // CHECK-NEXT: mhlo.return %[[ADD_1]], %[[ADD_0]] // CHECK-NEXT: } // CHECK-NEXT: return %[[WHILE]]#0 %0 = mhlo.constant dense<0> : tensor @@ -55,7 +57,7 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x mhlo.return %4 : tensor } do { %4 = mhlo.add %iterArg_0, %1 : tensor - %5 = mhlo.add %iterArg, %iterArg : tensor<32x96xf32> + %5 = mhlo.add %iterArg, %arg1 : tensor<32x96xf32> mhlo.return %5, %4 : tensor<32x96xf32>, tensor } return %3#0 : tensor<32x96xf32> diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index c3fc8b1ab31c0a..c02493e053f8b3 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -571,8 +571,11 @@ PartitionedHlo PartitionedHlo::ReshardNoCache( "not able to go from sharding " << sharding().ToString(/*include_metadata=*/true) << " to " << target.ToString(/*include_metadata=*/true) - << " without doing a full rematerialization of the tensor. You " - "probably want to enrich the sharding annotations to prevent " + << " without doing a full rematerialization of the tensor for HLO " + "operation: " + << hlo_->ToString() + << ". You probably want to enrich the sharding annotations to " + "prevent " "this from happening."; } return Replicate().Reshard(target); @@ -3316,6 +3319,7 @@ absl::Status SpmdPartitioningVisitor::HandleSingleDevice( auto param = true_b.AddInstruction(HloInstruction::CreateParameter( /*parameter_number=*/0, operand_shape, "true_branch_param")); std::vector new_operands; + new_operands.reserve(operands.size()); for (int64_t i = 0; i < operands.size(); ++i) { new_operands.push_back(true_b.AddInstruction( HloInstruction::CreateGetTupleElement(*operand_shapes[i], param, i))); @@ -4040,20 +4044,21 @@ absl::Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) { const HloSharding& sharding = hlo->sharding(); // Shardings for the body parameter, body root, and cond parameter must be - // the same, and the condition root must be replicated so that all partitions - // follow the same control flow. + // the same. hlo->while_condition()->parameter_instruction(0)->set_sharding(sharding); hlo->while_body()->parameter_instruction(0)->set_sharding(sharding); - const HloSharding& cond_root_sharding = - hlo->while_condition()->root_instruction()->sharding(); - TF_RETURN_IF_ERROR(partitioner_ - ->PartitionComputation(hlo->while_condition(), - cond_root_sharding.IsManual() - ? cond_root_sharding - : HloSharding::Replicate(), - next_channel_id_, logger_, - call_graph_) - .status()); + + // The condition root must be replicated so that all partitions follow the + // same control flow. + HloInstruction* cond_root = hlo->while_condition()->root_instruction(); + const HloSharding cond_root_sharding = + hlo_sharding_util::ReplicateAllDataDims(cond_root->sharding()); + cond_root->set_sharding(cond_root_sharding); + TF_RETURN_IF_ERROR( + partitioner_ + ->PartitionComputation(hlo->while_condition(), cond_root_sharding, + next_channel_id_, logger_, call_graph_) + .status()); TF_RETURN_IF_ERROR(partitioner_ ->PartitionComputation(hlo->while_body(), sharding, next_channel_id_, logger_, @@ -4129,6 +4134,7 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) { if (hlo->sharding().IsManual()) { auto clone_from_original = [&](const HloSharding& shared_sharding) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo()); @@ -4310,6 +4316,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { } auto clone_from_original = [&](const HloSharding& shared_sharding) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo()); @@ -4340,6 +4347,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) { TF_RET_CHECK(!hlo->sharding().IsTileMaximal()); // Replicate the operands and run partitioned Rng on all devices. std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back(GetPartitionedHlo(hlo->operand(i)) .Reshard(HloSharding::Replicate()) @@ -4659,6 +4667,7 @@ absl::Status SpmdPartitioningVisitor::HandleSelectAndScatter( absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) { std::vector new_operands; + new_operands.reserve(hlo->operand_count()); for (int64_t i = 0; i < hlo->operand_count(); ++i) { new_operands.push_back( GetPartitionedHlo(hlo->operand(i)) diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc index 99a1d1f92dc951..0d1d4a7be5c37b 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_test.cc @@ -47,9 +47,9 @@ limitations under the License. #include "xla/service/spmd/spmd_prepare.h" #include "xla/shape.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" @@ -4474,6 +4474,36 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]"))); } +TEST_P(SpmdPartitioningTest, WhilePartialManual) { + absl::string_view hlo_string = R"( +HloModule module + +LoopCond { + x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + const = s32[] constant(5), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + ROOT lt = pred[] compare(x, const), direction=LT, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} +} + +Inc { + x = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + const = s32[] constant(1), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + ROOT add = s32[] add(x, const), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} +} + +ENTRY entry { + zero = s32[] parameter(0), sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} + ROOT while = s32[] while(zero), body=Inc, condition=LoopCond, sharding={devices=[2,2]<=[4] last_tile_dims={manual, replicated}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto zero = AllOf(op::Parameter(0), op::Shape("s32[]")); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::While(zero), op::Shape("s32[]"))); +} + TEST_P(SpmdPartitioningTest, TestWhileFrontendAttributes) { absl::string_view hlo_string = R"( HloModule module @@ -9114,6 +9144,78 @@ ENTRY %main.7 { EXPECT_THAT(root, tuple); } +TEST_P(SpmdPartitioningTest, PartiallyReplicateRHS) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[16384,2048] parameter(0), sharding={devices=[16,8]<=[128]} + rhs = bf16[16384,256] parameter(1), sharding={devices=[128,1]<=[128]} + ROOT dot = bf16[2048,256] dot(lhs, rhs), lhs_contracting_dims={0}, rhs_contracting_dims={0}, sharding={devices=[8,1,16]<=[16,8]T(1,0) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, PartitionComputation(hlo_string, /*num_devices=*/128)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[1024,256]"), op::Parameter(0)); + const auto rhs = AllOf(op::Shape("bf16[1024,256]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), op::Parameter(1), _, _))); + auto dot = AllOf(op::Shape("bf16[256,256]"), op::Dot(lhs, rhs)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_P(SpmdPartitioningTest, AllToAllAndPartialReplicateRHS) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[64,64] parameter(0), sharding={devices=[2,2,2]<=[8] last_tile_dim_replicate} + rhs = bf16[64,64,64] parameter(1), sharding={devices=[1,2,4]<=[2,2,2]T(2,1,0)} + ROOT dot = bf16[64,64,64] dot(lhs, rhs), lhs_contracting_dims={1}, rhs_contracting_dims={2}, sharding={devices=[2,2,1,2]<=[2,2,2]T(0,2,1) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[32,32]"), op::Parameter(0)); + const auto all_to_all_p1 = AllOf( + op::Shape("bf16[32,64,16]"), + op::Reshape(op::Transpose(op::AllToAll(op::Reshape(op::Parameter(1)))))); + const auto rhs = AllOf(op::Shape("bf16[32,64,32]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), all_to_all_p1, _, _, _))); + auto dot = AllOf(op::Shape("bf16[32,32,64]"), op::Dot(lhs, rhs)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(dot)); +} + +TEST_P(SpmdPartitioningTest, ReplicateLHSofConv) { + const char* const hlo_string = R"( +HloModule module +ENTRY main { + lhs = bf16[128,8,8,1280] parameter(0), sharding={devices=[128,1,1,1]<=[128]} + rhs = bf16[3,3,1280,1280] parameter(1), sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate} + ROOT conv = bf16[128,8,8,1280] convolution(lhs, rhs), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, sharding={devices=[1,1,1,8,16]<=[16,8]T(1,0) last_tile_dim_replicate} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, PartitionComputation(hlo_string, /*num_devices=*/128)); + VLOG(1) << module->ToString(); + + const auto lhs = AllOf(op::Shape("bf16[128,8,8,1280]"), + op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), op::Parameter(0), _, _, _, _))); + const auto rhs = AllOf(op::Shape("bf16[3,3,1280,160]"), op::Parameter(1)); + const auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Shape("bf16[128,8,8,160]"), op::Convolution(lhs, rhs))); +} + TEST_P(SpmdPartitioningTest, ElementwiseTest_SubgroupSharding_TileToReplicate) { absl::string_view hlo_string = R"( HloModule module @@ -10862,6 +10964,53 @@ ENTRY entry { _)); } +TEST_P(SpmdPartitioningTest, ScatterRepsOnLastTileDimDontDivideGroups) { + absl::string_view hlo_string = R"( +HloModule module + +region.1 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT res.1 = f32[] add(lhs, rhs) +} + +ENTRY entry { + %add.1 = f32[8,96,2048,16]{3,2,1,0} parameter(0) + %concatenate.1 = s32[8,96,2048,2,4]{4,3,2,1,0} parameter(1) + %broadcast.1 = f32[8,96,2048,2]{3,2,1,0} parameter(2) + + %add.1.shard = f32[8,96,2048,16]{3,2,1,0} copy(%add.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate} + %concatenate.1.shard = s32[8,96,2048,2,4]{4,3,2,1,0} copy(%concatenate.1), sharding={devices=[8,8,1,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate} + %broadcast.1.shard = f32[8,96,2048,2]{3,2,1,0} copy(%broadcast.1), sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate} + + ROOT %scatter.44 = f32[8,96,2048,16]{3,2,1,0} scatter( + %add.1.shard, + %concatenate.1.shard, + %broadcast.1.shard), + update_window_dims={}, + inserted_window_dims={0,1,2,3}, + scatter_dims_to_operand_dims={0,1,2,3}, + index_vector_dim=4, + to_apply=region.1, + sharding={devices=[8,8,1,1,24]<=[8,8,24]T(1,0,2) last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN( + auto module, PartitionComputation(hlo_string, /*num_devices=*/1536)); + VLOG(1) << module->ToString(); + // Verify scatter is partitioned properly. + { + const auto partitioned_scatter = + module->entry_computation()->root_instruction(); + auto operand = AllOf(op::Shape("f32[1,12,2048,16]")); + auto indices = AllOf(op::Shape("s32[8,96,2048,2,4]")); + auto update = AllOf(op::Shape("f32[8,96,2048,2]")); + auto scatter = AllOf(op::Shape("f32[1,12,2048,16]"), + op::Scatter(operand, indices, update)); + EXPECT_THAT(partitioned_scatter, scatter); + } +} + TEST_P(SpmdPartitioningTest, ParallelDimFromOutsideConditionalPositive) { absl::string_view hlo_string = R"( HloModule module diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h index 65b5d0134b4e39..a982c3edf1e8db 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner_util.h +++ b/third_party/xla/xla/service/spmd/spmd_partitioner_util.h @@ -84,6 +84,7 @@ HloInstruction* CreateConstantBase(const Shape& shape, Literal value, T* b, PrimitiveType)) { if (shape.IsTuple()) { std::vector elements; + elements.reserve(ShapeUtil::TupleElementCount(shape)); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { elements.push_back( CreateConstantBase(ShapeUtil::GetTupleElementShape(shape, i), diff --git a/third_party/xla/xla/service/stable_sort_expander.cc b/third_party/xla/xla/service/stable_sort_expander.cc index 910ab5da82a01e..ca87dce4df65a7 100644 --- a/third_party/xla/xla/service/stable_sort_expander.cc +++ b/third_party/xla/xla/service/stable_sort_expander.cc @@ -55,7 +55,6 @@ absl::StatusOr StableSortExpander::ExpandInstruction( HloComputation* computation = sort->parent(); HloInstruction* expanded_sort = nullptr; - absl::flat_hash_set used_indices; int64_t iota_index = IotaOperandIndexForStableSort(*sort); // If there is currently no iota operand which we could use for making the diff --git a/third_party/xla/xla/service/stable_sort_expander_test.cc b/third_party/xla/xla/service/stable_sort_expander_test.cc index f2b5c41eee4f17..83ba193ede5aef 100644 --- a/third_party/xla/xla/service/stable_sort_expander_test.cc +++ b/third_party/xla/xla/service/stable_sort_expander_test.cc @@ -22,7 +22,7 @@ limitations under the License. #include "xla/service/pattern_matcher_gmock.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/stream_pool_test.cc b/third_party/xla/xla/service/stream_pool_test.cc index fd0a05e5d2f237..2bea4119a4d9b7 100644 --- a/third_party/xla/xla/service/stream_pool_test.cc +++ b/third_party/xla/xla/service/stream_pool_test.cc @@ -26,22 +26,21 @@ namespace { class StreamPoolTest : public ::testing::Test { protected: - std::unique_ptr NewStreamExecutor() { + se::StreamExecutor* NewStreamExecutor() { se::Platform* platform = se::PlatformManager::PlatformWithName("Host").value(); - se::StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetUncachedExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } }; TEST_F(StreamPoolTest, EmptyPool) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool(executor.get()); + se::StreamExecutor* executor = NewStreamExecutor(); + StreamPool pool(executor); } TEST_F(StreamPoolTest, OneStreamPool) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool(executor.get()); + se::StreamExecutor* executor = NewStreamExecutor(); + StreamPool pool(executor); // Borrow and return a stream. StreamPool::Ptr stream1 = pool.BorrowStream(); @@ -61,8 +60,8 @@ TEST_F(StreamPoolTest, OneStreamPool) { } TEST_F(StreamPoolTest, TwoStreamPool) { - std::unique_ptr executor = NewStreamExecutor(); - StreamPool pool(executor.get()); + se::StreamExecutor* executor = NewStreamExecutor(); + StreamPool pool(executor); // Borrow two streams. StreamPool::Ptr stream1 = pool.BorrowStream(); diff --git a/third_party/xla/xla/service/topk_rewriter_test.cc b/third_party/xla/xla/service/topk_rewriter_test.cc index a1dbb7a5f59b09..c678bef94e373f 100644 --- a/third_party/xla/xla/service/topk_rewriter_test.cc +++ b/third_party/xla/xla/service/topk_rewriter_test.cc @@ -32,7 +32,7 @@ limitations under the License. #include "xla/service/tuple_simplifier.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/service/transfer_manager.cc b/third_party/xla/xla/service/transfer_manager.cc index b5edfad998ad5e..c601a884919d1b 100644 --- a/third_party/xla/xla/service/transfer_manager.cc +++ b/third_party/xla/xla/service/transfer_manager.cc @@ -270,6 +270,8 @@ absl::Status TransferManager::WriteRootTupleIndexTable( device_memory.size()); std::vector elements; + elements.reserve( + ShapeUtil::TupleElementCount(device_buffer.on_device_shape())); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(device_buffer.on_device_shape()); ++i) { elements.push_back(device_buffer.buffer({i})); @@ -290,6 +292,7 @@ absl::Status TransferManager::WriteRootTupleIndexTable( device_memory.size()); std::vector elements; + elements.reserve(ShapeUtil::TupleElementCount(buffer_tree.shape())); for (int64_t i = 0; i < ShapeUtil::TupleElementCount(buffer_tree.shape()); ++i) { elements.push_back(buffer_tree.element({i}).AsDeviceMemoryBase()); diff --git a/third_party/xla/xla/service/triangular_solve_expander_test.cc b/third_party/xla/xla/service/triangular_solve_expander_test.cc index 777f1258eb1ce1..fa382b24d0d9db 100644 --- a/third_party/xla/xla/service/triangular_solve_expander_test.cc +++ b/third_party/xla/xla/service/triangular_solve_expander_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/reference_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/tuple_simplifier.cc b/third_party/xla/xla/service/tuple_simplifier.cc index ae033b79ba917a..3557b076df0ef6 100644 --- a/third_party/xla/xla/service/tuple_simplifier.cc +++ b/third_party/xla/xla/service/tuple_simplifier.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/shape_util.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" namespace xla { @@ -116,6 +117,11 @@ absl::StatusOr TupleSimplifier::Run( } } } + + if (module->has_schedule()) { + TF_RETURN_IF_ERROR(module->schedule().Update()); + } + return changed; } diff --git a/third_party/xla/xla/service/tuple_simplifier_test.cc b/third_party/xla/xla/service/tuple_simplifier_test.cc index f83a88f50709a1..33305afd7e0f71 100644 --- a/third_party/xla/xla/service/tuple_simplifier_test.cc +++ b/third_party/xla/xla/service/tuple_simplifier_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc b/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc index 22928f05f1cac2..271b04cd4e2643 100644 --- a/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_all_reduce_code_motion_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/while_loop_concat_code_motion_test.cc b/third_party/xla/xla/service/while_loop_concat_code_motion_test.cc index a4baa5bbe4c1e6..83a43f54f7dd05 100644 --- a/third_party/xla/xla/service/while_loop_concat_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_concat_code_motion_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/service/hlo_verifier.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/while_loop_fusible_sinking.cc b/third_party/xla/xla/service/while_loop_fusible_sinking.cc index d1fd7acd8ca110..07b49dbafe45d1 100644 --- a/third_party/xla/xla/service/while_loop_fusible_sinking.cc +++ b/third_party/xla/xla/service/while_loop_fusible_sinking.cc @@ -136,10 +136,6 @@ absl::StatusOr WhileLoopFusibleSinking::TrySinkingFusiblesIntoWhileLoop( } bool changed = false; - - absl::flat_hash_map> - conditional_gte_index_to_insts = - WhileUtil::GetGTEsMapForWhileConditional(*while_cond); std::vector invariant_body_gtes = WhileUtil::GetInvariantGTEsForWhileBody(*while_body); std::vector tuple_indices; diff --git a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc index 57f2768b458c0e..7d311df3546e65 100644 --- a/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc +++ b/third_party/xla/xla/service/while_loop_invariant_code_motion_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" namespace xla { diff --git a/third_party/xla/xla/service/while_loop_simplifier.cc b/third_party/xla/xla/service/while_loop_simplifier.cc index 2ca8d1884a80bb..4cc642c3994e18 100644 --- a/third_party/xla/xla/service/while_loop_simplifier.cc +++ b/third_party/xla/xla/service/while_loop_simplifier.cc @@ -954,8 +954,8 @@ static absl::StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // inline the call. const auto& attrs = while_op->frontend_attributes().map(); bool skip_trip_count_one_simplification = - attrs.contains("skip-simplify-while-loops/trip-count-one") && - (attrs.at("skip-simplify-while-loops/trip-count-one") == "true"); + attrs.contains("skip-simplify-while-loops_trip-count-one") && + (attrs.at("skip-simplify-while-loops_trip-count-one") == "true"); if (trip_count && *trip_count == 1 && !skip_trip_count_one_simplification) { // Do not simplify the loop away when there is a side-effectful op, // otherwise the infeed op may not inherit the data dependency from diff --git a/third_party/xla/xla/service/while_loop_simplifier_test.cc b/third_party/xla/xla/service/while_loop_simplifier_test.cc index c82e29c06728eb..494271c2023ceb 100644 --- a/third_party/xla/xla/service/while_loop_simplifier_test.cc +++ b/third_party/xla/xla/service/while_loop_simplifier_test.cc @@ -33,8 +33,8 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/service/while_loop_unroller.cc b/third_party/xla/xla/service/while_loop_unroller.cc index 0e1c3288c468df..7b22244fa9023f 100644 --- a/third_party/xla/xla/service/while_loop_unroller.cc +++ b/third_party/xla/xla/service/while_loop_unroller.cc @@ -98,8 +98,11 @@ std::unique_ptr MakeTrivialLoopCondition( absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { if (instr->IsCustomCall("DynamicGte")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); - TF_ASSIGN_OR_RETURN(Literal index_lit, - evaluator.Evaluate(instr->mutable_operand(1), true)); + TF_ASSIGN_OR_RETURN( + Literal index_lit, + evaluator.Evaluate(instr->mutable_operand(1), + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); auto index = LiteralUtil::LiteralAsScalarInt64(std::move(index_lit)); // The index must have a compile-time integer value at this point. TF_RET_CHECK(index.has_value()); @@ -109,8 +112,11 @@ absl::Status HandleDynamicGteOrTuple(HloInstruction* instr) { } else if (instr->IsCustomCall("DynamicTuple")) { HloEvaluator evaluator(/*max_loop_iterations=*/0); std::vector tuple_operands; - TF_ASSIGN_OR_RETURN(Literal index_lit, - evaluator.Evaluate(instr->mutable_operand(2), true)); + TF_ASSIGN_OR_RETURN( + Literal index_lit, + evaluator.Evaluate(instr->mutable_operand(2), + /*precomputed_analyses=*/{}, + /*recursively_evaluate_nonconstant_operands=*/true)); auto index = LiteralUtil::LiteralAsScalarInt64(std::move(index_lit)); // The index must have a compile-time integer value at this point. TF_RET_CHECK(index.has_value()); diff --git a/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc b/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc index 16ae69b5bf0e18..85b9b73098466b 100644 --- a/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc +++ b/third_party/xla/xla/service/xla_aot_compile_cpu_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/xla_aot_compile_gpu_test.cc b/third_party/xla/xla/service/xla_aot_compile_gpu_test.cc index a3720b3b39f15f..88c36c9f4daf47 100644 --- a/third_party/xla/xla/service/xla_aot_compile_gpu_test.cc +++ b/third_party/xla/xla/service/xla_aot_compile_gpu_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/xla_aot_compile_stablehlo_cpu_test.cc b/third_party/xla/xla/service/xla_aot_compile_stablehlo_cpu_test.cc index 7526cd401c71ce..a85de68e143d08 100644 --- a/third_party/xla/xla/service/xla_aot_compile_stablehlo_cpu_test.cc +++ b/third_party/xla/xla/service/xla_aot_compile_stablehlo_cpu_test.cc @@ -26,7 +26,7 @@ limitations under the License. #include "xla/service/platform_util.h" #include "xla/service/shaped_buffer.h" #include "xla/tests/literal_test_util.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt index 1901cf2eecebd2..0cae8b48151a6a 100644 --- a/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt +++ b/third_party/xla/xla/service/xla_aot_compile_test_autotune_results.prototxt @@ -15,7 +15,7 @@ version: 3 results { device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB" - hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"gemm_backend_config\":{\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"alpha_imag\":0,\"precision_config\":{\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"],\"algorithm\":\"ALG_UNSET\"},\"epilogue\":\"DEFAULT\",\"lhs_stride\":\"9\",\"rhs_stride\":\"9\",\"grad_x\":false,\"grad_y\":false},\"force_earliest_schedule\":false}" + hlo: "(f32[3,3]{1,0}, s8[72]{0}) custom-call(f32[3,3]{1,0}, f32[3,3]{1,0}), custom_call_target=\"__cublas$gemm\", backend_config={\"force_earliest_schedule\":false,\"gemm_backend_config\":{\"alpha_imag\":0,\"alpha_real\":1,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_batch_dimensions\":[],\"lhs_contracting_dimensions\":[\"1\"],\"rhs_batch_dimensions\":[],\"rhs_contracting_dimensions\":[\"0\"]},\"epilogue\":\"DEFAULT\",\"grad_x\":false,\"grad_y\":false,\"lhs_stride\":\"9\",\"precision_config\":{\"algorithm\":\"ALG_UNSET\",\"operand_precision\":[\"DEFAULT\",\"DEFAULT\"]},\"rhs_stride\":\"9\"},\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" result { gemm { algorithm: 13 @@ -24,7 +24,7 @@ results { } results { device: "CUDA: 6.0, Cores: 56, GPU clock: 1.4805 GHz, Memory bandwidth: 732 GB/s, L2 cache: 4 MB" - hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[],\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"side_input_scale\":0,\"leakyrelu_alpha\":0},\"force_earliest_schedule\":false}" + hlo: "(f32[1,1,2,3]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,2,4,4]{3,2,1,0}, f32[1,2,3,2]{3,2,1,0}), window={size=3x2}, dim_labels=bf01_oi01->bf01, custom_call_target=\"__cudnn$convForward\", backend_config={\"cudnn_conv_backend_config\":{\"activation_mode\":\"kNone\",\"conv_result_scale\":1,\"leakyrelu_alpha\":0,\"side_input_scale\":0},\"force_earliest_schedule\":false,\"operation_queue_id\":\"0\",\"wait_on_operation_queues\":[]}" result { run_time { nanos: 8192 diff --git a/third_party/xla/xla/service/xla_compile_main.cc b/third_party/xla/xla/service/xla_compile_main.cc index 1e607a0b674c3b..0e217d53cb7343 100644 --- a/third_party/xla/xla/service/xla_compile_main.cc +++ b/third_party/xla/xla/service/xla_compile_main.cc @@ -71,7 +71,7 @@ int main(int argc, char* argv[]) { "an attached GPU will be used."), tsl::Flag("autotune_results", &options.gpu_options.autotune_results_path, "The path to AutotuneResults, optional when compiling for" - " GPU"), + " GPU. Only used if autotuning is enabled in XLA_FLAGS."), tsl::Flag("symbol_repo", &options.repo_options.symbol_repo, "Which SymbolRepository to look up --symbol_reference in. If " "the repository contains a GpuTargetConfig, " @@ -83,7 +83,8 @@ int main(int argc, char* argv[]) { "optimized_symbol_reference", &options.repo_options.optimized_symbol_id, "Optimized symbol ID to look up in a SymbolRepository. Overrides " - "--autotune_results_path."), + "--autotune_results_path. Any autotuning results that are present " + "will be used as long as autotuning is enabled in XLA_FLAGS."), tsl::Flag("use_attached_device", &options.gpu_options.use_attached_device, "Whether to use the attached GPU or not. Overrides the " diff --git a/third_party/xla/xla/shape_util.cc b/third_party/xla/xla/shape_util.cc index 0deca085f6c94c..9581d6d673f6f0 100644 --- a/third_party/xla/xla/shape_util.cc +++ b/third_party/xla/xla/shape_util.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/shape_util.h" #include -#include #include #include #include @@ -1982,8 +1981,9 @@ struct ParallelState { // Returns the indices of the first elements of all consecutive subarrays of the // given array. For example: // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4} -static std::vector ConsecutiveSegments(absl::Span xs) { - std::vector is = {0}; +static absl::InlinedVector ConsecutiveSegments( + absl::Span xs) { + absl::InlinedVector is = {0}; for (size_t i = 1; i < xs.size(); ++i) { if (1 != xs[i] - xs[i - 1]) { is.push_back(i); @@ -2010,83 +2010,74 @@ static Shape MergeDimensions(absl::Span segs, dimensions); } -static std::vector MajorToMinorLayout(const Shape& s) { +static absl::InlinedVector MajorToMinorLayout(const Shape& s) { absl::Span minor_to_major = LayoutUtil::MinorToMajor(s); - return std::vector{minor_to_major.rbegin(), minor_to_major.rend()}; -} - -static std::optional GetNormalizedTransposeShapeHelper( - const Shape& input_shape, absl::Span output_to_input, - const Vector3& permutation) { - // 'permutation' should not be the identity permutation. - if (permutation[0] == 0 && permutation[1] == 1 && permutation[2] == 2) { - return std::nullopt; - } - std::vector segments = ConsecutiveSegments(output_to_input); - if (segments.size() > 3) { + return absl::InlinedVector{minor_to_major.rbegin(), + minor_to_major.rend()}; +} + +static std::optional> +GetNormalizedTransposeShapeHelper( + const Shape& output_shape, absl::Span output_to_input, + absl::InlinedVector& permutation) { + absl::InlinedVector segments = + ConsecutiveSegments(output_to_input); + // This means that after normalization there is actually no transpose. + if (segments.size() == 1) { return std::nullopt; } - - Shape normalized_input_shape = - ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( - input_shape); - Shape normalized_shape = MergeDimensions(segments, normalized_input_shape); - std::vector normalized_dims{normalized_shape.dimensions().begin(), - normalized_shape.dimensions().end()}; + Shape normalized_shape = MergeDimensions(segments, output_shape); if (segments.size() == 2) { - // If we have two segments, we know that at least one transpose is - // happening, otherwise we would have only 1 segment. - int64_t untransposed = 0; - while (untransposed < permutation.size() && - permutation[untransposed] != untransposed) { - ++untransposed; - } - // The desired permutation may not contain any untransposed dimension. With - // just 2 segments, we cannot uniquely match that. - if (untransposed == permutation.size()) { - return std::nullopt; - } - // Insert a 1-dimension at the position of the untransposed dimension. - normalized_dims.insert(normalized_dims.begin() + untransposed, 1); - } else if (segments.size() == 3) { - // Derive the order from the segments. - Vector3 segment_order{output_to_input[segments[0]], - output_to_input[segments[1]], - output_to_input[segments[2]]}; - // We expect the same relative order. - for (int64_t i = 1; i < 3; ++i) { - if ((segment_order[i] > segment_order[i - 1]) != - (permutation[i] > permutation[i - 1])) { - return std::nullopt; - } + // If we have two segments, we know that exactly two dimensions are swapped. + // Insert a 1-dimension at the front and detect a 021 transpose. + // TODO(b/328656780): Don't insert the extra 1-dimension once the emitter + // supports any number of dimensions >= 2. + permutation = {0, 2, 1}; + return absl::InlinedVector{1, normalized_shape.dimensions(0), + normalized_shape.dimensions(1)}; + } + // We have at least 3 segments. Derive the permutation from the segments. + std::vector segment_to_normalized_dim(output_shape.rank(), -1); + for (size_t segment : segments) { + segment_to_normalized_dim[output_to_input[segment]] = 0; + } + int64_t normalized_dim = 0; + for (int64_t i = 0; i < segment_to_normalized_dim.size(); ++i) { + if (segment_to_normalized_dim[i] >= 0) { + segment_to_normalized_dim[i] = normalized_dim++; } } - if (normalized_dims.size() == 3) { - return Vector3{normalized_dims[permutation[0]], - normalized_dims[permutation[1]], - normalized_dims[permutation[2]]}; + permutation.reserve(segments.size()); + for (int64_t i = 0; i < segments.size(); ++i) { + permutation.push_back( + segment_to_normalized_dim[output_to_input[segments[i]]]); } - return std::nullopt; + absl::InlinedVector normalized_dims( + normalized_shape.dimensions().begin(), + normalized_shape.dimensions().end()); + return normalized_dims; } -/* static */ std::optional +/* static */ std::optional> ShapeUtil::GetNormalizedLogicalTransposeShape( - const Shape& input_shape, const Shape& output_shape, - absl::Span dimensions, const Vector3& permutation) { - if (!LayoutUtil::IsMonotonicWithDim0Major(input_shape.layout()) || - !LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) { + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation) { + permutation.clear(); + if (!LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout())) { // Only works on default layouts. return std::nullopt; } // Drop degenerate dimensions. - std::vector delta(input_shape.rank() + 1, 0); - for (int i = 0; i < input_shape.rank(); ++i) { + absl::InlinedVector delta(output_shape.rank() + 1, 0); + auto input_dimensions = ComposePermutations(output_shape.dimensions(), + InversePermutation(dimensions)); + for (int i = 0; i < output_shape.rank(); ++i) { delta[i + 1] = delta[i]; - if (input_shape.dimensions(i) == static_cast(1)) { + if (input_dimensions[i] == static_cast(1)) { ++delta[i + 1]; } } - std::vector new_dimensions; + absl::InlinedVector new_dimensions; for (int i = 0; i < dimensions.size(); i++) { if (output_shape.dimensions(i) != 1) { new_dimensions.push_back(dimensions[i] - delta[dimensions[i]]); @@ -2094,24 +2085,29 @@ ShapeUtil::GetNormalizedLogicalTransposeShape( } return GetNormalizedTransposeShapeHelper( - DropDegenerateDimensions(input_shape), InversePermutation(new_dimensions), - permutation); + DropDegenerateDimensions(output_shape), new_dimensions, permutation); } -/* static */ std::optional ShapeUtil::GetNormalizedTransposeShape( +/* static */ std::optional> +ShapeUtil::GetNormalizedTransposeShape( const Shape& input_shape, const Shape& output_shape, - const Vector3& permutation) { + absl::InlinedVector& permutation) { + permutation.clear(); if (!ShapeUtil::CompatibleIgnoringElementType(input_shape, output_shape)) { return std::nullopt; } - std::vector major_to_minor_input = MajorToMinorLayout(input_shape); - std::vector major_to_minor_output = MajorToMinorLayout(output_shape); + absl::InlinedVector major_to_minor_input = + MajorToMinorLayout(input_shape); + absl::InlinedVector major_to_minor_output = + MajorToMinorLayout(output_shape); std::vector output_to_input = ComposePermutations( - InversePermutation(major_to_minor_output), major_to_minor_input); + InversePermutation(major_to_minor_input), major_to_minor_output); - return GetNormalizedTransposeShapeHelper(input_shape, output_to_input, - permutation); + return GetNormalizedTransposeShapeHelper( + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + output_shape), + output_to_input, permutation); } Shape ShapeUtil::DeviceShapeToHostShape(Shape s) { diff --git a/third_party/xla/xla/shape_util.h b/third_party/xla/xla/shape_util.h index a773adeaf08d6b..aeb043ebeb4d13 100644 --- a/third_party/xla/xla/shape_util.h +++ b/third_party/xla/xla/shape_util.h @@ -44,7 +44,6 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/printer.h" #include "xla/shape.h" -#include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" #include "tsl/platform/logging.h" // IWYU pragma: keep @@ -1012,10 +1011,9 @@ class ShapeUtil { const Shape& shape, const ForEachParallelVisitorFunction& visitor_function); - // In this case, we care about transposes that swap two dimensions of a - // a shape that can be viewed as three logical components 0-1-2 in the order - // of major to minor. - // As an example, let's consider a 0-2-1 transpose: + // In this case, we care about transposes that permute dimensions of a shape + // that can be viewed as several logical components in the order of major to + // minor. As an example, let's consider a 0-2-1 transpose: // // If a shape can be viewed as three logical components 0-1-2 in the order of // major to minor, a 0-2-1-transpose changes the order of such logical @@ -1029,15 +1027,18 @@ class ShapeUtil { // should be set to {0, 2, 1}. // If `b` is a 0-2-1 transpose of `a` in 0-1-2, return the dimensions for the // normalized shape of `b` or the 0-2-1 shape. In general, the - // permutation[0]-permutation[1]-permutation[2] shape is returned. - static std::optional GetNormalizedTransposeShape( - const Shape& input_shape, const Shape& output_shape, - const Vector3& permutation); + // permutation[0]-permutation[1]-...-permutation[permutation.size()-1] shape + // is returned. + static std::optional> + GetNormalizedTransposeShape(const Shape& input_shape, + const Shape& output_shape, + absl::InlinedVector& permutation); // Entry point for physical + logical transposition. - static std::optional GetNormalizedLogicalTransposeShape( - const Shape& input_shape, const Shape& output_shape, - absl::Span dimensions, const Vector3& permutation); + static std::optional> + GetNormalizedLogicalTransposeShape( + const Shape& output_shape, absl::Span dimensions, + absl::InlinedVector& permutation); // Strips device-specific information, namely tiling and memory-space // information, from a shape. diff --git a/third_party/xla/xla/shape_util_test.cc b/third_party/xla/xla/shape_util_test.cc index e7c1beb972958d..c35464af6d55c5 100644 --- a/third_party/xla/xla/shape_util_test.cc +++ b/third_party/xla/xla/shape_util_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "absl/container/inlined_vector.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -1391,167 +1392,173 @@ TEST(ShapeUtilTest, DecomposeBitcastToTrt) { EXPECT_FALSE(decomposition_trt.IsTranspose2Identity()); } -TEST(Transpose021Test, NoTranspose) { +TEST(NormalizedTransposeShapeTest, NoTranspose) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 128}, {0, 1}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - shape, transposed, Vector3{0, 2, 1})); + shape, transposed, permutation)); } -TEST(Transpose021Test, NoTranspose2) { +TEST(NormalizedTransposeShapeTest, NoTranspose2) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64, 32}, {2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 128}, {0, 1, 2}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - shape, transposed, Vector3{0, 1, 2})); + shape, transposed, permutation)); } -TEST(Transpose021Test, WrongTranspose) { - Shape input_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 2, 1})); -} - -TEST(Transpose021Test, WrongTranspose2) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 1, 2})); -} - -TEST(Transpose021Test, WrongTranspose3) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); - Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{1, 2, 0})); -} - -TEST(Transpose021Test, Simple) { +TEST(NormalizedTransposeShapeTest, Simple) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {128, 64}, {0, 1}); - EXPECT_EQ(std::make_optional(Vector3{1, 64, 128}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{1, 64, 128}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Simple2) { +TEST(NormalizedTransposeShapeTest, Simple2) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {1, 2, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{8, 16, 32768}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{0, 2, 1})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Simple3) { +TEST(NormalizedTransposeShapeTest, Simple3) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 32768, 16}, {0, 1, 2}); - EXPECT_EQ(std::make_optional(Vector3{16, 32768, 8}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{16, 32768, 8}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{2, 1, 0})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 1, 0})); } -TEST(Transpose021Test, Simple4) { - Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {1, 0}); +TEST(NormalizedTransposeShapeTest, NormalizedShapeRank4) { + Shape input_shape = + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 4, 8, 32768}, {2, 1, 3, 0}); Shape output_shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 16}, {0, 1}); - EXPECT_EQ(std::make_optional(Vector3{16, 1, 8}), - ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{2, 1, 0})); + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 4, 8, 32768}, {1, 0, 2, 3}); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{32768, 8, 16, 4}), + ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{1, 3, 0, 2})); } -TEST(Transpose021Test, LargeView) { +TEST(NormalizedTransposeShapeTest, LargeView) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {8, 32, 32, 32, 16}, {4, 3, 2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout( F32, {8, 32, 32, 32, 16}, {3, 2, 1, 4, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 16, 32768}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{8, 16, 32768}), ShapeUtil::GetNormalizedTransposeShape(input_shape, output_shape, - Vector3{0, 2, 1})); + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LargeSizeOverflowTest) { +TEST(NormalizedTransposeShapeTest, LargeSizeOverflowTest) { Shape input_shape = ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0}); Shape output_shape = ShapeUtil::MakeShapeWithDenseLayout(BF16, {4096, 4096, 128}, {2, 1, 0}); + absl::InlinedVector permutation; EXPECT_EQ(std::nullopt, ShapeUtil::GetNormalizedTransposeShape( - input_shape, output_shape, Vector3{0, 2, 1})); + input_shape, output_shape, permutation)); } -TEST(Transpose021Test, Batched) { +TEST(NormalizedTransposeShapeTest, Batched) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {1, 0, 2}); - EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{1, 64, 96}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, BatchedLogical) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 3, 64}, {2, 1, 0}); +TEST(NormalizedTransposeShapeTest, BatchedLogical) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {64, 32, 3}, {2, 1, 0}); std::vector dimensions = {2, 0, 1}; - EXPECT_EQ(std::make_optional(Vector3{1, 64, 96}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{1, 64, 96}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LogicalWithDegenerateDims) { - Shape shape = ShapeUtil::MakeShapeWithDenseLayout( - F32, {1, 32, 1, 3, 1, 64, 1}, {6, 5, 4, 3, 2, 1, 0}); +TEST(NormalizedTransposeShapeTest, LogicalWithDegenerateDims) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout( F32, {1, 32, 1, 64, 1, 3, 1}, {6, 5, 4, 3, 2, 1, 0}); std::vector dimensions = {6, 1, 4, 5, 2, 3, 0}; - EXPECT_EQ(std::make_optional(Vector3{32, 64, 3}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{32, 64, 3}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, LogicalWithDegenerateLastDim) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {1, 64, 32}, {2, 1, 0}); +TEST(NormalizedTransposeShapeTest, LogicalWithDegenerateLastDim) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {32, 64, 1}, {2, 1, 0}); std::vector dimensions = {2, 1, 0}; - EXPECT_EQ(std::make_optional(Vector3{1, 32, 64}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{1, 32, 64}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{0, 2, 1})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose021Test, Large) { +TEST(NormalizedTransposeShapeTest, Large) { Shape shape = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {3, 2, 1, 0}); Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {8, 31, 31, 65}, {2, 1, 3, 0}); - EXPECT_EQ(std::make_optional(Vector3{8, 65, 961}), - ShapeUtil::GetNormalizedTransposeShape(shape, transposed, - Vector3{0, 2, 1})); + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{8, 65, 961}), + ShapeUtil::GetNormalizedTransposeShape(shape, transposed, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{0, 2, 1})); } -TEST(Transpose210Test, LogicalTranspose) { - Shape shape = - ShapeUtil::MakeShapeWithDenseLayout(F32, {10, 11, 12, 13}, {3, 2, 1, 0}); +TEST(NormalizedLogicialTransposeShapeTest, LogicalTranspose) { Shape transposed = ShapeUtil::MakeShapeWithDenseLayout(F32, {13, 12, 10, 11}, {3, 2, 1, 0}); std::vector dimensions = {3, 2, 0, 1}; - EXPECT_EQ(std::make_optional(Vector3{13, 12, 110}), + absl::InlinedVector permutation; + EXPECT_EQ(std::make_optional(absl::InlinedVector{13, 12, 110}), ShapeUtil::GetNormalizedLogicalTransposeShape( - shape, transposed, dimensions, Vector3{2, 1, 0})); + transposed, dimensions, permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 1, 0})); +} + +TEST(NormalizedLogicalTransposeShapeTest, NormalizedShapeRank4) { + Shape transposed = + ShapeUtil::MakeShapeWithDenseLayout(F32, {16, 32768, 8, 4}, {3, 2, 1, 0}); + std::vector dimensions = {2, 0, 3, 1}; + absl::InlinedVector permutation; + EXPECT_EQ( + std::make_optional(absl::InlinedVector{16, 32768, 8, 4}), + ShapeUtil::GetNormalizedLogicalTransposeShape(transposed, dimensions, + permutation)); + EXPECT_EQ(permutation, (absl::InlinedVector{2, 0, 3, 1})); } TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) { diff --git a/third_party/xla/xla/sort_json.cc b/third_party/xla/xla/sort_json.cc new file mode 100644 index 00000000000000..aaa1e197a3fa26 --- /dev/null +++ b/third_party/xla/xla/sort_json.cc @@ -0,0 +1,257 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/sort_json.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" + +namespace { + +void SkipWhitespace(absl::string_view json, size_t& index) { + while (index < json.size() && std::isspace(json[index])) { + ++index; + } +} + +absl::Status CheckNotEndOfString(absl::string_view json, int index, + absl::string_view expected) { + return index < json.size() + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::StrCat( + "Prematurely reached end of JSON while looking for ", + expected, ".")); +} + +absl::Status Consume(absl::string_view json, size_t& index, char c, + bool optional = false) { + SkipWhitespace(json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(json, index, std::string(1, c))); + if (json[index] == c) { + ++index; + SkipWhitespace(json, index); + } else if (!optional) { + return absl::InvalidArgumentError( + absl::StrCat("Expected '", std::string(1, c), "', but found '", + std::string(1, json[index]), "'.")); + } + return absl::OkStatus(); +} + +struct JsonArray; +struct JsonObject; + +using JsonValue = std::variant, + std::unique_ptr>; + +struct JsonField { + absl::string_view name; + JsonValue value; +}; + +template +struct JsonSequence { + std::vector elements; +}; + +struct JsonArray : public JsonSequence {}; +struct JsonObject : public JsonSequence {}; + +// This parses either an array or an object. +template +absl::StatusOr> ParseSequence(absl::string_view outer_json, + size_t& index, + ElemFn elem_fn) { + TF_RETURN_IF_ERROR(Consume(outer_json, index, begin)); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name)); + + auto seq = std::make_unique(); + while (outer_json[index] != end) { + TF_ASSIGN_OR_RETURN(auto elem, elem_fn(outer_json, index)); + seq->elements.emplace_back(std::move(elem)); + TF_RETURN_IF_ERROR(Consume(outer_json, index, ',', /*optional=*/true)); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, name)); + } + TF_RETURN_IF_ERROR(Consume(outer_json, index, end)); + return seq; +} + +absl::Status EnsureValidLiteralStart(char c) { + if (c != '"' && c != '+' && c != '-' && c != 'f' && c != 't' && c != 'n' && + (c < '0' || c > '9')) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid first character of literal: '", std::string(1, c), "'.")); + } + return absl::OkStatus(); +} + +bool HandleEscape(absl::string_view outer_json, size_t& index, + bool& is_escaped) { + if (is_escaped) { + is_escaped = false; + ++index; + return true; + } + + if (outer_json[index] == '\\') { + is_escaped = true; + ++index; + return true; + } + return false; +} + +bool LiteralIsFinished(absl::string_view outer_json, size_t& index, + bool is_string_literal) { + char c = outer_json[index]; + if (is_string_literal) { + index += (c == '"' ? 1 : 0); + return c == '"'; + } + + return std::isspace(c) || c == ',' || c == '{' || c == '}' || c == '[' || + c == ']' || c == ':'; +} + +absl::StatusOr ParseLiteral(absl::string_view outer_json, + size_t& index) { + SkipWhitespace(outer_json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "literal")); + + auto c = outer_json[index]; + TF_RETURN_IF_ERROR(EnsureValidLiteralStart(c)); + bool is_string_literal = c == '"'; + size_t start_index = index; + bool is_escaped = false; + ++index; + + while (index < outer_json.size()) { + if (HandleEscape(outer_json, index, is_escaped)) { + continue; + } + if (LiteralIsFinished(outer_json, index, is_string_literal)) { + break; + } + ++index; + } + return outer_json.substr(start_index, index - start_index); +} + +absl::StatusOr ParseField(absl::string_view outer_json, + size_t& index); + +absl::StatusOr ParseValue(absl::string_view outer_json, + size_t& index) { + JsonValue value; + SkipWhitespace(outer_json, index); + TF_RETURN_IF_ERROR(CheckNotEndOfString(outer_json, index, "value")); + auto c = outer_json[index]; + if (c == '{') { + constexpr static char kObject[] = "object"; + auto seq = ParseSequence(outer_json, index, + ParseField); + TF_ASSIGN_OR_RETURN(value, std::move(seq)); + } else if (c == '[') { + constexpr static char kArray[] = "array"; + auto seq = ParseSequence(outer_json, index, + ParseValue); + TF_ASSIGN_OR_RETURN(value, std::move(seq)); + } else { + TF_ASSIGN_OR_RETURN(value, ParseLiteral(outer_json, index)); + } + return value; +} + +absl::StatusOr ParseField(absl::string_view outer_json, + size_t& index) { + JsonField field; + TF_ASSIGN_OR_RETURN(field.name, ParseLiteral(outer_json, index)); + TF_RETURN_IF_ERROR(Consume(outer_json, index, ':')); + TF_ASSIGN_OR_RETURN(field.value, ParseValue(outer_json, index)); + return field; +} + +template +std::vector SerializedElements(const JsonSequence& seq) { + std::vector result; + for (const auto& field : seq.elements) { + result.push_back(""); + Serialize(field, result.back()); + } + return result; +} + +template +void Serialize(const JsonSequence& object, std::string& result) { + auto elems = SerializedElements(object); + if constexpr (std::is_same_v) { + std::sort(elems.begin(), elems.end()); + } + + result += begin_brace; + bool has_preceeding = false; + for (const auto& elem : elems) { + if (has_preceeding) { + result += ','; + } + result += elem; + has_preceeding = true; + } + result += end_brace; +} + +void Serialize(const JsonValue& value, std::string& result) { + if (auto* lit = std::get_if(&value)) { + absl::StrAppend(&result, *lit); + } else if (auto* object = std::get_if>(&value)) { + Serialize(**object, result); + } else if (auto* array = std::get_if>(&value)) { + Serialize(**array, result); + } +} + +void Serialize(const JsonField& field, std::string& result) { + absl::StrAppend(&result, field.name, ":"); + Serialize(field.value, result); +} + +} // namespace + +namespace xla { +absl::StatusOr SortJson(absl::string_view json) { + size_t index = 0; + TF_ASSIGN_OR_RETURN(auto value, ParseValue(json, index)); + SkipWhitespace(json, index); + if (index < json.size()) { + return absl::InvalidArgumentError("Found trailing characters in JSON."); + } + std::string result; + Serialize(value, result); + return result; +} +} // namespace xla diff --git a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h b/third_party/xla/xla/sort_json.h similarity index 51% rename from third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h rename to third_party/xla/xla/sort_json.h index 5dcb51f282a635..b4283f556500ce 100644 --- a/third_party/xla/xla/service/spmd/shardy/round_trip_common/identity_to_pass_through_while_args.h +++ b/third_party/xla/xla/sort_json.h @@ -13,24 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_ -#define XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_ +#ifndef XLA_SORT_JSON_H_ +#define XLA_SORT_JSON_H_ -#include +#include -#include "mlir/Pass/Pass.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" namespace xla { -namespace sdy { -// Creates a pass that adds an identity op between pass-through block arguments -// of a while op. -std::unique_ptr createAddIdentityToPassThroughWhileArgsPass(); +// Sorts the given JSON string or returns an error if the JSON could not be +// parsed. Note that this function expects the input JSON to be valid and not +// all forms of invalid JSON are correctly recognized. This function completely +// ignores whitespace and the resulting JSON does not have any whitespace. +// Comments are not supported in the input JSON. +absl::StatusOr SortJson(absl::string_view json); -// Registers the xla-sdy-add-identity-to-pass-through-while-args pass. -void registerAddIdentityToPassThroughWhileArgsPass(); - -} // namespace sdy } // namespace xla -#endif // XLA_SERVICE_SPMD_SHARDY_ROUND_TRIP_COMMON_IDENTITY_TO_PASS_THROUGH_WHILE_ARGS_H_ +#endif // XLA_SORT_JSON_H_ diff --git a/third_party/xla/xla/sort_json_test.cc b/third_party/xla/xla/sort_json_test.cc new file mode 100644 index 00000000000000..f4ff0c1d785bc1 --- /dev/null +++ b/third_party/xla/xla/sort_json_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/sort_json.h" + +#include +#include +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/test.h" + +namespace xla { +namespace { + +using ::tsl::testing::IsOkAndHolds; + +TEST(SortJsonTest, SortsJson) { + EXPECT_THAT(SortJson(R"({"a": 1, "c": 3,"b": 2, "b": 1,})"), + IsOkAndHolds(R"({"a":1,"b":1,"b":2,"c":3})")); + + EXPECT_THAT(SortJson(R"({"a": 1 , "c": 1,"b": 1 })"), + IsOkAndHolds(R"({"a":1,"b":1,"c":1})")); + + EXPECT_THAT(SortJson(R"({"a": 1,"c": 3,"b": 2,"b": [3,2,1],})"), + IsOkAndHolds(R"({"a":1,"b":2,"b":[3,2,1],"c":3})")); + + EXPECT_THAT(SortJson(R"({"aa": 1, "a": {"c": "c", "b": "b"}})"), + IsOkAndHolds(R"({"a":{"b":"b","c":"c"},"aa":1})")); + + EXPECT_THAT( + SortJson( + R"({"x": true, "x": false, "x": null, "x": 0, "x": -0.5,"x": "a"})"), + IsOkAndHolds(R"({"x":"a","x":-0.5,"x":0,"x":false,"x":null,"x":true})")); + + EXPECT_THAT(SortJson(R"({"a": "a}", "a": "a"})"), + IsOkAndHolds(R"({"a":"a","a":"a}"})")); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/stream_executor/BUILD b/third_party/xla/xla/stream_executor/BUILD index 22cfd063ca6065..761f4bb0dfec73 100644 --- a/third_party/xla/xla/stream_executor/BUILD +++ b/third_party/xla/xla/stream_executor/BUILD @@ -201,16 +201,67 @@ cc_library( ], ) +cc_library( + name = "stream_finder", + srcs = ["stream_finder.cc"], + hdrs = ["stream_finder.h"], + deps = [ + ":platform", + ":stream", + ":stream_executor_h", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_platform", + testonly = True, + hdrs = ["mock_platform.h"], + deps = [ + ":device_description", + ":platform", + ":stream_executor_h", + "//xla:test", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "mock_stream", + testonly = True, + hdrs = ["mock_stream.h"], + deps = [ + ":device_description", + ":device_memory", + ":event", + ":event_based_timer", + ":kernel", + ":launch_dim", + ":platform", + ":stream", + "//xla:test", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "mock_stream_executor", testonly = True, hdrs = ["mock_stream_executor.h"], deps = [ ":allocator_stats", + ":blas", ":command_buffer", ":device_description", ":device_memory", + ":dnn", ":event", + ":fft", ":kernel", ":kernel_spec", ":launch_dim", @@ -220,7 +271,6 @@ cc_library( ":stream", ":stream_executor_h", "//xla:test", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -380,6 +430,7 @@ cc_library( ":device_memory", ":numeric_options", "//xla/stream_executor/platform", + "//xla/tsl/lib/strings:proto_serialization", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", @@ -391,7 +442,6 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", # buildcleaner: keep - "@local_tsl//tsl/lib/strings:proto_serialization", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:status", @@ -449,12 +499,10 @@ cc_library( ":fft", ":kernel", ":kernel_spec", - ":launch_dim", ":memory_allocation", ":module_spec", ":platform", ":stream", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -574,15 +622,15 @@ cc_library( srcs = ["executor_cache.cc"], hdrs = ["executor_cache.h"], deps = [ - ":platform", ":stream_executor_h", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:statusor", ], alwayslink = True, ) @@ -619,25 +667,11 @@ cc_library( alwayslink = True, ) -cc_library( - name = "kernel_factory", - hdrs = ["kernel_factory.h"], - deps = [ - ":kernel", - ":kernel_spec", - ":stream_executor_h", - "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - cc_library( name = "typed_kernel_factory", hdrs = ["typed_kernel_factory.h"], deps = [ ":kernel", - ":kernel_factory", ":kernel_spec", ":stream_executor_h", "@com_google_absl//absl/status:statusor", @@ -686,13 +720,10 @@ cc_library( ":blas", ":device_description", ":fft", - ":kernel", - ":launch_dim", ":platform", ":stream", ":stream_executor_h", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -828,6 +859,36 @@ xla_cc_test( ], ) +xla_cc_test( + name = "executor_cache_test", + srcs = ["executor_cache_test.cc"], + deps = [ + ":executor_cache", + ":mock_stream_executor", + ":stream", + "@com_google_absl//absl/log", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + +xla_cc_test( + name = "stream_finder_test", + srcs = ["stream_finder_test.cc"], + deps = [ + ":mock_platform", + ":mock_stream", + ":mock_stream_executor", + ":stream_finder", + "//xla:test", + "@com_google_absl//absl/status", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + #===--------------------------------------------------------------------------------------------===# # Aliases for StreamExecutor platforms #===--------------------------------------------------------------------------------------------===# diff --git a/third_party/xla/xla/stream_executor/command_buffer.h b/third_party/xla/xla/stream_executor/command_buffer.h index 5cb39e857f7fbb..2b92b504f2059a 100644 --- a/third_party/xla/xla/stream_executor/command_buffer.h +++ b/third_party/xla/xla/stream_executor/command_buffer.h @@ -327,6 +327,11 @@ class CommandBuffer { return While(kDefaulExecutionScope, pred, cond_builder, body_builder); } + // Submits the command buffer for execution. + virtual absl::Status Submit(Stream* stream) { + return absl::UnimplementedError("Not implemented for this command buffer."); + } + //--------------------------------------------------------------------------// // Command buffer state management API //--------------------------------------------------------------------------// diff --git a/third_party/xla/xla/stream_executor/cuda/BUILD b/third_party/xla/xla/stream_executor/cuda/BUILD index 77fc2e0f9a50b3..b0b0bf5e608d24 100644 --- a/third_party/xla/xla/stream_executor/cuda/BUILD +++ b/third_party/xla/xla/stream_executor/cuda/BUILD @@ -10,6 +10,7 @@ load( load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", + "if_cuda_newer_than", ) load( "//xla:xla.bzl", @@ -28,7 +29,14 @@ load( "tf_additional_gpu_compilation_copts", ) load("//xla/tests:build_defs.bzl", "xla_test") -load("//xla/tsl:tsl.bzl", "if_google", "if_nccl", "internal_visibility", "tsl_copts") +load( + "//xla/tsl:tsl.bzl", + "if_google", + "if_hermetic_cuda_tools", + "if_nccl", + "internal_visibility", + "tsl_copts", +) package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -133,9 +141,21 @@ cuda_only_cc_library( # Buildozer can not remove dependencies inside select guards, so we have to use # an intermediate target. -cc_library(name = "ptxas_wrapper") +cc_library( + name = "ptxas_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:ptxas"], + [], + ), +) -cc_library(name = "nvlink_wrapper") +cc_library( + name = "nvlink_wrapper", + data = if_hermetic_cuda_tools( + ["@cuda_nvcc//:nvlink"], + [], + ), +) cuda_only_cc_library( name = "cuda_driver", @@ -225,12 +245,9 @@ xla_test( name = "cuda_driver_test", srcs = ["cuda_driver_test.cc"], backends = ["gpu"], - tags = [ - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - "gpu", - "no_rocm", - ], + tags = ["no_rocm"], deps = [ + ":cuda_diagnostics", ":cuda_driver", ":cuda_status", "//xla/stream_executor/gpu:gpu_driver_header", @@ -428,7 +445,6 @@ gpu_kernel_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_semaphore", - "//xla/stream_executor/gpu:gpu_stream", "@com_google_absl//absl/status:statusor", ], ) @@ -590,6 +606,7 @@ cc_library( cc_library( name = "ptx_compiler", hdrs = ["ptx_compiler.h"], + tags = ["no_rocm"], deps = select({ ":libnvptxcompiler_support_enabled": [":ptx_compiler_impl"], "//conditions:default": [":ptx_compiler_stub"], @@ -599,11 +616,31 @@ cc_library( ], ) +xla_test( + name = "cuda_platform_test", + srcs = ["cuda_platform_test.cc"], + backends = ["gpu"], + tags = ["no_rocm"], + deps = [ + ":cuda_platform", + "//xla/stream_executor:platform", + "//xla/stream_executor:platform_manager", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + ], +) + xla_cc_test( name = "ptx_compiler_test", srcs = ["ptx_compiler_test.cc"], - # TODO(b/343996893): Figure out whether msan reports a false positive or not. - tags = ["nomsan"], + tags = [ + "no_rocm", + # TODO(b/343996893): Figure out whether msan reports a false positive or not. + "nomsan", + ], deps = [ ":ptx_compiler", ":ptx_compiler_support", @@ -629,7 +666,11 @@ cc_library( "//conditions:default": [ "LIBNVJITLINK_SUPPORT=false", ], - }), + }) + if_cuda_newer_than( + "12_0", + ["CUDA_SUPPORTS_NVJITLINK=true"], + ["CUDA_SUPPORTS_NVJITLINK=false"], + ), ) cc_library( @@ -672,13 +713,25 @@ cc_library( ], ) +# Since select() can't be nested, we need to wrap the cuda_newer_than check in a separate +# library target. +cc_library( + name = "nvjitlink_cuda_supported", + # Even though the macro is called `*_newer_than`, it does a greater-than-or-equal-to comparison. + deps = if_cuda_newer_than( + "12_0", + [":nvjitlink_impl"], + [":nvjitlink_stub"], + ), +) + cc_library( name = "nvjitlink", hdrs = [ "nvjitlink.h", ], deps = select({ - ":libnvjitlink_support_enabled": [":nvjitlink_impl"], + ":libnvjitlink_support_enabled": [":nvjitlink_cuda_supported"], "//conditions:default": [":nvjitlink_stub"], }) + [ "//xla/stream_executor/gpu:gpu_asm_opts", @@ -691,7 +744,7 @@ xla_cc_test( name = "nvjitlink_test", srcs = ["nvjitlink_test.cc"], args = if_google([ - # nvjitlink allocates memory and only keeps a pointer past the usual offest of 1024 bytes; + # nvjitlink allocates memory and only keeps a pointer past the usual offset of 1024 bytes; # so we need to increase the max pointer offset. -1 means no limit. # This is only relevant for Google's HeapLeakChecker. The newer Leak sanitizer doesn't # have this issue. @@ -729,6 +782,13 @@ cuda_only_cc_library( # "@local_config_cuda//cuda:runtime_ptxas", # ], # copybara:uncomment_end + # copybara:comment_begin + data = if_hermetic_cuda_tools([ + "@cuda_nvcc//:fatbinary", + "@cuda_nvcc//:nvlink", + "@cuda_nvcc//:ptxas", + ]), + # copybara:comment_end visibility = internal_visibility([ "//third_party/py/jax:__subpackages__", "//tensorflow/compiler/mlir/tools/kernel_gen:__subpackages__", diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc index 01aa15313c2cd0..7f2183f85a0a95 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc index a4337dfe60e497..9d200e74dcada8 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.cc @@ -41,6 +41,7 @@ limitations under the License. #include "xla/stream_executor/blas.h" #include "xla/stream_executor/cuda/cuda_blas.h" #include "xla/stream_executor/cuda/cuda_blas_utils.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_activation.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -448,12 +449,15 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, b_scale.opaque())); } - if (c_scale != nullptr) { + auto isF8Input = [](const auto& desc) { + return desc.type() == CUDA_R_8F_E4M3 || desc.type() == CUDA_R_8F_E5M2; + }; + if (c_scale != nullptr && isF8Input(c_desc_)) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, c_scale.opaque())); } - if (d_scale != nullptr) { + if (d_scale != nullptr && isF8Input(d_desc_)) { TF_RETURN_IF_ERROR(SetAttr(op_desc_.get(), CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, d_scale.opaque())); diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h index 2fae670f87edca..3d61c816024af9 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_blas_lt.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc index 561ac0d401e2f2..155c3383e7d843 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_diagnostics.cc @@ -108,6 +108,8 @@ namespace gpu { #if !defined(PLATFORM_WINDOWS) static const char *kDriverVersionPath = "/proc/driver/nvidia/version"; +#else +static const char *kDriverVersionPath = "NO NVIDIA DRIVER VERSION FILE"; #endif // -- class Diagnostician @@ -223,7 +225,7 @@ absl::StatusOr Diagnostician::FindDsoVersion() { absl::StatusOr Diagnostician::FindKernelModuleVersion( const std::string &driver_version_file_contents) { - static const char *kDriverFilePrelude = "Kernel Module "; + static const char *kDriverFilePrelude = "Kernel Module"; size_t offset = driver_version_file_contents.find(kDriverFilePrelude); if (offset == std::string::npos) { return absl::NotFoundError( @@ -231,9 +233,17 @@ absl::StatusOr Diagnostician::FindKernelModuleVersion( "driver version file contents: \"", driver_version_file_contents, "\"")); } + static const char *kDriverVersionPrelude = " "; + offset = driver_version_file_contents.find(kDriverVersionPrelude, offset); + if (offset == std::string::npos) { + return absl::NotFoundError( + absl::StrCat("driver version not preceded by two spaces in " + "driver version file contents: \"", + driver_version_file_contents, "\"")); + } std::string version_and_rest = driver_version_file_contents.substr( - offset + strlen(kDriverFilePrelude), std::string::npos); + offset + strlen(kDriverVersionPrelude), std::string::npos); size_t space_index = version_and_rest.find(' '); auto kernel_version = version_and_rest.substr(0, space_index); // TODO(b/22689637): Eliminate the explicit namespace if possible. diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc index bbc6a6dc2cca79..440f647b84f1ce 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc @@ -40,7 +40,6 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" @@ -1749,8 +1748,8 @@ absl::Status CheckAndFetchProjectionWeights( int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); #endif // CUDNN_VERSION >= 8100 - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast(offset), size}; + dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), + size}; weights->push_back(region); } return absl::OkStatus(); @@ -1891,8 +1890,8 @@ absl::StatusOr CudnnRnnParamsDescriptor::Create( /*nbDims=*/&n_dims, /*filterDimA=*/dims)); int64_t size = dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type); - dnn::RnnDescriptor::ParamsRegion region = { - reinterpret_cast(offset), size}; + dnn::RnnDescriptor::ParamsRegion region = {static_cast(offset), + size}; (type == 0 ? weights : biases).push_back(region); } #endif // CUDNN_VERSION >= 8100 @@ -3762,32 +3761,6 @@ absl::StatusOr CreateCudnnTensor( } #if CUDNN_VERSION >= 8800 -enum CudnnfMHAUid { - Q_ID = 400, - K_ID, - V_ID, - P_ID, - O_ID, - dQ_ID, - dK_ID, - dV_ID, - dP_ID, - dO_ID, - dS_ID, - dBIAS_ID, - BIAS_ID, - MASK_ID, - ZERO_VAL_ID, - ONE_VAL_ID, - NEG_INFINITY_ID, - ALPHA_SCALE_ID, - DROPOUT_SCALE_ID, - Q_SEQLEN_ID, - K_SEQLEN_ID, - D_OFFSET_ID, - D_SEED_ID, - VIRTUAL_ID = 34857 -}; absl::StatusOr CreatePwDesc( dnn::DataType dtype, cudnnPointwiseMode_t mode) { @@ -3842,49 +3815,6 @@ absl::StatusOr CreateTernaryPwOp( RETURN_MSG_IF_CUDNN_ERROR(pw_op_created); return pw_op_created; } - -// Returns a cudnn tensor that's the output of the mask op -absl::StatusOr CreateCudnnMaskFwdTensor( - std::vector& ops, absl::Span dims, - absl::Span strides, dnn::DataType dtype, - cudnn_frontend::Tensor& input_tensor) { - std::vector mask_dim(dims.size(), 1); - std::vector mask_stride(strides.size(), 1); - - // Create the mask tensor - TF_ASSIGN_OR_RETURN( - auto mask_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::MASK_ID, dtype, 1, -1, - /*is_virtual=*/false)); - // Create the mask output tensor - TF_ASSIGN_OR_RETURN( - auto mask_out_tensor, - CreateCudnnTensor(dims, strides, CudnnfMHAUid::VIRTUAL_ID + 400, - dnn::DataType::kFloat, 1, -1, - /*is_virtual=*/true)); - - auto mask_desc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_MUL) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - - // Create the mask op. - auto mask_op = cudnn_frontend::OperationBuilder( - CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) - .setxDesc(input_tensor) - .setbDesc(mask_tensor) - .setyDesc(mask_out_tensor) - .setpwDesc(mask_desc) - .build(); - - RETURN_MSG_IF_CUDNN_ERROR(mask_op); - - RETURN_MSG_IF_CUDNN_ERROR(mask_out_tensor); - // Add mask to op list - ops.push_back(std::move(mask_op)); - - return mask_out_tensor; -} #endif // CUDNN_VERSION >= 8800 absl::StatusOr> @@ -5047,7 +4977,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( const dnn::FMHAMaskKind mask_type) { using cudnn_frontend::graph::Tensor_attributes; -#if CUDNN_VERSION >= 8904 +#if CUDNN_VERSION >= 90000 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString() << "\n bmm1_rhs(k): " << k_descriptor.ToString() @@ -5075,12 +5005,14 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_io_data_type(ioDataType) .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q_tensor = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(q_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::Q_ID)); + .set_uid(next_uid())); auto dim = k_descriptor.GetCudnnCompatibleDimensions(true); std::shared_ptr k_tensor = @@ -5088,13 +5020,13 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("K") .set_dim(k_descriptor.GetCudnnCompatibleDimensions(true)) .set_stride(k_descriptor.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::K_ID)); + .set_uid(next_uid())); std::shared_ptr v_tensor = graph.tensor( Tensor_attributes() .set_name("V") .set_dim(v_descriptor.GetCudnnCompatibleDimensions(false)) .set_stride(v_descriptor.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::V_ID)); + .set_uid(next_uid())); // Setting sdpa, and is_inference bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || @@ -5112,7 +5044,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_options.set_bias(bias_tensor); } // Setting actual seqlen @@ -5126,37 +5058,38 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_options.set_padding_mask(true); sdpa_options.set_seq_len_q(seq_q_tensor); sdpa_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5170,7 +5103,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_output(true) .set_dim(o_descriptor.dimensions()) .set_stride(o_descriptor.GetLogicalStrides()) - .set_uid(CudnnfMHAUid::O_ID); + .set_uid(next_uid()); if (stats_descriptor.has_value()) { cudnn_frontend::DataType_t statsType = ToCudnnFrontendDataType(stats_descriptor->type()); @@ -5183,11 +5116,19 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( .set_data_type(statsType) .set_dim(stat_dims) .set_stride(stat_strides) - .set_uid(CudnnfMHAUid::P_ID); + .set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support)); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt)); + TF_RETURN_IF_ERROR(cudnnGraph.Prepare( + dnn_support, NumericOptions{/*require_determinism=*/false, + /*allow_tf32=*/true})); + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); if (VLOG_IS_ON(4)) { VLOG(4) << "\b flash attention operation graph: " << graph; @@ -5195,7 +5136,7 @@ absl::StatusOr GetCudnnFlashAttentionOperationGraph( return cudnnGraph; #else return absl::UnimplementedError( - "Cudnn flash attention only supported with Cudnn >= 8.9.4"); + "Cudnn flash attention only supported with Cudnn >= 9.0.0"); #endif } @@ -5211,7 +5152,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( std::optional dropout_rate, std::optional seed, double scale, bool use_dropout, bool use_bias, dnn::FMHAMaskKind mask_type, bool force_deterministic) { -#if CUDNN_VERSION >= 8904 +#if CUDNN_VERSION >= 90000 if (VLOG_IS_ON(4)) { VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString() << "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString() @@ -5236,71 +5177,66 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_intermediate_data_type(cudnn_frontend::DataType_t::FLOAT) .set_io_data_type(ioDataType); + auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); + auto p_strides = p_desc.GetCudnnCompatibleStrides(false); + std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); + p_reduction_dims.push_back(1); + + // Divide every stride by the last dim value. + std::vector p_reduction_strides; + p_reduction_strides.reserve(p_strides.size()); + int64_t p_reduced_dim_len = p_dims.back(); + for (auto stride : p_strides) { + p_reduction_strides.push_back(stride / p_reduced_dim_len); + } + p_reduction_strides[3] = 1; + bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || + mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; + auto sdpa_backward_options = + cudnn_frontend::graph::SDPA_backward_attributes() + .set_name("flash_attention_backward") + .set_causal_mask(is_causal) + .set_attn_scale(scale) + .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; + std::shared_ptr q = graph.tensor(Tensor_attributes() .set_name("Q") .set_dim(q_desc.GetCudnnCompatibleDimensions(false)) .set_stride(q_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::Q_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr k = graph.tensor(Tensor_attributes() .set_name("K") .set_dim(k_desc.GetCudnnCompatibleDimensions(false)) .set_stride(k_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::K_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); std::shared_ptr v = graph.tensor(Tensor_attributes() .set_name("V") .set_dim(v_desc.GetCudnnCompatibleDimensions(true)) .set_stride(v_desc.GetCudnnCompatibleStrides(true)) - .set_uid(CudnnfMHAUid::V_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - std::shared_ptr o = + std::shared_ptr stats = graph.tensor(Tensor_attributes() - .set_name("O") - .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) - .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::O_ID) - .set_data_type(ioDataType)); + .set_name("stats") + .set_dim(p_reduction_dims) + .set_stride(p_reduction_strides) + .set_uid(next_uid()) + .set_data_type(cudnn_frontend::DataType_t::FLOAT)); std::shared_ptr dO = graph.tensor(Tensor_attributes() .set_name("dO") .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) .set_stride(do_desc.GetCudnnCompatibleStrides(false)) - .set_uid(CudnnfMHAUid::dO_ID) + .set_uid(next_uid()) .set_data_type(ioDataType)); - auto p_dims = p_desc.GetCudnnCompatibleDimensions(false); - auto p_strides = p_desc.GetCudnnCompatibleStrides(false); - std::vector p_reduction_dims(p_dims.begin(), p_dims.end() - 1); - p_reduction_dims.push_back(1); - - // Divide every stride by the last dim value. - std::vector p_reduction_strides; - p_reduction_strides.reserve(p_strides.size()); - int64_t p_reduced_dim_len = p_dims.back(); - for (auto stride : p_strides) { - p_reduction_strides.push_back(stride / p_reduced_dim_len); - } - p_reduction_strides[3] = 1; - std::shared_ptr stats = - graph.tensor(Tensor_attributes() - .set_name("stats") - .set_dim(p_reduction_dims) - .set_stride(p_reduction_strides) - .set_uid(CudnnfMHAUid::P_ID) - .set_data_type(cudnn_frontend::DataType_t::FLOAT)); - bool is_causal = mask_type == dnn::FMHAMaskKind::CAUSAL || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - auto sdpa_backward_options = - cudnn_frontend::graph::SDPA_backward_attributes() - .set_name("flash_attention_backward") - .set_causal_mask(is_causal) - .set_attn_scale(scale) - .set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); - - // Setting bias + std::shared_ptr d_bias_tensor; if (use_bias) { DCHECK(bias_descriptor != std::nullopt); auto bias_dim = bias_descriptor->dimensions(); @@ -5313,21 +5249,29 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("bias") .set_dim(bias_descriptor->dimensions()) .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::BIAS_ID)); + .set_uid(next_uid())); sdpa_backward_options.set_bias(bias_tensor); // shapes [1, 1, s, s], [b, 1, s, s], [b, h, s, s] are not supported for // dbias calculation but they are supported for forward bias calculation + // Set UID later: this is the last output tuple element. if (b == 1 && n == q_n) { - auto d_bias_tensor = + d_bias_tensor = graph.tensor(Tensor_attributes() .set_name("dBias") .set_dim(bias_descriptor->dimensions()) - .set_stride(bias_descriptor->GetLogicalStrides()) - .set_uid(CudnnfMHAUid::dBIAS_ID)); + .set_stride(bias_descriptor->GetLogicalStrides())); sdpa_backward_options.set_dbias(d_bias_tensor); } } + std::shared_ptr o = + graph.tensor(Tensor_attributes() + .set_name("O") + .set_dim(do_desc.GetCudnnCompatibleDimensions(false)) + .set_stride(do_desc.GetCudnnCompatibleStrides(false)) + .set_uid(next_uid()) + .set_data_type(ioDataType)); + // Setting actual seqlen bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; @@ -5339,38 +5283,39 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( .set_name("seq_q") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::Q_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); auto seq_kv_tensor = graph.tensor(Tensor_attributes() .set_name("seq_kv") .set_dim({b, 1, 1, 1}) .set_stride({1, 1, 1, 1}) - .set_uid(CudnnfMHAUid::K_SEQLEN_ID) + .set_uid(next_uid()) .set_data_type(cudnn_frontend::DataType_t::INT32)); sdpa_backward_options.set_padding_mask(true); sdpa_backward_options.set_seq_len_q(seq_q_tensor); sdpa_backward_options.set_seq_len_kv(seq_kv_tensor); } // Setting seed and offset + std::shared_ptr seed_tensor; + std::shared_ptr offset_tensor; if (use_dropout) { DCHECK(dropout_rate != std::nullopt); - auto seed_tensor = + // Skip setting UIDs: pass by value tensors go at the end. + seed_tensor = graph.tensor(Tensor_attributes() .set_name("seed") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_SEED_ID)); - auto offset_tensor = + .set_is_pass_by_value(true)); + offset_tensor = graph.tensor(Tensor_attributes() .set_name("offset") .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(cudnn_frontend::DataType_t::INT64) - .set_is_pass_by_value(true) - .set_uid(CudnnfMHAUid::D_OFFSET_ID)); + .set_is_pass_by_value(true)); sdpa_backward_options.set_dropout((float)dropout_rate.value(), seed_tensor, offset_tensor); } @@ -5385,25 +5330,36 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( dQ->set_output(true) .set_dim(dq_desc.dimensions()) .set_stride(dq_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dQ") - .set_uid(CudnnfMHAUid::dQ_ID) .set_data_type(ioDataType); dK->set_output(true) .set_dim(dk_desc.dimensions()) .set_stride(dk_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dK") - .set_uid(CudnnfMHAUid::dK_ID) .set_data_type(ioDataType); dV->set_output(true) .set_dim(dv_desc.dimensions()) .set_stride(dv_desc.GetLogicalStrides()) + .set_uid(next_uid()) .set_name("dV") - .set_uid(CudnnfMHAUid::dV_ID) .set_data_type(ioDataType); + if (d_bias_tensor != nullptr) { + d_bias_tensor->set_uid(next_uid()); + } + if (seed_tensor != nullptr) { + seed_tensor->set_uid(next_uid()); + } + if (offset_tensor != nullptr) { + offset_tensor->set_uid(next_uid()); + } CudnnGraph cudnnGraph(std::move(graph)); - TF_RETURN_IF_ERROR(cudnnGraph.Prepare(dnn_support)); - TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, std::nullopt)); + TF_RETURN_IF_ERROR( + cudnnGraph.Prepare(dnn_support, NumericOptions{force_deterministic, + /*allow_tf32=*/true})); + TF_RETURN_IF_ERROR(cudnnGraph.Build(dnn_support, /*plan_id=*/std::nullopt)); if (VLOG_IS_ON(4)) { VLOG(4) << "\b flash attention operation backward graph: " << graph; @@ -5412,7 +5368,7 @@ absl::StatusOr GetCudnnFlashAttentionBackwardOperationGraph( return cudnnGraph; #else return absl::UnimplementedError( - "Cudnn flash attention only supported with Cudnn >= 8.9.4"); + "Cudnn flash attention only supported with Cudnn >= 9.0.0"); #endif } @@ -5735,8 +5691,8 @@ absl::Status CudnnSupport::DoConvolve( } // Utility for dealing with CUDA's type-erased scaling parameters, where some -// sets of parameters expect a void* pointing at a float while others expect it -// to point at a double. +// sets of parameters expect a void* pointing at a float while others expect +// it to point at a double. // // This is rather ugly, but its purpose is to quarantine the corresponding // ugliness that already exists in the CUDA API. @@ -5760,9 +5716,9 @@ class ScalingParam { // // See // https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#scaling-parameters - // for more info; the behavior for int8 result tensors is not described there, - // but is maintained from the existing behavior (namely, using a float scaling - // parameter). + // for more info; the behavior for int8 result tensors is not described + // there, but is maintained from the existing behavior (namely, using a + // float scaling parameter). void* ToVoidPointer(dnn::DataType element_type) { if (element_type == dnn::DataType::kDouble) { return &as_double_; @@ -5834,10 +5790,11 @@ absl::StatusOr> GetDescriptorAttribute( absl::c_transform(result, std::back_inserter(raw_ptrs), [](const BackendDescriptor& ptr) { return ptr.get(); }); - // This API evidently does a deep copy of the descriptors into the pointers in - // the output array, rather than writing pointers to the descriptors into the - // output array. So, this writes the memory behind each BackendDescriptor in - // result, rather than writing the contents of raw_ptrs. + // This API evidently does a deep copy of the descriptors into the pointers + // in the output array, rather than writing pointers to the descriptors into + // the output array. So, this writes the memory behind each + // BackendDescriptor in result, rather than writing the contents of + // raw_ptrs. RETURN_IF_CUDNN_ERROR(cudnnBackendGetAttribute( desc, name, CUDNN_TYPE_BACKEND_DESCRIPTOR, n, &n, raw_ptrs.data())); @@ -5873,9 +5830,9 @@ absl::StatusOr ExecutionPlanToAlgorithmDesc( cudnnBackendGetAttribute(engines[0].get(), CUDNN_ATTR_ENGINE_GLOBAL_INDEX, CUDNN_TYPE_INT64, 1, &n, &engine_id)); - // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query the - // number of elements in the attribute by using an output limit value of 0 - // just returns 0; the only way to find out how many there are is to + // Apparently for CUDNN_ATTR_ENGINECFG_KNOB_CHOICES only, trying to query + // the number of elements in the attribute by using an output limit value of + // 0 just returns 0; the only way to find out how many there are is to // pre-allocate space for every existing knob type (as an upper bound on the // number of knob choices a config can have), and then look back at how many // were filled. @@ -6086,103 +6043,7 @@ class CudnnExecutionPlanRunner std::vector scalar_input_uids_; std::vector scalar_input_values_; }; -#endif // CUDNN_VERSION >= 8100 - -template -class CudnnGraphRunner; -// An OpRunner implemented by a cuDNN frontend graph. -// -// This is the class holding the implementation of ToString, GetWorkspaceSize, -// and operator() for use by the cudnn frontend op runners. -template -class CudnnGraphRunner : public dnn::OpRunner { - private: - using Graph = cudnn_frontend::graph::Graph; - using Tensor_attributes = cudnn_frontend::graph::Tensor_attributes; - - public: - std::string ToString() const override { return graph_.Graph().print(); } - - size_t GetWorkspaceSize() const override { - return graph_.Graph().get_workspace_size(); - } - absl::StatusOr ToAlgorithmDesc() const override { - return absl::InternalError( - "Unexpected call to CudnnGraphRunner::ToAlgorithmDesc"); - } - - absl::Status operator()(Stream* stream, dnn::ProfileResult* profile_result, - DeviceMemoryBase scratch_memory, - Args... inputs) const override { - if (parent_ != stream->parent()) { - return tsl::errors::Internal( - "CudnnExecutionPlanRunner cached across multiple StreamExecutors."); - } - CudnnHandle handle = cudnn_->GetHandle(parent_, stream); - std::unordered_map variant_pack; - std::vector vec = {inputs.opaque()...}; - - // add device buffers to the variant pack - for (int i = 0; i < uids_.size(); ++i) { - if (uids_[i].has_value()) { - variant_pack[*uids_[i]] = vec[i]; - } - } - if (dropout_rng_offset_increment_ > 0) { -#if CUDNN_VERSION >= 8800 - variant_pack[CudnnfMHAUid::D_SEED_ID] = (void*)&dropout_rng_seed_; - current_dropout_rng_offset_ += dropout_rng_offset_increment_; - variant_pack[CudnnfMHAUid::D_OFFSET_ID] = - (void*)¤t_dropout_rng_offset_; -#else - return absl::UnimplementedError( - "Cudnn dropout offset and seed are only supported with Cudnn >= " - "8.8.0"); -#endif // CUDNN_VERSION >= 8800 - } - int workspace = graph_.Graph().get_workspace_size(); - if (workspace > scratch_memory.size()) { - return tsl::errors::Internal( - absl::StrFormat("CuDNN FMHA requires %d workspace, got %d workspace.", - workspace, scratch_memory.size())); - } - RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.Graph().execute( - handle.handle(), variant_pack, scratch_memory.opaque())); - - return absl::OkStatus(); - } - - static absl::StatusOr Create( - GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) { - return CudnnGraphRunner(parent, cudnn, std::move(graph), dropout_rng_seed, - dropout_rng_offset, uids); - } - - private: - CudnnGraphRunner(GpuExecutor* parent, CudnnAccess* cudnn, CudnnGraph graph, - int64_t dropout_rng_seed, int64_t dropout_rng_offset, - std::vector> uids) - : parent_(parent), - cudnn_(cudnn), - graph_(std::move(graph)), - dropout_rng_seed_(dropout_rng_seed), - current_dropout_rng_offset_(0), - dropout_rng_offset_increment_(dropout_rng_offset), - uids_(uids) {} - GpuExecutor* parent_; - CudnnAccess* cudnn_; - Stream* stream_; - CudnnGraph graph_; - int64_t dropout_rng_seed_; - mutable int64_t current_dropout_rng_offset_; - int64_t dropout_rng_offset_increment_; - std::vector> uids_; -}; - -#if CUDNN_VERSION >= 8100 namespace { template @@ -6968,7 +6829,8 @@ absl::Status CudnnSupport::GetFusedMatmulRunners( use_fallback, out_exec_plans, /*need_side_input=*/true, numeric_options); #else return tsl::errors::Unimplemented( - "Cudnn execution plans for matmul are only supported with Cudnn >= 8.4."); + "Cudnn execution plans for matmul are only supported with Cudnn >= " + "8.4."); #endif // CUDNN_VERSION >= 8400 } @@ -7170,139 +7032,6 @@ int64_t GetDropoutRngOffset(std::vector& intermediate_shape) { return max_seq_len * max_seq_len / cudnn_mha_num_threads; } -absl::StatusOr> -CudnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { -#if CUDNN_VERSION >= 8904 - auto cudnn = cudnn_->GetHandle(parent_, stream); - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN(auto graph, - GetCudnnFlashAttentionOperationGraph( - *this, /*q_descriptor=*/bmm1_lhs_descriptor, - /*k_descriptor=*/bmm1_rhs_descriptor, - /*v_descriptor=*/bmm2_rhs_descriptor, - /*o_descriptor=*/output_descriptor, bias_descriptor, - /*stats_descriptor=*/activation_descriptor, - /*scale=*/static_cast(scale), use_dropout, - dropout_rate, mask_type)); - - std::vector intermediate_bmm2_lhs_dims = - intermediate_bmm2_lhs_descriptor.GetCudnnCompatibleDimensions(true); - intermediate_shape = intermediate_bmm2_lhs_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - std::vector> uids = { - CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::V_ID, - CudnnfMHAUid::O_ID}; - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - uids.emplace_back(activation_descriptor.has_value() - ? std::optional(CudnnfMHAUid::P_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), std::move(graph), - dropout_rng_seed, dropout_rng_offset, uids)); - - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention are only supported with Cudnn >= 8.9.4"); -#endif // CUDNN_VERSION >= 8904 -} - -absl::StatusOr> -CudnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { -#if CUDNN_VERSION >= 8904 - auto cudnn = cudnn_->GetHandle(parent_, stream); - - bool use_dropout = dropout_rate && *dropout_rate > 0.0; - std::vector intermediate_shape; - - TF_ASSIGN_OR_RETURN( - auto graph, - GetCudnnFlashAttentionBackwardOperationGraph( - *this, bmm1_grad_gemm1_rhs_descriptor, bmm1_grad_gemm2_rhs_descriptor, - bmm2_grad_gemm1_lhs_descriptor, bmm2_grad_gemm2_rhs_descriptor, - d_output_descriptor, d_bmm1_lhs_descriptor, d_bmm1_rhs_descriptor, - d_bmm2_rhs_descriptor, bias_descriptor, dropout_rate, seed, scale, - use_dropout, bias_descriptor != std::nullopt, mask_type, - force_deterministic)); - - std::vector p_dims = - bmm2_grad_gemm1_lhs_descriptor.GetCudnnCompatibleDimensions(false); - intermediate_shape = p_dims; - int64_t dropout_rng_offset = GetDropoutRngOffset(intermediate_shape); - int64_t dropout_rng_seed = seed.has_value() ? *seed : 0; - - std::vector> uids; - uids = {CudnnfMHAUid::Q_ID, CudnnfMHAUid::K_ID, CudnnfMHAUid::P_ID, - CudnnfMHAUid::V_ID, CudnnfMHAUid::dO_ID, CudnnfMHAUid::dQ_ID, - CudnnfMHAUid::dK_ID, CudnnfMHAUid::dV_ID, std::nullopt}; - uids.emplace_back(d_bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::dBIAS_ID) - : std::nullopt); - uids.push_back(CudnnfMHAUid::O_ID); - uids.emplace_back(bias_descriptor.has_value() - ? std::optional(CudnnfMHAUid::BIAS_ID) - : std::nullopt); - bool is_padding = mask_type == dnn::FMHAMaskKind::PADDING || - mask_type == dnn::FMHAMaskKind::PADDING_CAUSAL; - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::Q_SEQLEN_ID) - : std::nullopt); - uids.emplace_back(is_padding - ? std::optional(CudnnfMHAUid::K_SEQLEN_ID) - : std::nullopt); - TF_ASSIGN_OR_RETURN(auto runner, - CudnnGraphRunner::Create( - parent_, cudnn_.get(), graph, dropout_rng_seed, - dropout_rng_offset, uids)); - return {std::make_unique>( - std::move(runner))}; -#else - return absl::UnimplementedError( - "Cudnn flash attention bwd are only " - "supported with Cudnn >= 8.9.4"); -#endif // CUDNN_VERSION >= 8904 -} - bool CudnnSupport::GetRnnAlgorithms( std::vector* out_algorithms) { PreloadCudnnSubLibs(PreloadCudnnType::Rnn); @@ -8353,11 +8082,16 @@ absl::StatusOr> CudnnSupport::DeserializeGraph( return std::make_unique(std::move(graph)); } -absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support) { +absl::Status CudnnGraph::Prepare(dnn::DnnSupport& dnn_support, + const NumericOptions& numeric_options) { const CudnnSupport& cudnn_support = static_cast(dnn_support); TF_ASSIGN_OR_RETURN(auto cudnn, cudnn_support.cudnn_->GetLocalHandle()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.validate()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.build_operation_graph(cudnn->handle())); + if (numeric_options.require_determinism) { + graph_.deselect_numeric_notes( + {cudnn_frontend::NumericalNote_t::NONDETERMINISTIC}); + } RETURN_IF_CUDNN_FRONTEND_ERROR( graph_.create_execution_plans({cudnn_frontend::HeurMode_t::A})); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.check_support(cudnn->handle())); @@ -8382,15 +8116,30 @@ absl::Status CudnnGraph::Execute(Stream& stream, std::unordered_map tensor_to_ptr_map; absl::Span operands_without_workspace = operands; DeviceMemoryBase workspace; - if (graph_.get_workspace_size() != 0) { + if (graph_.get_workspace_size() > 0) { workspace = operands.back(); CHECK_EQ(graph_.get_workspace_size(), workspace.size()); + } + if (graph_.get_workspace_size() > 0 || operands.back().size() == 0) { operands_without_workspace = operands.first(operands.size() - 1); } - int operand_number = 0; + auto next_uid = [uid = 0]() mutable -> int { return CuDnnTensorUID(uid++); }; for (DeviceMemoryBase operand : operands_without_workspace) { - tensor_to_ptr_map[CuDnnTensorUID(operand_number++)] = operand.opaque(); + tensor_to_ptr_map[next_uid()] = operand.opaque(); + } + + if (dropout_rng_offset_increment_ > 0) { +#if CUDNN_VERSION >= 8800 + tensor_to_ptr_map[next_uid()] = (void*)&dropout_rng_seed_; + current_dropout_rng_offset_ += dropout_rng_offset_increment_; + tensor_to_ptr_map[next_uid()] = (void*)¤t_dropout_rng_offset_; +#else + return absl::UnimplementedError( + "Cudnn dropout offset and seed are only supported with Cudnn >= " + "8.8.0"); +#endif // CUDNN_VERSION >= 8800 } + const CudnnSupport& dnn_support = static_cast(*stream.parent()->AsDnn()); RETURN_IF_CUDNN_FRONTEND_ERROR(graph_.execute( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h index 52086938d5a30f..24d84e369cb138 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_dnn.h @@ -60,7 +60,7 @@ class CudnnGraph : public dnn::DnnGraph { explicit CudnnGraph(cudnn_frontend::graph::Graph&& graph) : graph_(std::move(graph)) {} // Prepares a graph and checks whether it is generally supported. - absl::Status Prepare(dnn::DnnSupport&) override; + absl::Status Prepare(dnn::DnnSupport&, const NumericOptions&) override; // Builds single plan of the graph with given ID. absl::Status Build(dnn::DnnSupport&, std::optional plan_id) override; // Builds all the plans @@ -70,6 +70,9 @@ class CudnnGraph : public dnn::DnnGraph { private: cudnn_frontend::graph::Graph graph_; + int64_t dropout_rng_seed_; + mutable int64_t current_dropout_rng_offset_; + int64_t dropout_rng_offset_increment_ = 0; }; #endif // CUDNN_VERSION >= 8100 @@ -335,37 +338,6 @@ class CudnnSupport : public dnn::DnnSupport { std::optional dscale_descriptor, std::optional dbias_descriptor) override; - absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) override; - - absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& d_output_descriptor, - const dnn::TensorDescriptor& d_bmm1_lhs_descriptor, - const dnn::TensorDescriptor& d_bmm1_rhs_descriptor, - const dnn::TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - bool GetRnnAlgorithms( std::vector* out_algorithms) override; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 866c1ff7131462..e8e26e2c9de5ee 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -27,7 +28,6 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/base/const_init.h" -#include "absl/base/optimization.h" #include "absl/container/inlined_vector.h" #include "absl/debugging/leak_check.h" #include "absl/log/check.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h index 5c04ab6ccbee02..aefd89650fda0f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.h @@ -19,16 +19,13 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_ #include -#include #include -#include #include #include #include "absl/container/node_hash_map.h" #include "absl/log/check.h" #include "absl/log/log.h" -#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "third_party/gpus/cuda/include/cuda.h" #include "xla/stream_executor/cuda/cuda_status.h" diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc index 7cb402a91ca43a..ba855635f3ecdb 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/cuda/cuda_status.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "tsl/platform/status.h" @@ -49,7 +50,7 @@ TEST(CudaDriverTest, ScopedActivateContextTest) { CUcontext context0, context1; CHECK_CUDA(cuCtxCreate(&context0, 0, device)); CHECK_CUDA(cuCtxCreate(&context1, 0, device)); - GpuContext se_context1(context1, /*id=*/101); + GpuContext se_context1(context1, /*device_ordinal=*/101); { ScopedActivateContext scope(&se_context1); CUcontext c; @@ -68,4 +69,25 @@ TEST(CudaDriverTest, ScopedActivateContextTest) { } } // namespace gpu + +namespace cuda { + +TEST(CudaDriverTest, DriverVersionParsingTest) { + // Tests that the driver version can be right after 'Kernel Module', + // or later as well. + auto driver_version = Diagnostician::FindKernelModuleVersion( + "... NVIDIA UNIX Open Kernel Module for x86_64 570.00 Release Build " + "... Mon Aug 12 04:17:20 UTC 2024"); + TF_CHECK_OK(driver_version.status()); + EXPECT_EQ("570.0.0", cuda::DriverVersionToString(driver_version.value())); + + driver_version = Diagnostician::FindKernelModuleVersion( + "... NVIDIA UNIX Open Kernel Module 571.00 Release Build " + "... Mon Aug 12 04:17:20 UTC 2024"); + TF_CHECK_OK(driver_version.status()); + EXPECT_EQ("571.0.0", cuda::DriverVersionToString(driver_version.value())); +} + +} // namespace cuda + } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc index 63df37d3c037d9..8ae24775c7558f 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_executor.cc @@ -17,20 +17,17 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include #include -#include -#include "absl/base/casts.h" #include "absl/numeric/int128.h" -#include "absl/strings/str_join.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" @@ -46,7 +43,6 @@ limitations under the License. #include #endif -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -54,7 +50,6 @@ limitations under the License. #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" @@ -75,14 +70,12 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/gpu/gpu_timer.h" #include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/integrations/device_mem_allocator.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/plugin_registry.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/fingerprint.h" #include "tsl/platform/logging.h" @@ -216,21 +209,19 @@ absl::Status GpuExecutor::LoadModuleFromHsaco(const char* hsaco, "Feature not supported on CUDA platform (LoadModuleFromHsaco)"); } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* cuda_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto cuda_kernel = std::make_unique(this); CUmodule module; const std::string* kernel_name; - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - if (spec.has_cuda_cubin_in_memory()) { absl::MutexLock lock{&in_memory_modules_mu_}; kernel_name = &spec.cuda_cubin_in_memory().kernel_name(); const char* cubin = reinterpret_cast( spec.cuda_cubin_in_memory().cubin_bytes().data()); TF_RETURN_IF_ERROR(LoadModuleFromCuBin(cubin, &module)); - kernel_to_gpu_binary_[kernel] = cubin; + kernel_to_gpu_binary_[cuda_kernel.get()] = cubin; } else if (spec.has_cuda_ptx_in_memory()) { kernel_name = &spec.cuda_ptx_in_memory().kernel_name(); @@ -249,7 +240,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, absl::MutexLock lock{&in_memory_modules_mu_}; TF_RETURN_IF_ERROR(LoadModuleFromPtx(ptx, &module)); - kernel_to_gpu_binary_[kernel] = ptx; + kernel_to_gpu_binary_[cuda_kernel.get()] = ptx; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); @@ -260,35 +251,35 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); - *cuda_kernel->gpu_function_ptr() = function; + cuda_kernel->set_gpu_function(function); } else { return absl::InternalError("No method of loading CUDA kernel provided"); } - + VLOG(3) << "LoadKernel on kernel : " << *kernel_name; // If we resolved kernel from a symbol pointer, there is no need to load it // from a module, as CUDA runtime did that automatically for us. if (!spec.has_in_process_symbol()) { VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR( - GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), - cuda_kernel->gpu_function_ptr())); + GpuFunctionHandle function; + TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( + context_, module, kernel_name->c_str(), &function)); + cuda_kernel->set_gpu_function(function); } // Update CUDA kernel properties after it was loaded in the CUDA context. cuda_kernel->set_name(*kernel_name); - cuda_kernel->set_gpu_context(context_); // We have to trust the kernel loader spec arity because there doesn't appear // to be a way to reflect on the number of expected arguments w/the CUDA API. cuda_kernel->set_arity(spec.arity()); KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + TF_RETURN_IF_ERROR(GetKernelMetadata(cuda_kernel.get(), &kernel_metadata)); + cuda_kernel->set_metadata(kernel_metadata); + cuda_kernel->set_name(*kernel_name); + cuda_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(cuda_kernel); } absl::StatusOr> @@ -469,106 +460,16 @@ absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* cuda_kernel, KernelMetadata* kernel_metadata) { int value; TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - CU_FUNC_ATTRIBUTE_NUM_REGS, *cuda_kernel->gpu_function_ptr(), &value)); + CU_FUNC_ATTRIBUTE_NUM_REGS, cuda_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); TF_RETURN_IF_ERROR( GpuDriver::FuncGetAttribute(CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - *cuda_kernel->gpu_function_ptr(), &value)); + cuda_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); } -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(stream, thread_dims, block_dims, std::nullopt, kernel, args); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - return Launch(stream, thread_dims, block_dims, - std::make_optional(cluster_dims), kernel, args); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - CUstream custream = AsGpuStreamValue(stream); - const GpuKernel* cuda_kernel = AsGpuKernel(&kernel); - CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle(); - - if (cuda_kernel->cache_config() != KernelCacheConfig::kNoPreference) { - TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( - cufunc, cuda_kernel->GetGpuCacheConfig())); - } - - // Launch CUDA kernels with packed arguments. - auto launch = [&](const KernelArgsPackedArrayBase& packed) { - int32_t expected_number_of_arguments = - kernel.Arity() + (packed.number_of_shared_bytes() > 0); - - CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) - << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() - << " arguments, but expected " << expected_number_of_arguments - << "; arity=" << kernel.Arity() - << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); - - void** params = const_cast(packed.argument_addresses().data()); - - if (cluster_dims.has_value()) { - return GpuDriver::LaunchKernel( - context_, kernel.name(), cufunc, cluster_dims->x, cluster_dims->y, - cluster_dims->z, block_dims.x, block_dims.y, block_dims.z, - thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), custream, params, - /*extra=*/nullptr); - } else { - return GpuDriver::LaunchKernel( - context_, kernel.name(), cufunc, block_dims.x, block_dims.y, - block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - packed.number_of_shared_bytes(), custream, params, - /*extra=*/nullptr); - } - }; - - // If arguments are already packed we can just launch the kernel. - if (auto* packed = DynCast(&args)) { - return launch(*packed); - } - - // For device memory array we rely on a custom kernel arguments packing. - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); - return launch(*packed); - } - - return absl::InternalError("Unsupported kernel arguments type"); -} - -absl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { - return absl::InvalidArgumentError( - "Can't submit non-primary command buffer for execution"); - } - - auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer executable graph " << exec - << " on a stream: " << stream; - return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); -} - DeviceMemoryBase GpuExecutor::Allocate(uint64_t size, int64_t memory_space) { if (memory_space == 1) { auto result = GpuCollectives::CollectiveMemoryAllocate(context_, size); @@ -617,16 +518,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsCudaDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) { - VLOG(2) << "enqueueing memset8 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - return GpuDriver::AsynchronousMemsetUint8(context_, AsCudaDevicePtr(location), - pattern, size, - AsGpuStreamValue(stream)); -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); @@ -634,13 +525,9 @@ void GpuExecutor::DeallocateStream(Stream* stream) { dnn_->NotifyStreamDestroyed(stream); } } - GpuStream* cuda_stream = AsGpuStream(stream); + GpuStream* gpu_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(cuda_stream->platform_specific_stream()); - if (!cuda_stream->IsIdle()) { - LOG(ERROR) << "Deallocating stream with pending work"; - } - cuda_stream->Destroy(); + alive_gpu_streams_.erase(gpu_stream->gpu_stream()); } absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { @@ -774,27 +661,13 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto gpu_stream = std::make_unique(this); - if (priority.has_value()) { - if (std::holds_alternative(*priority)) { - gpu_stream->SetPriority(std::get(*priority)); - } else { - gpu_stream->SetPriority(std::get(*priority)); - } - } + TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + auto stream = std::make_unique(this, std::move(event), priority); absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = gpu_stream->Init(); - if (init_worked) { - auto platform_specific_stream = gpu_stream->platform_specific_stream(); - alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); - return std::move(gpu_stream); - } else { - return absl::InvalidArgumentError("Failed to initialize gpu stream"); - } -} - -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); + TF_RETURN_IF_ERROR(stream->Init()); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } absl::StatusOr> GpuExecutor::CreateCommandBuffer( diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc index bdace571118435..83dab87a0c6c85 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.cc @@ -15,19 +15,14 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_platform.h" -#include -#include -#include #include #include #include -#include "absl/base/call_once.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "xla/stream_executor/cuda/cuda_platform_id.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/gpu/gpu_driver.h" @@ -35,65 +30,16 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace stream_executor { namespace gpu { -CudaPlatform::CudaPlatform() - : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {} +CudaPlatform::CudaPlatform() : name_("CUDA") {} CudaPlatform::~CudaPlatform() {} -// Due to legacy issues in user code, we can't currently call InpectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void CudaPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - static absl::once_flag once; - absl::call_once(once, [&] { - for (int i = 0; i < VisibleDeviceCount(); i++) { - StreamExecutor* exec = *ExecutorForDevice(i); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int CudaPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int CudaPlatform::DeviceToBus(int device_ordinal) { - StreamExecutor* exec = *ExecutorForDevice(device_ordinal); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr CudaPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - return *ExecutorForDevice(i); - } - } - - return absl::NotFoundError( - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); -} - Platform::Id CudaPlatform::id() const { return cuda::kCudaPlatformId; } int CudaPlatform::VisibleDeviceCount() const { @@ -113,44 +59,26 @@ CudaPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr CudaPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.GetOrCreate( + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } -absl::StatusOr CudaPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } - return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); +absl::StatusOr CudaPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } absl::StatusOr> -CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for CUDA device ordinal %d: %s", - config.ordinal, init_status.ToString())); - } - +CudaPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeCudaPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. - - std::unique_ptr platform(new gpu::CudaPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK( + PlatformManager::RegisterPlatform(std::make_unique())); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h index 153282b26507e6..e4ba806343f091 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_platform.h +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform.h @@ -22,7 +22,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/executor_cache.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { @@ -41,16 +40,6 @@ class CudaPlatform : public Platform { CudaPlatform(); ~CudaPlatform() override; - // CudaPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kCudaPlatform above. Platform::Id id() const override; @@ -64,32 +53,21 @@ class CudaPlatform : public Platform { int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; + absl::StatusOr FindExisting(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - + // Returns a device constructed with the ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; + int ordinal); private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); - // This platform's name. std::string name_; // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./ - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - CudaPlatform(const CudaPlatform&) = delete; void operator=(const CudaPlatform&) = delete; }; diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc new file mode 100644 index 00000000000000..b9621f76aee349 --- /dev/null +++ b/third_party/xla/xla/stream_executor/cuda/cuda_platform_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/cuda/cuda_platform.h" + +#include +#include "absl/container/flat_hash_map.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/platform_manager.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor::gpu { +namespace { + +TEST(CudaPlatformTest, FindExistingWorks) { + TF_ASSERT_OK_AND_ASSIGN(Platform * platform, + PlatformManager::PlatformWithName("CUDA")); + CHECK_GT(platform->VisibleDeviceCount(), 0); + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + EXPECT_FALSE(platform->FindExisting(i).ok()); + } + absl::flat_hash_map executors; + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->ExecutorForDevice(i)); + executors[i] = executor; + } + EXPECT_EQ(executors.size(), platform->VisibleDeviceCount()); + for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { + TF_ASSERT_OK_AND_ASSIGN(auto executor, platform->FindExisting(i)); + EXPECT_EQ(executor, executors[i]); + } +} + +} // namespace +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h index aa59af500ba7a3..0a30c1af59c0c4 100644 --- a/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h +++ b/third_party/xla/xla/stream_executor/cuda/cudnn_frontend_helpers.h @@ -29,6 +29,11 @@ namespace gpu { } \ } while (false) +// UIDs for cuDNN are unique identifiers of tensors within a graph. They are +// assigned during graph construction; then graph execution takes a {uid: +// buffer pointer} map defining the correspondance of buffers to tensors. +// UID assignment scheme can be arbitrary; at the moment for simplicity XLA uses +// a scheme UID = (HLO operand number + 1). int CuDnnTensorUID(int offset); } // namespace gpu diff --git a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h index 09aad2f6e85a67..016639d0ba2136 100644 --- a/third_party/xla/xla/stream_executor/cuda/delay_kernel.h +++ b/third_party/xla/xla/stream_executor/cuda/delay_kernel.h @@ -18,7 +18,6 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_stream.h" #include "xla/stream_executor/stream.h" namespace stream_executor::gpu { diff --git a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc index 18036973e6d145..bedd416c3cf8d5 100644 --- a/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc +++ b/third_party/xla/xla/stream_executor/cuda/nvjitlink_support.cc @@ -16,5 +16,7 @@ limitations under the License. #include "xla/stream_executor/cuda/nvjitlink_support.h" namespace stream_executor { -bool IsLibNvJitLinkSupported() { return LIBNVJITLINK_SUPPORT; } +bool IsLibNvJitLinkSupported() { + return LIBNVJITLINK_SUPPORT && CUDA_SUPPORTS_NVJITLINK; +} } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h index d6e28e96b69d67..12d5ae5a4d9d7e 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compilation_method.h @@ -22,6 +22,7 @@ limitations under the License. namespace stream_executor { enum class PtxCompilationMethod { + kNvJitLink, kNvPtxCompiler, kPtxas, }; @@ -30,6 +31,9 @@ template static void AbslStringify(Sink& sink, const PtxCompilationMethod& compilation_method) { switch (compilation_method) { + case PtxCompilationMethod::kNvJitLink: + sink.Append("NvJitLink"); + break; case PtxCompilationMethod::kNvPtxCompiler: sink.Append("NvPtxCompiler"); break; diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc index c2958332c154c3..aae94067af0ceb 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc +++ b/third_party/xla/xla/stream_executor/cuda/ptx_compiler_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include #include diff --git a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h index 56dcdf1fa53d54..aafc36d64fa117 100644 --- a/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h +++ b/third_party/xla/xla/stream_executor/cuda/ptx_linking_method.h @@ -26,11 +26,15 @@ enum class PtxLinkingMethod { kNone, kNvLink, kDriver, + kNvJitLink, }; template void AbslStringify(Sink& sink, const PtxLinkingMethod& method) { switch (method) { + case PtxLinkingMethod::kNvJitLink: + sink.Append("NvJitLink"); + break; case PtxLinkingMethod::kNvLink: sink.Append("NvLink"); break; diff --git a/third_party/xla/xla/stream_executor/dnn.cc b/third_party/xla/xla/stream_executor/dnn.cc index 5a674a05e175c2..951b2f6e147cd8 100644 --- a/third_party/xla/xla/stream_executor/dnn.cc +++ b/third_party/xla/xla/stream_executor/dnn.cc @@ -41,7 +41,7 @@ limitations under the License. #include "xla/stream_executor/data_type.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/numeric_options.h" -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/protobuf/dnn.pb.h" @@ -249,42 +249,6 @@ DnnSupport::NormRunnerFromDesc( return absl::UnimplementedError("NormRunnerFromDesc not implemented."); } -absl::StatusOr> -DnnSupport::FusedMHARunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const dnn::MatmulTensorDescriptor& bmm1_lhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm1_rhs_descriptor, - const dnn::MatmulTensorDescriptor& bmm2_rhs_descriptor, - const dnn::MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const dnn::TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type) { - return absl::UnimplementedError("FusedMHARunnerFromDesc not implemented."); -} - -absl::StatusOr> -DnnSupport::FusedMHABackwardRunnerFromDesc( - Stream* stream, const dnn::AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic) { - return absl::UnimplementedError( - "FusedMHABackwardRunnerFromDesc not implemented."); -} - bool DnnSupport::GetMIOpenConvolveAlgorithms( dnn::ConvolutionKind /*kind*/, dnn::DataType /*element_type*/, dnn::DataType /*output_type*/, Stream* /*stream*/, diff --git a/third_party/xla/xla/stream_executor/dnn.h b/third_party/xla/xla/stream_executor/dnn.h index 72f4603b4d3a04..a2e1cd629dc2b4 100644 --- a/third_party/xla/xla/stream_executor/dnn.h +++ b/third_party/xla/xla/stream_executor/dnn.h @@ -993,30 +993,6 @@ using FusedMatmulRunner = OpRunner; using NormSignature = void(std::vector); using NormRunner = OpRunner; -using FusedMHASignature = void(DeviceMemoryBase /*BMM1_inputA_data*/, - DeviceMemoryBase /* BMM1_inputB_data */, - DeviceMemoryBase /* BMM2_inputA_data */, - DeviceMemoryBase /* output_data */, - DeviceMemoryBase /* bias_data */, - DeviceMemoryBase /* activation_data */, - DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHARunner = OpRunner; - -using FusedMHABackwardSignature = void( - DeviceMemoryBase /* BMM1_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM1_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM1_inputA_data */, - DeviceMemoryBase /* BMM2_GRAD_GEMM2_inputB_data */, - DeviceMemoryBase /* d_output_data */, - DeviceMemoryBase /* d_BMM1_inputA_data */, - DeviceMemoryBase /* d_BMM1_inputB_data */, - DeviceMemoryBase /* d_BMM2_inputB_data */, DeviceMemoryBase /* d_S_data */, - DeviceMemoryBase /* d_bias_data */, DeviceMemoryBase /* fwd_output_data */, - DeviceMemoryBase /* bias_data */, DeviceMemoryBase /* seqlen_q_data */, - DeviceMemoryBase /* seqlen_k_data */); -using FusedMHABackwardRunner = OpRunner; - // Describes the configuration for the algorithms that will used. // // Arguments: @@ -1257,11 +1233,7 @@ class DnnGraph { DnnGraph() = default; virtual ~DnnGraph() = default; - // Returns non-OK status on hard failures (incorrectly constructed graph, - // anything else unexpected), - // false on expected ones (graph is valid but not supported), - // true on success. - virtual absl::Status Prepare(DnnSupport&) = 0; + virtual absl::Status Prepare(DnnSupport&, const NumericOptions&) = 0; virtual absl::Status Build(DnnSupport&, std::optional plan_id) = 0; virtual absl::Status Execute(Stream& stream, absl::Span operands) const = 0; @@ -1735,37 +1707,6 @@ class DnnSupport { return absl::UnimplementedError("Graph support requires cuDNN >= 8.1."); }; - virtual absl::StatusOr> - FusedMHARunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_rhs_descriptor, - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor, - const TensorDescriptor& output_descriptor, - std::optional activation_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type); - - virtual absl::StatusOr> - FusedMHABackwardRunnerFromDesc( - Stream* stream, const AlgorithmDesc& algorithm_desc, - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor, - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor, - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor, - const MatmulTensorDescriptor& d_output_descriptor, - const TensorDescriptor& d_bmm1_lhs_descriptor, - const TensorDescriptor& d_bmm1_rhs_descriptor, - const TensorDescriptor& d_bmm2_rhs_descriptor, - std::optional d_s_descriptor, - std::optional d_bias_descriptor, - std::optional fwd_output_descriptor, - std::optional bias_descriptor, double scale, - std::optional dropout_rate, std::optional seed, - dnn::FMHAMaskKind mask_type, bool force_deterministic); - virtual bool GetMIOpenConvolveAlgorithms( ConvolutionKind kind, DataType element_type, DataType output_type, Stream* stream, const BatchDescriptor& input_descriptor, diff --git a/third_party/xla/xla/stream_executor/executor_cache.cc b/third_party/xla/xla/stream_executor/executor_cache.cc index eae72060f0c04c..1fcfd6b847f907 100644 --- a/third_party/xla/xla/stream_executor/executor_cache.cc +++ b/third_party/xla/xla/stream_executor/executor_cache.cc @@ -22,103 +22,41 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" +#include "tsl/platform/statusor.h" namespace stream_executor { ExecutorCache::ExecutorCache() = default; -ExecutorCache::~ExecutorCache() { DestroyAllExecutors(); } +ExecutorCache::~ExecutorCache() = default; absl::StatusOr ExecutorCache::GetOrCreate( - const StreamExecutorConfig& config, const ExecutorFactory& factory) { + int ordinal, const ExecutorFactory& factory) { // In the fast path case, the cache already has an entry and we can just // return after Get() which only takes a shared lock and not a unique lock. // If we need to create, we take a unique lock on cache_. - if (auto fast_result = Get(config); fast_result.ok()) { + if (auto fast_result = Get(ordinal); fast_result.ok()) { return fast_result; } - Entry* entry = nullptr; - { - absl::MutexLock lock{&mutex_}; - entry = &cache_[config.ordinal]; - // Release the map lock; the address of 'entry' is stable because - // absl::node_hash_map guarantees reference stability. - } - - // Acquire the per-Entry mutex without holding the map mutex. Initializing - // an Executor may be expensive, so we want to allow concurrent - // initialization of different entries. - absl::MutexLock lock{&entry->configurations_mutex}; - for (const auto& iter : entry->configurations) { - VLOG(2) << "hit in cache"; - return iter.second.get(); - } - VLOG(2) << "building executor"; - absl::StatusOr> result = factory(); - if (!result.ok()) { - VLOG(2) << "failed to get build executor: " << result.status(); - // If construction failed, leave the cache Entry around, but with a null - // executor. - return result.status(); - } - entry->configurations.emplace_back(config, std::move(result.value())); - return entry->configurations.back().second.get(); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, factory()); + auto returned_executor = result.get(); + absl::MutexLock lock(&mutex_); + cache_.emplace(ordinal, std::move(result)); + return returned_executor; } -absl::StatusOr ExecutorCache::Get( - const StreamExecutorConfig& config) { - Entry* entry = nullptr; - { - absl::ReaderMutexLock lock{&mutex_}; - - // If gpu stream is not nullptr we have to find StreamExecutor that owns it, - // and return NOT_FOUND error if we can't find it. - if (config.gpu_stream) { - for (auto& [ordinal, e] : cache_) { - absl::ReaderMutexLock l{&e.configurations_mutex}; - for (auto& [c, executor] : e.configurations) { - if (executor->FindAllocatedStream(config.gpu_stream)) { - return executor.get(); - } - } - } - return absl::NotFoundError( - absl::StrFormat("No executors own stream %p", config.gpu_stream)); - } - - if (auto it = cache_.find(config.ordinal); it != cache_.end()) { - entry = &it->second; - } else { - return absl::NotFoundError(absl::StrFormat( - "No executors registered for ordinal %d", config.ordinal)); - } - } - - absl::ReaderMutexLock lock{&entry->configurations_mutex}; - if (entry->configurations.empty()) { - return absl::NotFoundError(absl::StrFormat( - "No executors registered for ordinal %d", config.ordinal)); - } +absl::StatusOr ExecutorCache::Get(int ordinal) { + absl::ReaderMutexLock lock{&mutex_}; - for (auto& [entry_config, entry_executor] : entry->configurations) { - return entry_executor.get(); + if (auto it = cache_.find(ordinal); it != cache_.end()) { + return it->second.get(); } - return absl::NotFoundError("No executor found with a matching config."); -} - -void ExecutorCache::DestroyAllExecutors() { - absl::MutexLock lock{&mutex_}; - cache_.clear(); -} - -ExecutorCache::Entry::~Entry() { - absl::MutexLock lock{&configurations_mutex}; - configurations.clear(); + return absl::NotFoundError( + absl::StrFormat("No executors registered for ordinal %d", ordinal)); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/executor_cache.h b/third_party/xla/xla/stream_executor/executor_cache.h index 6e7f32e487cd1a..d4cf4b5e31441d 100644 --- a/third_party/xla/xla/stream_executor/executor_cache.h +++ b/third_party/xla/xla/stream_executor/executor_cache.h @@ -18,20 +18,15 @@ limitations under the License. #include #include -#include -#include #include "absl/base/thread_annotations.h" -#include "absl/container/node_hash_map.h" +#include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "absl/synchronization/mutex.h" -#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" namespace stream_executor { -// Forward declare. -class StreamExecutor; - // Utility class to allow Platform objects to manage cached StreamExecutors. // Thread-safe. class ExecutorCache { @@ -42,43 +37,23 @@ class ExecutorCache { ExecutorCache(); ~ExecutorCache(); - // Looks up 'config' in the cache. Returns a pointer to the existing executor, - // if already present, or creates it using 'factory', if it does not. - // Factories may be executed concurrently for different device ordinals. - absl::StatusOr GetOrCreate( - const StreamExecutorConfig& config, const ExecutorFactory& factory); + // Looks up 'ordinal' in the cache. Returns a pointer to the existing + // executor, if already present, or creates it using 'factory', if it does + // not. Factories may be executed concurrently for different device ordinals. + absl::StatusOr GetOrCreate(int ordinal, + const ExecutorFactory& factory); - // Returns a pointer to the described executor (if one with a matching config + // Returns a pointer to the described executor (if one with a matching ordinal // has been created), or a NOT_FOUND status. - absl::StatusOr Get(const StreamExecutorConfig& config); - - // Destroys all Executors and clears the cache. - // Performs no synchronization with the executors - undefined behavior may - // occur if any executors are active! - void DestroyAllExecutors(); + absl::StatusOr Get(int ordinal); private: - // Each Entry contains zero or more cached executors for a device ordinal. - struct Entry { - ~Entry(); - - // Mutex that guards the contents of each entry. The 'mutex_' of the - // ExecutorCache class protects both the 'cache_' and the existence of each - // Entry, but not the Entry's contents. 'configurations_mutex' protects the - // contents of the entry after 'mutex_' has been dropped. - absl::Mutex configurations_mutex; - - // Vector of cached {config, executor} pairs. - std::vector< - std::pair>> - configurations ABSL_GUARDED_BY(configurations_mutex); - }; - - // Maps ordinal number to a list of cached executors for that ordinal. - // We key off of ordinal (instead of just looking up all fields in the - // StreamExecutorConfig) for a slight improvement in lookup time. + // Protects cache_. absl::Mutex mutex_; - absl::node_hash_map cache_ ABSL_GUARDED_BY(mutex_); + + // Maps ordinal number to a cached executor for that ordinal. + absl::flat_hash_map> cache_ + ABSL_GUARDED_BY(mutex_); ExecutorCache(const ExecutorCache&) = delete; void operator=(const ExecutorCache&) = delete; diff --git a/third_party/xla/xla/stream_executor/executor_cache_test.cc b/third_party/xla/xla/stream_executor/executor_cache_test.cc new file mode 100644 index 00000000000000..84bed1ecaf576b --- /dev/null +++ b/third_party/xla/xla/stream_executor/executor_cache_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/executor_cache.h" + +#include + +#include "absl/log/log.h" +#include "xla/stream_executor/mock_stream_executor.h" +#include "xla/stream_executor/stream.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace stream_executor { +namespace { + +TEST(ExecutorCacheTest, GetOnEmptyCacheFails) { + ExecutorCache cache; + EXPECT_FALSE(cache.Get(0).ok()); +} + +TEST(ExecutorCacheTest, GetReturnsExpectedExecutor) { + ExecutorCache cache; + StreamExecutor *executor0 = nullptr; + StreamExecutor *executor1 = nullptr; + auto factory = [&executor0, &executor1]() { + auto executor = std::make_unique(); + if (executor0 == nullptr) { + executor0 = executor.get(); + } else if (executor1 == nullptr) { + executor1 = executor.get(); + } else { + LOG(FATAL) << "Bad call to factory."; + } + return executor; + }; + TF_ASSERT_OK_AND_ASSIGN(auto found, cache.GetOrCreate(0, factory)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(1, factory)); + EXPECT_EQ(found, executor1); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(0, factory)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.GetOrCreate(1, factory)); + EXPECT_EQ(found, executor1); + TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(0)); + EXPECT_EQ(found, executor0); + TF_ASSERT_OK_AND_ASSIGN(found, cache.Get(1)); + EXPECT_EQ(found, executor1); +} + +} // namespace +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/gpu/BUILD b/third_party/xla/xla/stream_executor/gpu/BUILD index bd2b5b0cbcd8f2..f7447852db9d96 100644 --- a/third_party/xla/xla/stream_executor/gpu/BUILD +++ b/third_party/xla/xla/stream_executor/gpu/BUILD @@ -220,7 +220,6 @@ gpu_only_cc_library( "//xla/stream_executor:host_memory_allocation", "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", - "//xla/stream_executor:launch_dim", "//xla/stream_executor:memory_allocation", "//xla/stream_executor:module_spec", "//xla/stream_executor:platform", @@ -228,7 +227,6 @@ gpu_only_cc_library( "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -243,8 +241,7 @@ gpu_only_cc_library( name = "gpu_helpers_header", hdrs = ["gpu_helpers.h"], deps = [ - ":gpu_types_header", - "@local_tsl//tsl/platform:logging", + "//xla/stream_executor:device_memory", ], ) @@ -308,11 +305,14 @@ gpu_only_cc_library( name = "gpu_stream_header", hdrs = ["gpu_stream.h"], deps = [ + ":gpu_event_header", ":gpu_executor_header", ":gpu_types_header", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -330,10 +330,13 @@ gpu_only_cc_library( ":gpu_driver_header", ":gpu_event_header", ":gpu_executor_header", + ":gpu_kernel_header", ":gpu_types_header", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:platform", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", @@ -344,6 +347,7 @@ gpu_only_cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/lib:nvtx_utils", ], ) @@ -376,7 +380,6 @@ gpu_only_cc_library( ":gpu_stream", ":gpu_types_header", "//xla/stream_executor", - "//xla/stream_executor:event", "//xla/stream_executor:event_based_timer", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", @@ -385,7 +388,6 @@ gpu_only_cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@com_google_absl//absl/utility", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", ], @@ -555,8 +557,6 @@ xla_test( name = "redzone_allocator_test", srcs = ["redzone_allocator_test.cc"], backends = ["gpu"], - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ ":gpu_asm_opts", ":gpu_init", @@ -564,10 +564,10 @@ xla_test( "//xla/stream_executor", "//xla/stream_executor:device_memory_allocator", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -618,11 +618,7 @@ xla_test( name = "gpu_cudamallocasync_allocator_test", srcs = ["gpu_cudamallocasync_allocator_test.cc"], backends = ["gpu_any"], - tags = [ - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - "gpu", - "no_rocm", - ], + tags = ["no_rocm"], deps = [ ":gpu_cudamallocasync_allocator", ":gpu_stream", @@ -675,40 +671,89 @@ cc_library( gpu_kernel_library( name = "gpu_test_kernels", testonly = 1, - srcs = if_gpu_is_configured(["gpu_test_kernels.cu.cc"]), - hdrs = if_gpu_is_configured(["gpu_test_kernels.h"]), + srcs = ["gpu_test_kernels.cu.cc"], + hdrs = ["gpu_test_kernels.h"], + tags = ["gpu"], deps = if_cuda_is_configured([ "@local_config_cuda//cuda:cuda_headers", ]) + if_rocm_is_configured([ "@local_config_rocm//rocm:rocm_headers", - "//xla/stream_executor/rocm:add_i32_kernel", ]), ) +genrule( + name = "gpu_test_kernels_fatbin_extractor", + testonly = True, + srcs = [":gpu_test_kernels"], + outs = ["gpu_test_kernels.fatbin"], + cmd = """ + STATIC_LIBRARY="" + for src in $(SRCS); do + if [[ $$src == *.a ]]; then + STATIC_LIBRARY=$$src + break + fi + done + + if [[ -z $$STATIC_LIBRARY ]]; then + echo "No static library found in $(SRCS)" >&2 + exit 1 + fi + + $(OBJCOPY) "--dump-section=.nv_fatbin=$@" "$$STATIC_LIBRARY" || true + + if [ ! -f "$@" ]; then + # binutils' objcopy doesn't return a non-zero exit code if the + # section was not found, so we need to check for the file's existence instead. + $(OBJCOPY) "--dump-section=.hip_fatbin=$@" "$$STATIC_LIBRARY" + fi + """, + tags = ["gpu"], + toolchains = ["@bazel_tools//tools/cpp:current_cc_toolchain"], +) + +cc_library( + name = "gpu_test_kernels_fatbin", + testonly = True, + srcs = ["gpu_test_kernels_fatbin.cc"], + hdrs = ["gpu_test_kernels_fatbin.h"], + data = [":gpu_test_kernels_fatbin_extractor"], + local_defines = [ + "FATBIN_SRC=\\\"$(rootpath :gpu_test_kernels_fatbin_extractor)\\\"", + ], + tags = ["gpu"], + deps = [ + "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + ], +) + xla_test( name = "gpu_kernel_test", - srcs = if_gpu_is_configured(["gpu_kernel_test.cc"]), + srcs = ["gpu_kernel_test.cc"], backends = ["gpu"], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]), deps = [ ":gpu_test_kernels", + ":gpu_test_kernels_fatbin", "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:typed_kernel_factory", + "//xla/stream_executor/rocm:rocm_platform_id", + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", - ] + if_cuda([ - "//xla/stream_executor/cuda:cuda_platform", - ]) + if_rocm([ - "//xla/stream_executor/rocm:rocm_platform", - ]), + ], ) xla_test( @@ -729,11 +774,11 @@ xla_test( "//xla/stream_executor:platform_manager", "//xla/stream_executor:trace_command_buffer_factory", "//xla/stream_executor:typed_kernel_factory", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -754,13 +799,11 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/stream_executor", "//xla/stream_executor:device_memory", "//xla/stream_executor:platform_manager", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -779,10 +822,9 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/stream_executor", + "//xla/stream_executor:stream_finder", "//xla/stream_executor/host:host_platform", "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:statusor", @@ -803,8 +845,6 @@ xla_test( local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ "TENSORFLOW_USE_ROCM=1", ]), - # TODO(b/317293391) Remove once Bazel test_suite handles tags correctly - tags = ["gpu"], deps = [ "//xla/service:platform_util", "//xla/stream_executor:platform", @@ -833,6 +873,7 @@ xla_test( "//xla/tools/hlo_opt:gpu_specs/a6000.txtpb", "//xla/tools/hlo_opt:gpu_specs/h100_pcie.txtpb", "//xla/tools/hlo_opt:gpu_specs/h100_sxm.txtpb", + "//xla/tools/hlo_opt:gpu_specs/mi200.txtpb", "//xla/tools/hlo_opt:gpu_specs/p100.txtpb", "//xla/tools/hlo_opt:gpu_specs/v100.txtpb", ]) + if_rocm_is_configured([ @@ -847,14 +888,15 @@ xla_test( 'PLATFORM_NAME=\\"ROCM\\"' ]), deps = [ + "//xla/service:platform_util", "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", "//xla/stream_executor:stream_executor_h", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:protobuf", diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index b4bf7ebd46d8dc..20caccbe18e62e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "xla/stream_executor/blas.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host_or_device_scalar.h" diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc index a0334695552915..c672b11a8d2760 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.cc @@ -340,7 +340,7 @@ absl::StatusOr GpuCommandBuffer::CreateBarrierNode( TF_RETURN_IF_ERROR(GpuDriver::GraphAddKernelNode( &barrier_handle, graph_, dependencies, "noop", - AsGpuKernel(&**noop)->AsGpuFunctionHandle(), 1, 1, 1, 1, 1, 1, 0, + AsGpuKernel(&**noop)->gpu_function(), 1, 1, 1, 1, 1, 1, 0, /*kernel_params=*/nullptr, /*extra=*/nullptr)); #else TF_RETURN_IF_ERROR( @@ -524,7 +524,7 @@ absl::Status GpuCommandBuffer::LaunchWithPackedArgs( packed_args.number_of_arguments()); const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); - GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle(); + GpuFunctionHandle gpu_func = gpu_kernel->gpu_function(); void** kernel_params = const_cast(packed_args.argument_addresses().data()); @@ -1006,6 +1006,8 @@ absl::Status GpuCommandBuffer::Finalize() { } else { TF_RETURN_IF_ERROR(retry); } + } else { + TF_RETURN_IF_ERROR(instantiated); } uint64_t end_nanos = tsl::Env::Default()->NowNanos(); @@ -1073,4 +1075,15 @@ GpuCommandBuffer::barriers(ExecutionScopeId id) const { return {}; } +absl::Status GpuCommandBuffer::Submit(Stream* stream) { + if (mode_ != CommandBuffer::Mode::kPrimary) { + return absl::InvalidArgumentError( + "Can't submit non-primary command buffer for execution"); + } + + VLOG(3) << "Launch command buffer executable graph " << exec_ + << " on a stream: " << stream; + return GpuDriver::GraphLaunch(exec_, AsGpuStreamValue(stream)); +} + } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h index 2808fe6364c047..0b33d340363e24 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer.h @@ -123,6 +123,7 @@ class GpuCommandBuffer : public CommandBuffer { absl::Status Finalize() override; absl::Status Update() override; + absl::Status Submit(Stream* stream) override; GpuGraphExecHandle executable() const { return exec_; } GpuGraphHandle graph() const { return graph_; } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc index ef31559eefc5bd..6852cebf9a1014 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_command_buffer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "xla/service/platform_util.h" #include "xla/stream_executor/command_buffer.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" #include "xla/stream_executor/gpu/gpu_types.h" // IWYU pragma: keep @@ -37,7 +38,7 @@ limitations under the License. #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/trace_command_buffer_factory.h" #include "xla/stream_executor/typed_kernel_factory.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -60,11 +61,7 @@ static Platform* GpuPlatform() { static MultiKernelLoaderSpec GetAddI32KernelSpec() { MultiKernelLoaderSpec spec(/*arity=*/3); -#if defined(GOOGLE_CUDA) - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); -#elif defined(TENSORFLOW_USE_ROCM) - spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); -#endif + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); return spec; } @@ -113,7 +110,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -133,7 +130,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, c)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -151,7 +148,7 @@ TEST(GpuCommandBufferTest, LaunchSingleKernel) { TF_ASSERT_OK(cmd_buffer->Launch(add, ThreadDim(), BlockDim(4), a, b, d)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -183,7 +180,7 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { cast(bufs[2]), }); }); - spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Ptrs3Kernel(), "AddI32Ptrs3"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Ptrs3::Create(executor, spec)); @@ -203,15 +200,16 @@ TEST(CudaCommandBufferTest, TraceSingleKernel) { KernelArgsDeviceMemoryArray args({a, b, c}, 0); // Create a command buffer by tracing kernel launch operations. - auto cmd_buffer = TraceCommandBufferFactory::Create( - executor, - [&](Stream* stream) { - return stream->Launch(ThreadDim(), BlockDim(4), *add, args); - }, - primary); + TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, TraceCommandBufferFactory::Create( + executor, + [&](Stream* stream) { + return stream->Launch( + ThreadDim(), BlockDim(4), + *add, args); + }, + primary)); - TF_ASSERT_OK(cmd_buffer.status()); - TF_ASSERT_OK(executor->Submit(stream.get(), **cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy data back to host. std::vector dst(4, 42); @@ -249,7 +247,7 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + TF_ASSERT_OK(primary_cmd->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -270,7 +268,7 @@ TEST(GpuCommandBufferTest, LaunchNestedCommandBuffer) { TF_ASSERT_OK(primary_cmd->AddNestedCommandBuffer(*nested_cmd)); TF_ASSERT_OK(primary_cmd->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *primary_cmd)); + TF_ASSERT_OK(primary_cmd->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -298,7 +296,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { TF_ASSERT_OK(cmd_buffer->MemcpyDeviceToDevice(&b, a, byte_length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 0); @@ -315,7 +313,7 @@ TEST(GpuCommandBufferTest, MemcpyDeviceToDevice) { // Clear destination to test that command buffer actually copied memory. TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `a` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -339,7 +337,7 @@ TEST(GpuCommandBufferTest, Memset) { TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{42}, length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `a` data back to host. std::vector dst(4, 0); @@ -353,7 +351,7 @@ TEST(GpuCommandBufferTest, Memset) { TF_ASSERT_OK(cmd_buffer->Memset(&a, uint32_t{43}, length)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 0); @@ -408,7 +406,7 @@ TEST(GpuCommandBufferTest, Barriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47}; ASSERT_EQ(transfer_buffers(), expected); @@ -445,7 +443,7 @@ TEST(GpuCommandBufferTest, Barriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -488,7 +486,7 @@ TEST(GpuCommandBufferTest, IndependentExecutionScopes) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45}; ASSERT_EQ(transfer_buffers(), expected); @@ -515,7 +513,7 @@ TEST(GpuCommandBufferTest, IndependentExecutionScopes) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46}; ASSERT_EQ(transfer_buffers(), expected); @@ -562,7 +560,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -607,7 +605,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeBarriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48, 49}; ASSERT_EQ(transfer_buffers(), expected); @@ -652,7 +650,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44, 45, 46, 47}; ASSERT_EQ(transfer_buffers(), expected); @@ -683,7 +681,7 @@ TEST(GpuCommandBufferTest, ExecutionScopeOneDirectionalBarriers) { // Update command buffer to use a new bit pattern. TF_ASSERT_OK(cmd_buffer->Update()); TF_ASSERT_OK(record(cmd_buffer.get(), 43)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {43, 44, 45, 46, 47, 48}; ASSERT_EQ(transfer_buffers(), expected); @@ -700,7 +698,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -728,7 +726,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `c` data back to host. std::vector dst(4, 42); @@ -744,7 +742,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { // Submit the same command buffer, but this time it should not execute // conditional branch as conditional handle should be updated to false. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); std::vector zeroes = {0, 0, 0, 0}; @@ -767,7 +765,7 @@ TEST(GpuCommandBufferTest, ConditionalIf) { TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `d` data back to host. std::fill(dst.begin(), dst.end(), 42); @@ -775,6 +773,69 @@ TEST(GpuCommandBufferTest, ConditionalIf) { ASSERT_EQ(dst, expected); } +TEST(GpuCommandBufferTest, ConditionalIfWithMemset) { +#if CUDA_VERSION < 12040 + GTEST_SKIP() << "ConditionalsWithMemset are not supported before 12.4.1."; +#endif + Platform* platform = GpuPlatform(); + + StreamExecutor* executor = platform->ExecutorForDevice(0).value(); + + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=0, pred=true + DeviceMemory pred = executor->AllocateArray(1, 0); + DeviceMemory a = executor->AllocateArray(length, 0); + + constexpr bool kTrue = true; + TF_ASSERT_OK(stream->Memcpy(&pred, &kTrue, 1)); + TF_ASSERT_OK(stream->Memset32(&a, 0, byte_length)); + + // if (pred == true) memset(&a, ...); + CommandBuffer::Builder then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Memset(&a, uint8_t{1}, byte_length); + }; + + // Create a command buffer with a single conditional operation. + TF_ASSERT_OK_AND_ASSIGN(auto cmd_buffer, + executor->CreateCommandBuffer(primary)); + TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Copy `a` data back to host. + std::vector dst(length, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), a, byte_length)); + + std::vector expected(length, 1 << 24 | 1 << 16 | 1 << 8 | 1); + ASSERT_EQ(dst, expected); + + // Prepare argument for graph update: b = 0 + DeviceMemory b = executor->AllocateArray(length, 0); + TF_ASSERT_OK(stream->MemZero(&a, byte_length)); + + // if (pred == true) memset(&b, ...); + then_builder = [&](CommandBuffer* then_cmd) { + return then_cmd->Memset(&b, uint8_t{1}, byte_length); + }; + + // Update command buffer with a conditional to use new builder. + TF_ASSERT_OK(cmd_buffer->Update()); + TF_ASSERT_OK(cmd_buffer->If(pred, then_builder)); + TF_ASSERT_OK(cmd_buffer->Finalize()); + + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); + + // Copy `b` data back to host. + std::fill(dst.begin(), dst.end(), 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), b, byte_length)); + ASSERT_EQ(dst, expected); +} + TEST(GpuCommandBufferTest, ConditionalIfElse) { if (!IsAtLeastCuda12300()) { GTEST_SKIP() << "CUDA graph conditionals are not supported"; @@ -787,12 +848,12 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load multiplication kernel. MultiKernelLoaderSpec mul_spec(/*arity=*/3); - mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "MulI32"); TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; @@ -825,7 +886,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { TF_ASSERT_OK(cmd_buffer->IfElse(pred, then_builder, else_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `c` data back to host. @@ -841,7 +902,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { // Submit the same command buffer, but this time it should execute `else` // branch and multiply inputs. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -862,7 +923,7 @@ TEST(GpuCommandBufferTest, ConditionalIfElse) { TF_ASSERT_OK(cmd_buffer->IfElse(pred, then_builder, else_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `d` data back to host. @@ -883,12 +944,12 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load multiplication kernel. MultiKernelLoaderSpec mul_spec(/*arity=*/3); - mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "mul"); + mul_spec.AddInProcessSymbol(internal::GetMulI32Kernel(), "MulI32"); TF_ASSERT_OK_AND_ASSIGN(auto mul, MulI32Kernel::Create(executor, mul_spec)); int64_t length = 4; @@ -920,7 +981,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { TF_ASSERT_OK(cmd_buffer->Case(index, {branch0, branch1})); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); // Copy `c` data back to host. @@ -934,7 +995,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { TF_ASSERT_OK(stream->Memset32(&index, 1, sizeof(int32_t))); // Submit the same command buffer, but this time it should multiply inputs. - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -944,7 +1005,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Set index to `-1` (out of bound index value). TF_ASSERT_OK(stream->Memset32(&index, -1, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -953,7 +1014,7 @@ TEST(GpuCommandBufferTest, ConditionalCase) { // Set index to `2` (out of bound index value). TF_ASSERT_OK(stream->Memset32(&index, 2, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->BlockHostUntilDone()); TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); @@ -971,7 +1032,7 @@ TEST(GpuCommandBufferTest, ConditionalFor) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); int64_t length = 4; @@ -999,7 +1060,7 @@ TEST(GpuCommandBufferTest, ConditionalFor) { TF_ASSERT_OK(cmd_buffer->For(num_iters, loop_counter, body_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 42); @@ -1021,12 +1082,12 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load inc_and_cmp kernel. MultiKernelLoaderSpec icmp_spec(/*arity=*/3); - icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "IncAndCmp"); TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, IncAndCmpKernel::Create(executor, icmp_spec)); @@ -1066,7 +1127,7 @@ TEST(GpuCommandBufferTest, ConditionalWhile) { TF_ASSERT_OK(cmd_buffer->While(pred, cond_builder, body_builder)); TF_ASSERT_OK(cmd_buffer->Finalize()); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` data back to host. std::vector dst(4, 42); @@ -1131,7 +1192,7 @@ TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { // Create a command buffer with a DAG of memset commands. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); std::vector expected = {42, 43, 44}; ASSERT_EQ(transfer_buffers(), expected); @@ -1165,7 +1226,7 @@ TEST(GpuCommandBufferTest, ConditionalIfInExecutionScope) { constexpr bool kFalse = false; TF_ASSERT_OK(stream->Memcpy(&pred, &kFalse, 1)); TF_ASSERT_OK(stream->MemZero(&buffers[2], sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); expected = {42, 43, 0}; ASSERT_EQ(transfer_buffers(), expected); @@ -1186,12 +1247,12 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { // Load addition kernel. MultiKernelLoaderSpec add_spec(/*arity=*/3); - add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + add_spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, add_spec)); // Load inc_and_cmp kernel. MultiKernelLoaderSpec icmp_spec(/*arity=*/3); - icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "inc_and_cmp"); + icmp_spec.AddInProcessSymbol(internal::GetIncAndCmpKernel(), "IncAndCmp"); TF_ASSERT_OK_AND_ASSIGN(auto inc_and_cmp, IncAndCmpKernel::Create(executor, icmp_spec)); @@ -1232,7 +1293,7 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { // Create a command buffer with a single conditional operation. auto cmd_buffer = executor->CreateCommandBuffer(primary).value(); TF_ASSERT_OK(record(cmd_buffer.get(), 42, 10)); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); // Copy `b` and `c` data back to host. int32_t b_dst, c_dst; @@ -1265,7 +1326,7 @@ TEST(GpuCommandBufferTest, ConditionalWhileInExecutionScope) { TF_ASSERT_OK(stream->MemZero(&loop_counter, sizeof(int32_t))); TF_ASSERT_OK(stream->MemZero(&b, sizeof(int32_t))); - TF_ASSERT_OK(executor->Submit(stream.get(), *cmd_buffer)); + TF_ASSERT_OK(cmd_buffer->Submit(stream.get())); TF_ASSERT_OK(stream->Memcpy(&b_dst, b, sizeof(int32_t))); TF_ASSERT_OK(stream->Memcpy(&c_dst, c, sizeof(int32_t))); @@ -1288,7 +1349,7 @@ static void BM_CreateCommandBuffer(benchmark::State& state) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); @@ -1311,7 +1372,7 @@ static void BM_TraceCommandBuffer(benchmark::State& state) { TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); @@ -1336,7 +1397,7 @@ static void BM_UpdateCommandBuffer(benchmark::State& state) { StreamExecutor* executor = platform->ExecutorForDevice(0).value(); MultiKernelLoaderSpec spec(/*arity=*/3); - spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "add"); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); DeviceMemory b = executor->AllocateArray(1, 0); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc index 2dd2f48f0db41c..d77ba42f2d497f 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_device_info_test.cc @@ -14,12 +14,13 @@ limitations under the License. ==============================================================================*/ #include "absl/container/flat_hash_map.h" +#include "xla/service/platform_util.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/protobuf.h" @@ -44,7 +45,8 @@ TEST(DeviceInfoTest, DeviceInfoMatches) { tsl::protobuf::TextFormat::ParseFromString(spec_string, &proto)); gpu_specs[proto.device_description_str()] = proto.gpu_device_info(); } - + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); TF_ASSERT_OK_AND_ASSIGN(Platform * platform, PlatformManager::PlatformWithName(PLATFORM_NAME)); bool all_skipped = false; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h index 599480c13e92da..94cff4632638e1 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_driver.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_driver.h @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" namespace stream_executor { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h index c19fa1cceeba0c..f7eab3beb9f626 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_executor.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_executor.h @@ -1,4 +1,3 @@ -#include "xla/stream_executor/event_based_timer.h" /* Copyright 2019 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -30,10 +29,11 @@ limitations under the License. #include #include #include +#include +#include #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" -#include "absl/functional/any_invocable.h" #include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -45,6 +45,7 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/dnn.h" #include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/fft.h" #include "xla/stream_executor/gpu/gpu_collectives.h" #include "xla/stream_executor/gpu/gpu_driver.h" @@ -52,7 +53,6 @@ limitations under the License. #include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -122,11 +122,11 @@ class GpuExecutor : public StreamExecutorCommon { int device_ordinal() const override { return device_ordinal_; }; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; - // (supported on CUDA only) - void UnloadKernel(const Kernel* kernel) override; + // Releases any state associated with the previously loaded kernel. + void UnloadKernel(const Kernel* kernel); absl::Status LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) override; bool UnloadModule(ModuleHandle module_handle) override; @@ -137,18 +137,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::StatusOr> CreateOrShareConstant( Stream* stream, absl::Span content) override; - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) override; - - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& kernel, - const KernelArgs& args) override; - - absl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) override; - DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -204,9 +192,6 @@ class GpuExecutor : public StreamExecutorCommon { const DeviceMemoryBase& gpu_src, uint64_t size) override; - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - void DeallocateStream(Stream* stream) override; absl::Status BlockHostUntilDone(Stream* stream) override; @@ -237,10 +222,7 @@ class GpuExecutor : public StreamExecutorCommon { absl::StatusOr> CreateEvent() override; absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override; - - absl::StatusOr> CreateKernel() override; + std::optional> priority) override; absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) override; @@ -324,11 +306,6 @@ class GpuExecutor : public StreamExecutorCommon { absl::Status LoadModuleFromHsaco(const char* hsaco, GpuModuleHandle* module) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const std::optional& cluster_dims, - const Kernel& kernel, const KernelArgs& args); - bool UnloadGpuBinary(const void* gpu_binary) TF_EXCLUSIVE_LOCKS_REQUIRED(in_memory_modules_mu_); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h index 62db12705491bc..187d882c78369c 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_helpers.h @@ -23,17 +23,10 @@ limitations under the License. #include -#include -#include - -#include "xla/stream_executor/gpu/gpu_types.h" -#include "tsl/platform/logging.h" +#include "xla/stream_executor/device_memory.h" namespace stream_executor { -template -class DeviceMemory; - namespace gpu { // Converts a const DeviceMemory reference to its underlying typed pointer in diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h index ea027f4dac22fb..d17b974fe44b7a 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel.h @@ -39,7 +39,9 @@ namespace stream_executor::gpu { class GpuKernel : public Kernel { public: - explicit GpuKernel(GpuExecutor* gpu_executor) : gpu_executor_(gpu_executor) {} + explicit GpuKernel(GpuExecutor* gpu_executor) + : gpu_executor_(gpu_executor), + gpu_context_(gpu_executor->gpu_context()) {} // Note that the function is unloaded when the module is unloaded, and the // module that the function is contained in is owned by the GpuExecutor. @@ -51,17 +53,6 @@ class GpuKernel : public Kernel { unsigned Arity() const override { return arity_; } void set_name(std::string name) { name_ = std::move(name); } - void set_gpu_context(GpuContext* gpu_context) { gpu_context_ = gpu_context; } - - // Returns the GpuFunctionHandle value for passing to the CUDA API. - GpuFunctionHandle AsGpuFunctionHandle() const { - DCHECK(gpu_function_ != nullptr); - return const_cast(gpu_function_); - } - - // Returns the slot that the GpuFunctionHandle is stored within for this - // object, for the CUDA API which wants to load into a GpuFunctionHandle*. - GpuFunctionHandle* gpu_function_ptr() { return &gpu_function_; } // Returns the current kernel cache configuration preference as a // GpuFuncCachePreference. @@ -70,6 +61,12 @@ class GpuKernel : public Kernel { absl::StatusOr GetMaxOccupiedBlocksPerCore( ThreadDim threads, size_t dynamic_shared_memory_bytes) const override; + // Simple accessor methods. + GpuFunctionHandle gpu_function() const { return gpu_function_; } + void set_gpu_function(GpuFunctionHandle gpu_function) { + gpu_function_ = gpu_function; + } + private: GpuExecutor* gpu_executor_ = nullptr; GpuContext* gpu_context_ = nullptr; // context where kernel is loaded diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc index 507fbfa477520f..fcb97ca7e790c7 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_kernel_test.cc @@ -14,66 +14,102 @@ limitations under the License. ==============================================================================*/ #include +#include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/ascii.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "xla/service/platform_util.h" +#include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_test_kernels.h" +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/rocm/rocm_platform_id.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" #include "xla/stream_executor/typed_kernel_factory.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace stream_executor::gpu { - -TEST(GpuKernelTest, Add) { - using AddI32Kernel = - TypedKernelFactory, DeviceMemory, - DeviceMemory>; - auto name = absl::AsciiStrToUpper( - xla::PlatformUtil::CanonicalPlatformName("gpu").value()); - Platform* platform = PlatformManager::PlatformWithName(name).value(); - StreamExecutor* executor = platform->ExecutorForDevice(0).value(); - - TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); +namespace { + +using AddI32Kernel = + TypedKernelFactory, DeviceMemory, + DeviceMemory>; + +class GpuKernelTest : public ::testing::Test { + public: + void SetUp() override { + auto name = absl::AsciiStrToUpper( + xla::PlatformUtil::CanonicalPlatformName("gpu").value()); + Platform* platform = PlatformManager::PlatformWithName(name).value(); + executor_ = platform->ExecutorForDevice(0).value(); + } + + void RunAddI32Kernel(const MultiKernelLoaderSpec& spec) { + TF_ASSERT_OK_AND_ASSIGN(auto stream, executor_->CreateStream()); + TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor_, spec)); + + int64_t length = 4; + int64_t byte_length = sizeof(int32_t) * length; + + // Prepare arguments: a=1, b=2, c=0 + DeviceMemory a = executor_->AllocateArray(length, 0); + DeviceMemory b = executor_->AllocateArray(length, 0); + DeviceMemory c = executor_->AllocateArray(length, 0); + + TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); + TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); + TF_ASSERT_OK(stream->MemZero(&c, byte_length)); + + // Launch kernel. + ASSERT_TRUE( + stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + + // Copy data back to host. + std::vector dst(4, 42); + TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); + + std::vector expected = {3, 3, 3, 3}; + ASSERT_EQ(dst, expected); + } + + StreamExecutor* executor_; +}; + +TEST_F(GpuKernelTest, LoadAndRunKernelFromPtx) { + if (executor_->GetPlatform()->id() == + stream_executor::rocm::kROCmPlatformId) { + GTEST_SKIP() << "There is no PTX or any equivalent abstraction for ROCm."; + } MultiKernelLoaderSpec spec(/*arity=*/3); -#if defined(GOOGLE_CUDA) - spec.AddCudaPtxInMemory(internal::kAddI32Kernel, "add"); -#elif defined(TENSORFLOW_USE_ROCM) - spec.AddCudaCubinInMemory(internal::kAddI32KernelModule, "add"); -#endif - - TF_ASSERT_OK_AND_ASSIGN(auto add, AddI32Kernel::Create(executor, spec)); - - int64_t length = 4; - int64_t byte_length = sizeof(int32_t) * length; - - // Prepare arguments: a=1, b=2, c=0 - DeviceMemory a = executor->AllocateArray(length, 0); - DeviceMemory b = executor->AllocateArray(length, 0); - DeviceMemory c = executor->AllocateArray(length, 0); - - TF_ASSERT_OK(stream->Memset32(&a, 1, byte_length)); - TF_ASSERT_OK(stream->Memset32(&b, 2, byte_length)); - TF_ASSERT_OK(stream->MemZero(&c, byte_length)); - - // Launch kernel. - ASSERT_TRUE(stream->ThenLaunch(ThreadDim(), BlockDim(4), add, a, b, c).ok()); + spec.AddCudaPtxInMemory(internal::kAddI32KernelPtx, "AddI32"); + RunAddI32Kernel(spec); +} - // Copy data back to host. - std::vector dst(4, 42); - TF_ASSERT_OK(stream->Memcpy(dst.data(), c, byte_length)); +TEST_F(GpuKernelTest, LoadAndRunKernelFromCubin) { + MultiKernelLoaderSpec spec(/*arity=*/3); + TF_ASSERT_OK_AND_ASSIGN(auto binary, GetGpuTestKernelsFatbin()); + spec.AddCudaCubinInMemory(binary, "AddI32"); + RunAddI32Kernel(spec); +} - std::vector expected = {3, 3, 3, 3}; - ASSERT_EQ(dst, expected); +TEST_F(GpuKernelTest, LoadAndRunKernelFromSymbol) { + MultiKernelLoaderSpec spec(/*arity=*/3); + spec.AddInProcessSymbol(internal::GetAddI32Kernel(), "AddI32"); + RunAddI32Kernel(spec); } +} // namespace } // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc index e1c943a1171d95..b257ffa0b675ec 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -31,10 +32,14 @@ limitations under the License. #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" +#include "xla/stream_executor/gpu/gpu_kernel.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/lib/nvtx_utils.h" namespace stream_executor { @@ -48,7 +53,7 @@ void InternalHostCallback(void* data) { } } // namespace -bool GpuStream::Init() { +absl::Status GpuStream::Init() { int priority = [&]() { if (std::holds_alternative(stream_priority_)) { return std::get(stream_priority_); @@ -58,11 +63,9 @@ bool GpuStream::Init() { }(); if (!GpuDriver::CreateStream(parent_->gpu_context(), &gpu_stream_, priority)) { - return false; + return absl::InternalError("Failed to CreateStream"); } - return GpuDriver::InitEvent(parent_->gpu_context(), &completed_event_, - GpuDriver::EventFlags::kDisableTiming) - .ok(); + return absl::OkStatus(); } Stream::PlatformSpecificHandle GpuStream::platform_specific_handle() const { @@ -86,7 +89,10 @@ absl::Status GpuStream::MemZero(DeviceMemoryBase* location, uint64_t size) { size % 4 == 0) { return Memset32(location, 0x0, size); } else { - return parent_->Memset(this, location, 0x0, size); + return GpuDriver::AsynchronousMemsetUint8( + parent_->gpu_context(), + reinterpret_cast(location->opaque()), 0x0, size, + gpu_stream()); } } @@ -130,14 +136,12 @@ absl::Status GpuStream::Memcpy(void* host_dst, const DeviceMemoryBase& gpu_src, absl::Status GpuStream::WaitFor(Stream* other) { GpuStream* other_gpu = AsGpuStream(other); - GpuEventHandle other_completed_event = *(other_gpu->completed_event()); - TF_RETURN_IF_ERROR(GpuDriver::RecordEvent(parent_->gpu_context(), - other_completed_event, - AsGpuStreamValue(other_gpu))); - - if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), - AsGpuStreamValue(this), - other_completed_event)) { + + GpuEvent* other_completed_event = other_gpu->completed_event(); + TF_RETURN_IF_ERROR(other_completed_event->Record(other_gpu->gpu_stream())); + + if (GpuDriver::WaitStreamOnEvent(parent_->gpu_context(), gpu_stream(), + other_completed_event->gpu_event())) { return absl::OkStatus(); } return absl::InternalError("Couldn't wait for stream."); @@ -173,22 +177,18 @@ absl::Status GpuStream::DoHostCallbackWithStatus( return absl::InternalError("Failed to host callback."); } -void GpuStream::Destroy() { - if (completed_event_ != nullptr) { - absl::Status status = - GpuDriver::DestroyEvent(parent_->gpu_context(), &completed_event_); - if (!status.ok()) { - LOG(ERROR) << status.message(); - } +GpuStream::~GpuStream() { + BlockHostUntilDone().IgnoreError(); + parent()->DeallocateStream(this); + + if (!GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_)) { + LOG(ERROR) << "Deallocating stream with pending work"; } + completed_event_.reset(); GpuDriver::DestroyStream(parent_->gpu_context(), &gpu_stream_); } -bool GpuStream::IsIdle() const { - return GpuDriver::IsStreamIdle(parent_->gpu_context(), gpu_stream_); -} - void GpuStream::set_name(absl::string_view name) { name_ = name; tsl::profiler::NameStream( @@ -200,6 +200,83 @@ GpuStream::CreateEventBasedTimer(bool use_delay_kernel) { return parent_->CreateEventBasedTimer(this, use_delay_kernel); } +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, const Kernel& kernel, + const KernelArgs& args) { + return Launch(thread_dims, block_dims, std::nullopt, kernel, args); +} + +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const ClusterDim& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + return Launch(thread_dims, block_dims, std::make_optional(cluster_dims), + kernel, args); +} + +absl::Status GpuStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args) { + const GpuKernel* gpu_kernel = AsGpuKernel(&kernel); + GpuFunctionHandle function = gpu_kernel->gpu_function(); + + if (gpu_kernel->cache_config() != KernelCacheConfig::kNoPreference) { + TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( + function, gpu_kernel->GetGpuCacheConfig())); + } + + // Launch kernels with packed arguments. + auto launch = [this, &kernel, &cluster_dims, &thread_dims, &block_dims, + &function](const KernelArgsPackedArrayBase& packed) { + int32_t expected_number_of_arguments = + kernel.Arity() + (packed.number_of_shared_bytes() > 0); + + CHECK_EQ(expected_number_of_arguments, packed.number_of_arguments()) + << "Kernel " << kernel.name() << " has " << packed.number_of_arguments() + << " arguments, but expected " << expected_number_of_arguments + << "; arity=" << kernel.Arity() + << "; number_of_shared_bytes=" << packed.number_of_shared_bytes(); + + void** params = const_cast(packed.argument_addresses().data()); + + if (cluster_dims.has_value()) { + return GpuDriver::LaunchKernel( + parent_->gpu_context(), kernel.name(), function, cluster_dims->x, + cluster_dims->y, cluster_dims->z, block_dims.x, block_dims.y, + block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, + packed.number_of_shared_bytes(), gpu_stream(), params, + /*extra=*/nullptr); + } else { + return GpuDriver::LaunchKernel( + parent_->gpu_context(), kernel.name(), function, block_dims.x, + block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, + thread_dims.z, packed.number_of_shared_bytes(), gpu_stream(), params, + /*extra=*/nullptr); + } + }; + + // If arguments are already packed we can just launch the kernel. + if (auto* packed = DynCast(&args)) { + return launch(*packed); + } + + // For device memory array we rely on a custom kernel arguments packing. + if (auto* device_mem = DynCast(&args)) { + auto& pack = kernel.args_packing(); + if (!pack) { + return absl::InternalError( + "Kernel is missing a custom arguments packing function for device " + "memory arguments array"); + } + + TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem)); + return launch(*packed); + } + + return absl::InternalError("Unsupported kernel arguments type"); +} + GpuStream* AsGpuStream(Stream* stream) { DCHECK(stream != nullptr); return static_cast(stream); diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h index 18b77fb888481b..249fbf78877a4e 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_stream.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_stream.h @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include +#include #include #include "absl/functional/any_invocable.h" @@ -29,8 +31,11 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_executor.h" #include "xla/stream_executor/gpu/gpu_types.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" @@ -46,47 +51,32 @@ class GpuExecutor; // Thread-safe post-initialization. class GpuStream : public StreamCommon { public: - explicit GpuStream(GpuExecutor* parent) + GpuStream(GpuExecutor* parent, std::unique_ptr completed_event, + std::optional> priority) : StreamCommon(parent), parent_(parent), gpu_stream_(nullptr), - completed_event_(nullptr) {} - - // Note: teardown is handled by a parent's call to DeallocateStream. - ~GpuStream() override { - BlockHostUntilDone().IgnoreError(); - parent()->DeallocateStream(this); + completed_event_(std::move(completed_event)) { + if (priority.has_value()) { + stream_priority_ = priority.value(); + } } - // Returns a pointer to a platform specific stream associated with this object - // if it exists, or nullptr otherwise. This is available via Stream public API - // as Stream::PlatformSpecificHandle, and should not be accessed directly - // outside of a StreamExecutor package. - void* platform_specific_stream() const { return gpu_stream_; } + // Note: teardown is handled by a parent's call to DeallocateStream. + ~GpuStream() override; // Explicitly initialize the CUDA resources associated with this stream. - bool Init(); - - // Sets the priority of this stream. - void SetPriority(StreamPriority priority) { stream_priority_ = priority; } - void SetPriority(int priority) { stream_priority_ = priority; } + absl::Status Init(); std::variant priority() const override { return stream_priority_; } PlatformSpecificHandle platform_specific_handle() const override; - // Explicitly destroy the CUDA resources associated with this stream, used by - // StreamExecutor::DeallocateStream(). - void Destroy(); - - // Returns true if no work is pending or executing on the stream. - bool IsIdle() const; - // Retrieves an event which indicates that all work enqueued into the stream // has completed. Ownership of the event is not transferred to the caller, the // event is owned by this stream. - GpuEventHandle* completed_event() { return &completed_event_; } + GpuEvent* completed_event() { return completed_event_.get(); } // Returns the GpuStreamHandle value for passing to the CUDA API. // @@ -115,14 +105,22 @@ class GpuStream : public StreamCommon { void set_name(absl::string_view name) override; absl::StatusOr> CreateEventBasedTimer( bool use_delay_kernel) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const Kernel& k, const KernelArgs& args) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const ClusterDim& cluster_dims, const Kernel& k, + const KernelArgs& args) override; private: + // Helper method to launch a kernel with optional cluster dimensions. + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const std::optional& cluster_dims, + const Kernel& kernel, const KernelArgs& args); + GpuExecutor* parent_; // Executor that spawned this stream. GpuStreamHandle gpu_stream_; // Wrapped CUDA stream handle. std::variant stream_priority_; - - // Event that indicates this stream has completed. - GpuEventHandle completed_event_ = nullptr; + std::unique_ptr completed_event_; }; // Helper functions to simplify extremely common flows. diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc index cab05701159ad9..b97771724d0ad6 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.cu.cc @@ -17,8 +17,16 @@ limitations under the License. #include +#ifdef TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#endif + namespace stream_executor::gpu::internal { +// We want to be able to load those kernels by symbol name, so let's make them +// C functions. +extern "C" { + __global__ void AddI32(int32_t* a, int32_t* b, int32_t* c) { int index = threadIdx.x + blockIdx.x * blockDim.x; c[index] = a[index] + b[index]; @@ -39,6 +47,7 @@ __global__ void AddI32Ptrs3(Ptrs3 ptrs) { int index = threadIdx.x + blockIdx.x * blockDim.x; ptrs.c[index] = ptrs.a[index] + ptrs.b[index]; } +} void* GetAddI32Kernel() { return reinterpret_cast(&AddI32); } diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h index 74931452bb6624..dc143779389f56 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels.h @@ -23,11 +23,10 @@ namespace stream_executor::gpu::internal { // This is a collection of gpu kernels for writing simple StreamExecutor tests. // // Some of the kernels available as pre-compiled PTX blobs (can be loaded with -// CUDA driver API) / HSACO modules (can be loaded with ROCM driver api), and +// CUDA driver API), and // some of the kernels are written directly in CUDA C++ and can be loaded from a // symbol pointer (to test StreamExecutor CUDA runtime integration). -#if !defined(TENSORFLOW_USE_ROCM) // PTX kernel compiled from: // // __global__ void add(int* a, int* b, int* c) { @@ -36,24 +35,24 @@ namespace stream_executor::gpu::internal { // } // // Easiest way to get PTX from C++ is to use https://godbolt.org. -inline constexpr std::string_view kAddI32Kernel = R"( +inline constexpr std::string_view kAddI32KernelPtx = R"( .version 4.0 .target sm_50 .address_size 64 -.visible .entry add( - .param .u64 add_param_0, - .param .u64 add_param_1, - .param .u64 add_param_2 +.visible .entry AddI32( + .param .u64 AddI32_param_0, + .param .u64 AddI32_param_1, + .param .u64 AddI32_param_2 ) { .reg .b32 %r<8>; .reg .b64 %rd<11>; .loc 1 1 0 - ld.param.u64 %rd1, [add_param_0]; - ld.param.u64 %rd2, [add_param_1]; - ld.param.u64 %rd3, [add_param_2]; + ld.param.u64 %rd1, [AddI32_param_0]; + ld.param.u64 %rd2, [AddI32_param_1]; + ld.param.u64 %rd3, [AddI32_param_2]; .loc 1 3 3 cvta.to.global.u64 %rd4, %rd3; cvta.to.global.u64 %rd5, %rd2; @@ -75,9 +74,6 @@ inline constexpr std::string_view kAddI32Kernel = R"( ret; })"; -#else -#include "xla/stream_executor/rocm/add_i32_kernel.h" -#endif // !defined(TENSORFLOW_USE_ROCM) template struct Ptrs3 { diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc new file mode 100644 index 00000000000000..da638565540cb2 --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.cc @@ -0,0 +1,34 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/gpu/gpu_test_kernels_fatbin.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "tsl/platform/env.h" +#include "tsl/platform/errors.h" + +namespace stream_executor::gpu { + +absl::StatusOr> GetGpuTestKernelsFatbin() { + tsl::Env* env = tsl::Env::Default(); + std::string file_contents; + TF_RETURN_IF_ERROR(tsl::ReadFileToString(env, FATBIN_SRC, &file_contents)); + return std::vector(file_contents.begin(), file_contents.end()); +} +} // namespace stream_executor::gpu diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h new file mode 100644 index 00000000000000..803b8b3cab4b4f --- /dev/null +++ b/third_party/xla/xla/stream_executor/gpu/gpu_test_kernels_fatbin.h @@ -0,0 +1,35 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ +#define XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ + +#include +#include + +#include "absl/status/statusor.h" + +namespace stream_executor::gpu { + +// Returns the NVIDIA or HIP fatbin for the :gpu_test_kernels target. +// The fatbin is being extracted at compile time from the compilation artifact. +// Note that this function will read the extracted fatbin from the file system +// at runtime and will only be able to succeed when the test is being invoked by +// `bazel test`. +absl::StatusOr> GetGpuTestKernelsFatbin(); + +} // namespace stream_executor::gpu + +#endif // XLA_STREAM_EXECUTOR_GPU_GPU_TEST_KERNELS_FATBIN_H_ diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h index be0f9a54a2af98..656dd1e9809490 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_timer.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_timer.h @@ -21,12 +21,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/time/time.h" -#include "xla/stream_executor/event.h" #include "xla/stream_executor/event_based_timer.h" #include "xla/stream_executor/gpu/gpu_event.h" #include "xla/stream_executor/gpu/gpu_semaphore.h" -#include "xla/stream_executor/gpu/gpu_types.h" -#include "xla/stream_executor/stream.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/stream_executor/gpu/memcpy_test.cc b/third_party/xla/xla/stream_executor/gpu/memcpy_test.cc index 96b7700ce33538..1fbe79ce0ec4b2 100644 --- a/third_party/xla/xla/stream_executor/gpu/memcpy_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/memcpy_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc index 1ab7dea3030050..abf94db2519ee4 100644 --- a/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/redzone_allocator_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor_memory_allocator.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc b/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc index c0f66159400039..c1e053e7914942 100644 --- a/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc +++ b/third_party/xla/xla/stream_executor/gpu/stream_search_test.cc @@ -16,7 +16,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream_finder.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" @@ -37,19 +39,15 @@ class StreamSearchTest : public ::testing::Test { TEST_F(StreamSearchTest, NoMatchBadPtr) { void* bad_ptr = reinterpret_cast(0xdeadbeef); - StreamExecutorConfig config; - config.gpu_stream = bad_ptr; - - absl::StatusOr found_executor = - GetPlatform()->GetExecutor(config); - - // No executor found. - EXPECT_FALSE(found_executor.ok()); + EXPECT_FALSE(FindStream(GetPlatform(), bad_ptr).ok()); } TEST_F(StreamSearchTest, FoundPrevExecutor) { - TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * executor, - GetPlatform()->ExecutorForDevice(0)); + int number_devices = GetPlatform()->VisibleDeviceCount(); + EXPECT_GT(number_devices, 0); + TF_ASSERT_OK_AND_ASSIGN( + StreamExecutor * executor, + GetPlatform()->ExecutorForDevice(number_devices > 1 ? 1 : 0)); TF_ASSERT_OK_AND_ASSIGN(auto s, executor->CreateStream()); TF_ASSERT_OK_AND_ASSIGN(auto s2, executor->CreateStream()); @@ -57,17 +55,10 @@ TEST_F(StreamSearchTest, FoundPrevExecutor) { void* gpu_ptr = s->platform_specific_handle().stream; void* gpu_ptr_2 = s2->platform_specific_handle().stream; - StreamExecutorConfig c; - c.gpu_stream = gpu_ptr; - - TF_ASSERT_OK_AND_ASSIGN(StreamExecutor * found_executor, - GetPlatform()->GetExecutor(c)); - EXPECT_EQ(found_executor, executor); - - Stream* found1 = found_executor->FindAllocatedStream(gpu_ptr); + TF_ASSERT_OK_AND_ASSIGN(Stream * found1, FindStream(GetPlatform(), gpu_ptr)); EXPECT_EQ(found1, s.get()); - - Stream* found2 = found_executor->FindAllocatedStream(gpu_ptr_2); + TF_ASSERT_OK_AND_ASSIGN(Stream * found2, + FindStream(GetPlatform(), gpu_ptr_2)); EXPECT_EQ(found2, s2.get()); } diff --git a/third_party/xla/xla/stream_executor/host/BUILD b/third_party/xla/xla/stream_executor/host/BUILD index b1b2f89f71072a..a03a21ceb5592b 100644 --- a/third_party/xla/xla/stream_executor/host/BUILD +++ b/third_party/xla/xla/stream_executor/host/BUILD @@ -81,10 +81,14 @@ cc_library( ], deps = [ ":host_event", + ":host_kernel", "//xla/stream_executor:device_memory", "//xla/stream_executor:event", + "//xla/stream_executor:kernel", + "//xla/stream_executor:launch_dim", "//xla/stream_executor:stream", "//xla/stream_executor:stream_common", + "//xla/stream_executor:stream_executor_h", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log:check", @@ -138,14 +142,12 @@ xla_cc_test( ":ptr_host_kernel_function", "//xla/stream_executor", "//xla/stream_executor:device_memory", - "//xla/stream_executor:kernel_factory", "//xla/stream_executor:kernel_spec", "//xla/tsl/concurrency:async_value", - "@com_google_absl//absl/functional:any_invocable", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", @@ -167,20 +169,21 @@ cc_library( ":host_event", ":host_kernel", ":host_stream", - "//xla/stream_executor", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_memory", "//xla/stream_executor:event", "//xla/stream_executor:host_memory_allocation", + "//xla/stream_executor:kernel", "//xla/stream_executor:kernel_spec", "//xla/stream_executor:memory_allocation", + "//xla/stream_executor:platform", + "//xla/stream_executor:stream", "//xla/stream_executor:stream_executor_common", - "//xla/stream_executor:stream_executor_h", - "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:statusor", @@ -197,9 +200,9 @@ xla_cc_test( "//xla/stream_executor", "//xla/stream_executor:platform", "//xla/stream_executor:platform_manager", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/stream_executor/host/host_executor.cc b/third_party/xla/xla/stream_executor/host/host_executor.cc index ac1d22583d0fde..7e2ac758903eed 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.cc +++ b/third_party/xla/xla/stream_executor/host/host_executor.cc @@ -22,27 +22,27 @@ limitations under the License. #include #include +#include #include #include +#include #include -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" -#include "absl/synchronization/notification.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" #include "xla/stream_executor/host/host_kernel.h" #include "xla/stream_executor/host/host_stream.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" -#include "xla/stream_executor/stream_executor.h" +#include "xla/stream_executor/stream.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/mem.h" @@ -74,49 +74,23 @@ absl::Status HostExecutor::Init() { return absl::OkStatus(); } -absl::StatusOr> HostExecutor::CreateKernel() { - return std::make_unique(thread_pool_); -} - -absl::Status HostExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - HostKernel* host_kernel = AsHostKernel(kernel); +absl::StatusOr> HostExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto host_kernel = std::make_unique(thread_pool_); host_kernel->SetArity(spec.arity()); - VLOG(3) << "GetKernel on kernel " << kernel << " : " << kernel->name(); - for (auto& loader : KernelFunctionLoaderRegistry()) { auto loaded = loader(spec); if (!loaded.has_value()) continue; TF_ASSIGN_OR_RETURN(auto kernel_function, *std::move(loaded)); host_kernel->SetKernelFunction(std::move(kernel_function)); - return absl::OkStatus(); + return std::move(host_kernel); } return absl::InternalError("No method of loading host kernel provided"); } -absl::Status HostExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, - const KernelArgs& args) { - const HostKernel* host_kernel = AsHostKernel(&kernel); - - const KernelArgsDeviceMemoryArray* device_mem = - DynCast(&args); - - absl::Status result; - if (device_mem != nullptr) { - result = host_kernel->Launch(thread_dims, device_mem->device_memory_args()); - } else { - result = absl::UnimplementedError( - "Host kernel implements Launch method only for DeviceMemoryArray " - "arguments."); - } - return result; -} - bool HostExecutor::DeviceMemoryUsage(int64_t* free, int64_t* total) const { tsl::port::MemoryInfo mem_info = tsl::port::GetMemoryInfo(); *free = (mem_info.free != INT64_MAX) ? mem_info.free : -1; @@ -143,16 +117,6 @@ absl::Status HostExecutor::SynchronousMemZero(DeviceMemoryBase* location, return absl::OkStatus(); } -absl::Status HostExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - void* gpu_mem = location->opaque(); - // Enqueue the [asynchronous] memzero on the stream (HostStream) associated - // with the HostExecutor. - AsHostStream(stream)->EnqueueTask( - [gpu_mem, size, pattern]() { memset(gpu_mem, pattern, size); }); - return absl::OkStatus(); -} - absl::Status HostExecutor::SynchronousMemcpy(DeviceMemoryBase* gpu_dst, const void* host_src, uint64_t size) { diff --git a/third_party/xla/xla/stream_executor/host/host_executor.h b/third_party/xla/xla/stream_executor/host/host_executor.h index 18ec5a739faca5..55eacc5fff4851 100644 --- a/third_party/xla/xla/stream_executor/host/host_executor.h +++ b/third_party/xla/xla/stream_executor/host/host_executor.h @@ -13,20 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Declares the HostExecutor class, which is a CPU-only implementation of -// the StreamExecutor interface. For now, this is used for testing and to -// examine the performance of host-based StreamExecutor code. #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_HOST_HOST_EXECUTOR_H_ -#include #include #include #include #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/stream_executor/device_description.h" @@ -36,24 +31,21 @@ limitations under the License. #include "xla/stream_executor/host_memory_allocation.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor_common.h" #include "tsl/platform/threadpool.h" namespace stream_executor { namespace host { -// An implementation of StreamExecutor that does no communication or interaction -// with a device, but DOES perform memory operations backed by the host. -// Kernel invocations will fail, but host callbacks may be enqueued on this -// executor and its associated stream, and should follow standard ordering -// semantics. +// Declares the HostExecutor class, which is a CPU-only implementation of +// the StreamExecutor interface. For now, this is used for testing and to +// examine the performance of host-based StreamExecutor code. // // This is useful for evaluating the performance of host-based or fallback // routines executed under the context of a GPU executor. -// See stream_executor.h for description of the below operations. class HostExecutor : public StreamExecutorCommon { public: // A function that loads a kernel function from a given spec. If spec is not @@ -70,14 +62,8 @@ class HostExecutor : public StreamExecutorCommon { absl::Status Init() override; - absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) override; - - absl::StatusOr> CreateKernel() override; - - absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& kernel, - const KernelArgs& args) override; + absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) override; DeviceMemoryBase Allocate(uint64_t size, int64_t memory_space) override; void Deallocate(DeviceMemoryBase* mem) override; @@ -90,10 +76,6 @@ class HostExecutor : public StreamExecutorCommon { delete[] static_cast(mem); } - absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) override; - - // No "synchronize all activity" implemented for this platform at the moment. bool SynchronizeAllActivity() override { return true; } absl::Status SynchronousMemZero(DeviceMemoryBase* location, uint64_t size) override; diff --git a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc index 98157266a74eef..4e766fc92158d5 100644 --- a/third_party/xla/xla/stream_executor/host/host_kernel_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_kernel_test.cc @@ -28,14 +28,14 @@ limitations under the License. #include "absl/types/span.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/host/host_kernel_c_api.h" -#include "xla/stream_executor/kernel_factory.h" +#include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tsl/concurrency/async_value_ref.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/cpu_info.h" #include "tsl/platform/env.h" #include "tsl/platform/statusor.h" @@ -90,10 +90,10 @@ define ptr @LlvmAddI32(ptr noundef %0) { } )"; -static absl::StatusOr> NewStreamExecutor() { - StreamExecutorConfig config(/*ordinal=*/0); +static absl::StatusOr NewStreamExecutor() { TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host")); - TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config)); + TF_ASSIGN_OR_RETURN(auto stream_exec, + platform->ExecutorForDevice(/*ordinal=*/0)); return stream_exec; } @@ -157,8 +157,7 @@ TEST(HostKernelTest, Addition3D) { TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, - KernelFactory::Create(executor.get(), spec)); + TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; TF_ASSERT_OK(stream->Launch(ThreadDim(2, 2, 3), BlockDim(1), *add, kargs)); @@ -184,8 +183,7 @@ TEST(HostKernelTest, JitAddition) { TF_ASSERT_OK_AND_ASSIGN(auto executor, NewStreamExecutor()); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); - TF_ASSERT_OK_AND_ASSIGN(auto add, - KernelFactory::Create(executor.get(), spec)); + TF_ASSERT_OK_AND_ASSIGN(auto add, executor->LoadKernel(spec)); const KernelArgsDeviceMemoryArray kargs{args, /*shared_memory_bytes=*/0}; TF_ASSERT_OK(stream->Launch(ThreadDim(4), BlockDim(1), *add, kargs)); diff --git a/third_party/xla/xla/stream_executor/host/host_platform.cc b/third_party/xla/xla/stream_executor/host/host_platform.cc index c9a12709d70f22..b70ea46fa25825 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.cc +++ b/third_party/xla/xla/stream_executor/host/host_platform.cc @@ -52,25 +52,18 @@ HostPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr HostPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); -} - -absl::StatusOr HostPlatform::GetExecutor( - const StreamExecutorConfig& config) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); +HostPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); auto init_status = executor->Init(); if (!init_status.ok()) { return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())); + "failed initializing StreamExecutor for device ordinal %d: %s", ordinal, + init_status.ToString().c_str())); } return std::move(executor); diff --git a/third_party/xla/xla/stream_executor/host/host_platform.h b/third_party/xla/xla/stream_executor/host/host_platform.h index 25c1179dcd7565..b8ce8f4340d6c4 100644 --- a/third_party/xla/xla/stream_executor/host/host_platform.h +++ b/third_party/xla/xla/stream_executor/host/host_platform.h @@ -51,13 +51,13 @@ class HostPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - + private: + // Returns a device constructed with ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; + int ordinal); - private: // This platform's name. std::string name_; diff --git a/third_party/xla/xla/stream_executor/host/host_stream.cc b/third_party/xla/xla/stream_executor/host/host_stream.cc index ed6e040431e478..76b66711e03d62 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream.cc @@ -33,6 +33,9 @@ limitations under the License. #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/event.h" #include "xla/stream_executor/host/host_event.h" +#include "xla/stream_executor/host/host_kernel.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" #include "tsl/platform/denormal.h" @@ -192,6 +195,21 @@ absl::Status HostStream::BlockUntilDone() { return status; } -} // namespace host +absl::Status HostStream::Launch(const ThreadDim& thread_dims, + const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) { + const HostKernel* host_kernel = AsHostKernel(&kernel); + + const KernelArgsDeviceMemoryArray* device_mem = + DynCast(&args); + + if (device_mem != nullptr) { + return host_kernel->Launch(thread_dims, device_mem->device_memory_args()); + } + return absl::UnimplementedError( + "Host kernel implements Launch method only for DeviceMemoryArray " + "arguments."); +} +} // namespace host } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/host/host_stream.h b/third_party/xla/xla/stream_executor/host/host_stream.h index ed1bbc2011f48f..a43ba610e25417 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream.h +++ b/third_party/xla/xla/stream_executor/host/host_stream.h @@ -13,12 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Class declaration for Stream type that enqueues tasks onto a host/CPU-based -// execution context (as opposed to a GPU device), HostExecutor. #ifndef XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ #define XLA_STREAM_EXECUTOR_HOST_HOST_STREAM_H_ -#include +#include #include #include @@ -27,13 +25,20 @@ limitations under the License. #include "absl/status/status.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_common.h" +#include "xla/stream_executor/stream_executor.h" #include "tsl/platform/env.h" #include "tsl/platform/thread_annotations.h" namespace stream_executor { namespace host { +// Class declaration for Stream type that enqueues tasks onto a host/CPU-based +// execution context (as opposed to a GPU device), HostExecutor. class HostStream : public StreamCommon { public: explicit HostStream(StreamExecutor* executor); @@ -65,6 +70,8 @@ class HostStream : public StreamCommon { uint64_t size) override; absl::Status DoHostCallbackWithStatus( absl::AnyInvocable callback) override; + absl::Status Launch(const ThreadDim& thread_dims, const BlockDim& block_dims, + const Kernel& kernel, const KernelArgs& args) override; private: bool WorkAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); diff --git a/third_party/xla/xla/stream_executor/host/host_stream_test.cc b/third_party/xla/xla/stream_executor/host/host_stream_test.cc index 522d38781256fd..1f60709ceb4b2f 100644 --- a/third_party/xla/xla/stream_executor/host/host_stream_test.cc +++ b/third_party/xla/xla/stream_executor/host/host_stream_test.cc @@ -19,7 +19,7 @@ limitations under the License. #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/stream_executor/kernel_test.cc b/third_party/xla/xla/stream_executor/kernel_test.cc index cf63e5b0a55281..a554785735d3cd 100644 --- a/third_party/xla/xla/stream_executor/kernel_test.cc +++ b/third_party/xla/xla/stream_executor/kernel_test.cc @@ -66,15 +66,12 @@ static_assert( std::is_same_v, std::tuple>); -static std::unique_ptr NewStreamExecutor() { +static StreamExecutor* NewStreamExecutor() { Platform* platform = PlatformManager::PlatformWithName("Host").value(); - StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetUncachedExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } TEST(KernelTest, PackDeviceMemoryArguments) { - auto executor = NewStreamExecutor(); - DeviceMemoryBase a(reinterpret_cast(0x12345678)); DeviceMemoryBase b(reinterpret_cast(0x87654321)); @@ -125,7 +122,7 @@ TEST(KernelTest, FailToCreateTypedKernelFromEmptySpec) { MultiKernelLoaderSpec empty_spec(/*arity=*/0); auto executor = NewStreamExecutor(); - auto kernel = TypedKernelFactory<>::Create(executor.get(), empty_spec); + auto kernel = TypedKernelFactory<>::Create(executor, empty_spec); EXPECT_FALSE(kernel.ok()); } diff --git a/third_party/xla/xla/stream_executor/lazy_op_runner.h b/third_party/xla/xla/stream_executor/lazy_op_runner.h index c74a03e1ad5226..bf964e05bbaae6 100644 --- a/third_party/xla/xla/stream_executor/lazy_op_runner.h +++ b/third_party/xla/xla/stream_executor/lazy_op_runner.h @@ -280,76 +280,6 @@ struct FusedMatmulOp { } }; -struct FusedMHAOp { - using Signature = FusedMHASignature; - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_rhs_descriptor; - const MatmulTensorDescriptor& intermediate_bmm2_lhs_descriptor; - const TensorDescriptor& output_descriptor; - std::optional bias_descriptor; - std::optional activation_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - }; - - static absl::StatusOr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHARunnerFromDesc( - stream, desc, config.bmm1_lhs_descriptor, config.bmm1_rhs_descriptor, - config.bmm2_rhs_descriptor, config.intermediate_bmm2_lhs_descriptor, - config.output_descriptor, config.activation_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type); - } -}; - -struct FusedMHABackwardOp { - using Signature = FusedMHABackwardSignature; - - struct Config { - double scale; - const MatmulTensorDescriptor& bmm1_grad_gemm1_rhs_descriptor; - const MatmulTensorDescriptor& bmm1_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm1_lhs_descriptor; - const MatmulTensorDescriptor& bmm2_grad_gemm2_rhs_descriptor; - const MatmulTensorDescriptor& d_output_descriptor; - const TensorDescriptor& d_bmm1_lhs_descriptor; - const TensorDescriptor& d_bmm1_rhs_descriptor; - const TensorDescriptor& d_bmm2_rhs_descriptor; - std::optional d_s_descriptor; - std::optional d_bias_descriptor; - std::optional fwd_output_descriptor; - std::optional bias_descriptor; - std::optional dropout_rate; - std::optional seed; - FMHAMaskKind mask_type; - bool force_deterministic; - }; - - static absl::StatusOr< - std::unique_ptr>> - RunnerFromAlgorithmDesc(const AlgorithmDesc& desc, Config config, - Stream* stream) { - TF_ASSIGN_OR_RETURN(auto dnn, internal::GetDnnFromStream(stream)); - return dnn->FusedMHABackwardRunnerFromDesc( - stream, desc, config.bmm1_grad_gemm1_rhs_descriptor, - config.bmm1_grad_gemm2_rhs_descriptor, - config.bmm2_grad_gemm1_lhs_descriptor, - config.bmm2_grad_gemm2_rhs_descriptor, config.d_output_descriptor, - config.d_bmm1_lhs_descriptor, config.d_bmm1_rhs_descriptor, - config.d_bmm2_rhs_descriptor, config.d_s_descriptor, - config.d_bias_descriptor, config.fwd_output_descriptor, - config.bias_descriptor, config.scale, config.dropout_rate, config.seed, - config.mask_type, config.force_deterministic); - } -}; - } // namespace dnn } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/mock_platform.h b/third_party/xla/xla/stream_executor/mock_platform.h new file mode 100644 index 00000000000000..7c8e11dcabe7dc --- /dev/null +++ b/third_party/xla/xla/stream_executor/mock_platform.h @@ -0,0 +1,50 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ +#define XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream_executor.h" +#include "xla/test.h" + +namespace stream_executor { + +// Implements the Platform interface for testing. +class MockPlatform : public Platform { + public: + MockPlatform() = default; + MOCK_METHOD(Id, id, (), (const, override)); + MOCK_METHOD(const std::string&, Name, (), (const, override)); + MOCK_METHOD(int, VisibleDeviceCount, (), (const, override)); + MOCK_METHOD(bool, Initialized, (), (const, override)); + MOCK_METHOD(absl::Status, Initialize, (), (override)); + MOCK_METHOD(absl::StatusOr>, + DescriptionForDevice, (int ordinal), (const, override)); + MOCK_METHOD(absl::StatusOr, ExecutorForDevice, (int ordinal), + (override)); + MOCK_METHOD(absl::StatusOr, FindExisting, (int ordinal), + (override)); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MOCK_PLATFORM_H_ diff --git a/third_party/xla/xla/stream_executor/mock_stream.h b/third_party/xla/xla/stream_executor/mock_stream.h new file mode 100644 index 00000000000000..5e9750e124caaa --- /dev/null +++ b/third_party/xla/xla/stream_executor/mock_stream.h @@ -0,0 +1,94 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ +#define XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ + +#include +#include +#include + +#include "absl/functional/any_invocable.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_memory.h" +#include "xla/stream_executor/event.h" +#include "xla/stream_executor/event_based_timer.h" +#include "xla/stream_executor/kernel.h" +#include "xla/stream_executor/launch_dim.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" +#include "xla/test.h" + +namespace stream_executor { + +// Implements the Stream interface for testing. +class MockStream : public Stream { + public: + MockStream() = default; + MOCK_METHOD(PlatformSpecificHandle, platform_specific_handle, (), + (const, override)); + MOCK_METHOD(bool, ok, (), (const, override)); + MOCK_METHOD(absl::Status, RefreshStatus, (), (override)); + MOCK_METHOD(absl::StatusOr, GetOrCreateSubStream, (), (override)); + MOCK_METHOD(void, ReturnSubStream, (Stream * sub_stream), (override)); + MOCK_METHOD(absl::Status, WaitFor, (Stream * other), (override)); + MOCK_METHOD(absl::Status, WaitFor, (Event * event), (override)); + MOCK_METHOD(absl::Status, RecordEvent, (Event * event), (override)); + MOCK_METHOD(absl::Status, Memcpy, + (void *host_dst, const DeviceMemoryBase &gpu_src, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, Memcpy, + (DeviceMemoryBase * gpu_dst, const void *host_src, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, Memcpy, + (DeviceMemoryBase * gpu_dst, const DeviceMemoryBase &gpu_src, + uint64_t size), + (override)); + MOCK_METHOD(absl::Status, MemZero, + (DeviceMemoryBase * location, uint64_t size), (override)); + MOCK_METHOD(absl::Status, Memset32, + (DeviceMemoryBase * location, uint32_t pattern, uint64_t size), + (override)); + MOCK_METHOD(absl::Status, BlockHostUntilDone, (), (override)); + MOCK_METHOD(absl::Status, DoHostCallbackWithStatus, + (absl::AnyInvocable callback), (override)); + MOCK_METHOD(StreamExecutor *, parent, (), (const, override)); + MOCK_METHOD(CudaComputeCapability, GetCudaComputeCapability, (), + (const, override)); + MOCK_METHOD(RocmComputeCapability, GetRocmComputeCapability, (), + (const, override)); + MOCK_METHOD((std::variant), priority, (), + (const, override)); + MOCK_METHOD(absl::Status, Launch, + (const ThreadDim &thread_dims, const BlockDim &block_dims, + const Kernel &k, const KernelArgs &args), + (override)); + MOCK_METHOD(absl::Status, Launch, + (const ThreadDim &thread_dims, const BlockDim &block_dims, + const ClusterDim &cluster_dims, const Kernel &k, + const KernelArgs &args), + (override)); + MOCK_METHOD(absl::string_view, name, (), (const, override)); + MOCK_METHOD(void, set_name, (absl::string_view name), (override)); + MOCK_METHOD(absl::StatusOr>, + CreateEventBasedTimer, (bool use_delay_kernel), (override)); +}; + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_MOCK_STREAM_H_ diff --git a/third_party/xla/xla/stream_executor/mock_stream_executor.h b/third_party/xla/xla/stream_executor/mock_stream_executor.h index 3787be1133b5d4..0379c2c068dc18 100644 --- a/third_party/xla/xla/stream_executor/mock_stream_executor.h +++ b/third_party/xla/xla/stream_executor/mock_stream_executor.h @@ -1,3 +1,6 @@ +#include "xla/stream_executor/blas.h" +#include "xla/stream_executor/dnn.h" +#include "xla/stream_executor/fft.h" /* Copyright 2024 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -22,7 +25,6 @@ limitations under the License. #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -33,7 +35,6 @@ limitations under the License. #include "xla/stream_executor/event.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -43,24 +44,14 @@ limitations under the License. namespace stream_executor { -namespace fft { -class FftSupport; -} -namespace dnn { -class DnnSupport; -} -namespace blas { -class BlasSupport; -} - // Implements StreamExecutor for testing. class MockStreamExecutor : public StreamExecutor { public: MockStreamExecutor() = default; MOCK_METHOD(absl::Status, Init, (), (override)); MOCK_METHOD(int, device_ordinal, (), (const, override)); - MOCK_METHOD(absl::Status, GetKernel, - (const MultiKernelLoaderSpec& spec, Kernel* kernel), (override)); + MOCK_METHOD(absl::StatusOr>, LoadKernel, + (const MultiKernelLoaderSpec& spec), (override)); MOCK_METHOD(bool, UnloadModule, (ModuleHandle module_handle), (override)); MOCK_METHOD(absl::Status, LoadModule, (const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle), @@ -68,19 +59,6 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(absl::StatusOr>, CreateOrShareConstant, (Stream * stream, absl::Span content), (override)); - MOCK_METHOD(absl::Status, Launch, - (Stream * stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args), - (override)); - MOCK_METHOD(absl::Status, Launch, - (Stream * stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const ClusterDim& cluster_dims, - const Kernel& k, const KernelArgs& args), - (override)); - MOCK_METHOD(absl::Status, Submit, - (Stream * stream, const CommandBuffer& command_buffer)); - MOCK_METHOD(void, UnloadKernel, (const Kernel* kernel), (override)); MOCK_METHOD(DeviceMemoryBase, Allocate, (uint64_t size, int64_t memory_space), (override)); MOCK_METHOD(void, Deallocate, (DeviceMemoryBase * mem), (override)); @@ -104,10 +82,6 @@ class MockStreamExecutor : public StreamExecutor { (void* host_dst, const DeviceMemoryBase& device_src, uint64_t size), (override)); - MOCK_METHOD(absl::Status, Memset, - (Stream * stream, DeviceMemoryBase* location, uint8_t pattern, - uint64_t size), - (override)); MOCK_METHOD(void, DeallocateStream, (Stream * stream), (override)); MOCK_METHOD(absl::Status, BlockHostUntilDone, (Stream * stream), (override)); MOCK_METHOD(absl::Status, EnablePeerAccessTo, (StreamExecutor * other), @@ -124,8 +98,6 @@ class MockStreamExecutor : public StreamExecutor { MOCK_METHOD(blas::BlasSupport*, AsBlas, (), (override)); MOCK_METHOD(fft::FftSupport*, AsFft, (), (override)); MOCK_METHOD(dnn::DnnSupport*, AsDnn, (), (override)); - MOCK_METHOD(absl::StatusOr>, CreateKernel, (), - (override)); MOCK_METHOD(absl::StatusOr>, CreateCommandBuffer, (CommandBuffer::Mode mode), (override)); MOCK_METHOD(std::optional, GetAllocatorStats, (), (override)); diff --git a/third_party/xla/xla/stream_executor/platform.cc b/third_party/xla/xla/stream_executor/platform.cc index 47bf5600d20297..9e8d4a8065c8b9 100644 --- a/third_party/xla/xla/stream_executor/platform.cc +++ b/third_party/xla/xla/stream_executor/platform.cc @@ -32,11 +32,6 @@ std::string StreamPriorityToString(StreamPriority priority) { } } -StreamExecutorConfig::StreamExecutorConfig() : ordinal(-1) {} - -StreamExecutorConfig::StreamExecutorConfig(int ordinal_in) - : ordinal(ordinal_in) {} - bool Platform::Initialized() const { return true; } absl::Status Platform::Initialize() { return absl::OkStatus(); } diff --git a/third_party/xla/xla/stream_executor/platform.h b/third_party/xla/xla/stream_executor/platform.h index 5fbc44b274823c..759a4c0acde70f 100644 --- a/third_party/xla/xla/stream_executor/platform.h +++ b/third_party/xla/xla/stream_executor/platform.h @@ -29,7 +29,6 @@ limitations under the License. namespace stream_executor { class StreamExecutor; -class DeviceDescription; // An enum to represent different levels of stream priorities. // This is to avoid platform-specific representations in abstractions. @@ -38,23 +37,6 @@ enum class StreamPriority { Default = 0, Lowest, Highest }; // Returns a printable description of StreamPriority. std::string StreamPriorityToString(StreamPriority priority); -// StreamExecutorConfig encapsulates the set of options for constructing a -// StreamExecutor for a given platform. -struct StreamExecutorConfig { - // Sets members to defaults: -1 for ordinal (must be changed). - StreamExecutorConfig(); - - // Simple ordinal-setting constructor. - explicit StreamExecutorConfig(int ordinal); - - // The GPU stream for which we are searching the executor. - // If this field is specified for the search, others will be ignored. - void* gpu_stream = nullptr; - - // The ordinal of the device to be managed by the returned StreamExecutor. - int ordinal; -}; - // Abstract base class for a platform registered with the PlatformManager. class Platform { public: @@ -105,25 +87,20 @@ class Platform { virtual absl::StatusOr> DescriptionForDevice(int ordinal) const = 0; - // Returns a device with the given ordinal on this platform with a default - // plugin configuration or, if none can be found with the given ordinal or - // there is an error in opening a context to communicate with the device, an - // error status is returned. + // Returns a StreamExecutor for the given ordinal if one has already been + // created, or an error is returned if none exists. Does not create a new + // context with the device. + virtual absl::StatusOr FindExisting(int ordinal) { + return absl::NotFoundError("Not implemented for this platform."); + } + + // Returns a device with the given ordinal on this platform or, if none can + // be found with the given ordinal or there is an error in opening a context + // to communicate with the device, an error status is returned. // // Ownership of the executor is NOT transferred to the caller -- // the Platform owns the executors in a singleton-like fashion. virtual absl::StatusOr ExecutorForDevice(int ordinal) = 0; - - // Returns a device constructed with the options specified in "config". - // Ownership of the executor is NOT transferred to the caller. - virtual absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) = 0; - - // Returns a device constructed with the options specified in "config" without - // looking in or storing to the Platform's executor cache. - // Ownership IS transferred to the caller. - virtual absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) = 0; }; } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/rocm/BUILD b/third_party/xla/xla/stream_executor/rocm/BUILD index 06a44b044f2db0..1fbff0912d4b06 100644 --- a/third_party/xla/xla/stream_executor/rocm/BUILD +++ b/third_party/xla/xla/stream_executor/rocm/BUILD @@ -8,11 +8,6 @@ load( "//xla/stream_executor:build_defs.bzl", "stream_executor_friends", ) - -# copybara:comment_begin(oss-only) -load("//xla/stream_executor/rocm:build_defs.bzl", "rocm_embedded_test_modules") - -# copybara:comment_end load( "@local_config_rocm//rocm:build_defs.bzl", "if_rocm_hipblaslt", @@ -184,6 +179,7 @@ cc_library( "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:rocm_headers", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:fingerprint", @@ -704,10 +700,3 @@ cc_library( [":all_runtime"], ), ) - -# copybara:comment_begin(oss-only) -rocm_embedded_test_modules( - name = "add_i32_kernel", - srcs = if_rocm_is_configured(["add_i32_kernel.cu.cc"]), -) -# copybara:comment_end diff --git a/third_party/xla/xla/stream_executor/rocm/build_defs.bzl b/third_party/xla/xla/stream_executor/rocm/build_defs.bzl deleted file mode 100644 index 0be87739c8469f..00000000000000 --- a/third_party/xla/xla/stream_executor/rocm/build_defs.bzl +++ /dev/null @@ -1,68 +0,0 @@ -""" ROCM-specific build macros. -""" - -load("@local_config_rocm//rocm:build_defs.bzl", "rocm_gpu_architectures") - -def rocm_embedded_test_modules(name, srcs, testonly = True, **kwargs): - """Compile srcs into hsaco files and create a header only cc_library. - - Binary files are embedded as constant data. - - Args: - name: name for the generated cc_library target, and the base name for - generated header file - srcs: source files for input modules - testonly: If True, the target can only be used with tests. - **kwargs: keyword arguments passed onto the generated cc_library() rule. - """ - - # Lets piggyback this on top crosstool wrapper for now - hipcc_tool = "@local_config_rocm//crosstool:crosstool_wrapper_driver_is_not_gcc" - target_opts = " ".join(["--amdgpu-target=" + - arch for arch in rocm_gpu_architectures()]) - - header_file = "%s.h" % name - - native.genrule( - name = name + "_header_file", - srcs = srcs, - outs = [header_file], - cmd = """ - tmp_name_for_xxd() { - local filename=$$(basename $$1) - local name="k" - for word in $$(echo $${filename%%%%.*} | tr '_' ' '); do - name="$$name$${word^}" - done - echo "$${name}Module" - } - - echo '#pragma once' > $@ - echo '#include ' >> $@ - for src in $(SRCS); do - tmp=$$(tmp_name_for_xxd $$src); - $(location %s) -x rocm %s --genco -c $$src -o $$tmp && xxd -i $$tmp | sed \ - -e 's/unsigned char/inline constexpr uint8_t/g' \ - -e '$$d' >> $@; - rm -f $$tmp - done - """ % (hipcc_tool, target_opts), - tools = [hipcc_tool], - testonly = testonly, - target_compatible_with = select({ - "@local_config_rocm//rocm:using_hipcc": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - ) - - native.cc_library( - name = name, - srcs = [], - hdrs = [header_file], - testonly = testonly, - target_compatible_with = select({ - "@local_config_rocm//rocm:using_hipcc": [], - "//conditions:default": ["@platforms//:incompatible"], - }), - **kwargs - ) diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc index 465dbbe84b2a00..2f61eae925d846 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_driver.cc @@ -1148,6 +1148,21 @@ struct BitPatternToValue { return absl::OkStatus(); } +absl::Status GpuDriver::LaunchKernel( + GpuContext* context, absl::string_view kernel_name, + GpuFunctionHandle function, unsigned int cluster_dim_x, + unsigned int cluster_dim_y, unsigned int cluster_dim_z, + unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z, + unsigned int block_dim_x, unsigned int block_dim_y, + unsigned int block_dim_z, unsigned int shared_mem_bytes, + GpuStreamHandle stream, void** kernel_params, void** extra) { + if (cluster_dim_x != 1 || cluster_dim_y != 1 || cluster_dim_z != 1) + return absl::UnimplementedError("Not implemented for ROCm"); + return LaunchKernel(context, kernel_name, function, grid_dim_x, grid_dim_y, + grid_dim_z, block_dim_x, block_dim_y, block_dim_z, + shared_mem_bytes, stream, kernel_params, extra); +} + /* static */ absl::Status GpuDriver::LoadPtx(GpuContext* context, const char* ptx_contents, hipModule_t* module) { diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc index c04406a3707499..daa8cd7d4568fb 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_executor.cc @@ -255,9 +255,9 @@ absl::StatusOr GpuExecutor::DelayKernelIsSupported(GpuStream* stream) { return false; } -absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { - GpuKernel* rocm_kernel = AsGpuKernel(kernel); +absl::StatusOr> GpuExecutor::LoadKernel( + const MultiKernelLoaderSpec& spec) { + auto rocm_kernel = std::make_unique(this); hipModule_t module = nullptr; const std::string* kernel_name; @@ -272,7 +272,7 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, if (module == nullptr) { TF_RETURN_IF_ERROR(GpuDriver::LoadHsaco(context_, hsaco, &module)); } - kernel_to_gpu_binary_[kernel] = hsaco; + kernel_to_gpu_binary_[rocm_kernel.get()] = hsaco; } else if (spec.has_in_process_symbol()) { kernel_name = &spec.in_process_symbol().kernel_name(); void* symbol = spec.in_process_symbol().symbol(); @@ -284,10 +284,10 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, TF_ASSIGN_OR_RETURN( GpuFunctionHandle function, GpuRuntime::GetFuncBySymbol(spec.in_process_symbol().symbol())); - *rocm_kernel->gpu_function_ptr() = function; + rocm_kernel->set_gpu_function(function); #else - *rocm_kernel->gpu_function_ptr() = - static_cast(spec.in_process_symbol().symbol()); + rocm_kernel->set_gpu_function( + static_cast(spec.in_process_symbol().symbol())); #endif // TF_ROCM_VERSION >= 60200 } else { @@ -298,9 +298,10 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, // from a module, as ROCm runtime did that automatically for us. if (!spec.has_in_process_symbol()) { VLOG(2) << "getting function " << *kernel_name << " from module " << module; - TF_RETURN_IF_ERROR( - GpuDriver::GetModuleFunction(context_, module, kernel_name->c_str(), - rocm_kernel->gpu_function_ptr())); + GpuFunctionHandle function; + TF_RETURN_IF_ERROR(GpuDriver::GetModuleFunction( + context_, module, kernel_name->c_str(), &function)); + rocm_kernel->set_gpu_function(function); } // We have to trust the kernel loader spec arity because there doesn't appear @@ -310,93 +311,28 @@ absl::Status GpuExecutor::GetKernel(const MultiKernelLoaderSpec& spec, // unable to get kernel metadata for in-process kernel if (!spec.has_in_process_symbol()) { KernelMetadata kernel_metadata; - TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel, &kernel_metadata)); - kernel->set_metadata(kernel_metadata); + TF_RETURN_IF_ERROR(GetKernelMetadata(rocm_kernel.get(), &kernel_metadata)); + rocm_kernel->set_metadata(kernel_metadata); } - kernel->set_name(*kernel_name); - kernel->set_args_packing(spec.kernel_args_packing()); - return absl::OkStatus(); + rocm_kernel->set_name(*kernel_name); + rocm_kernel->set_args_packing(spec.kernel_args_packing()); + return std::move(rocm_kernel); } absl::Status GpuExecutor::GetKernelMetadata(GpuKernel* rocm_kernel, KernelMetadata* kernel_metadata) { int value = 0; TF_RETURN_IF_ERROR(GpuDriver::FuncGetAttribute( - HIP_FUNC_ATTRIBUTE_NUM_REGS, *rocm_kernel->gpu_function_ptr(), &value)); + HIP_FUNC_ATTRIBUTE_NUM_REGS, rocm_kernel->gpu_function(), &value)); kernel_metadata->set_registers_per_thread(value); TF_RETURN_IF_ERROR( GpuDriver::FuncGetAttribute(HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, - *rocm_kernel->gpu_function_ptr(), &value)); + rocm_kernel->gpu_function(), &value)); kernel_metadata->set_shared_memory_bytes(value); return absl::OkStatus(); } -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const Kernel& kernel, const KernelArgs& args) { - GpuStreamHandle hipstream = AsGpuStreamValue(stream); - const GpuKernel* rocm_kernel = AsGpuKernel(&kernel); - hipFunction_t hipfunc = rocm_kernel->AsGpuFunctionHandle(); - - if (rocm_kernel->cache_config() != KernelCacheConfig::kNoPreference) { - TF_RETURN_IF_ERROR(GpuDriver::FuncSetCacheConfig( - hipfunc, rocm_kernel->GetGpuCacheConfig())); - } - - auto launch = [&](const KernelArgsPackedArrayBase& packed) { - CHECK_EQ(kernel.Arity() + (args.number_of_shared_bytes() > 0), - packed.number_of_arguments()); - - void** kernel_params = - const_cast(packed.argument_addresses().data()); - - return GpuDriver::LaunchKernel( - GetGpuContext(stream), kernel.name(), hipfunc, block_dims.x, - block_dims.y, block_dims.z, thread_dims.x, thread_dims.y, thread_dims.z, - args.number_of_shared_bytes(), hipstream, kernel_params, nullptr); - }; - - auto* packed_args = DynCast(&args); - if (packed_args) return launch(*packed_args); - - if (auto* device_mem = DynCast(&args)) { - auto& pack = kernel.args_packing(); - if (!pack) { - return absl::InternalError( - "Kernel is missing a custom arguments packing function for device " - "memory arguments array"); - } - - TF_ASSIGN_OR_RETURN(auto packed_args, pack(kernel, *device_mem)); - return launch(*packed_args); - } - - return absl::InternalError("Unsupported kernel arguments type"); -} - -absl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, - const Kernel& kernel, const KernelArgs& args) { - if (cluster_dims.x != 1 || cluster_dims.y != 1 || cluster_dims.z != 1) - return absl::UnimplementedError("Not implemented for ROCm"); - return Launch(stream, thread_dims, block_dims, kernel, args); -} - -absl::Status GpuExecutor::Submit(Stream* stream, - const CommandBuffer& command_buffer) { - if (command_buffer.mode() != CommandBuffer::Mode::kPrimary) { - return absl::InvalidArgumentError( - "Can't submit non-primary command buffer for execution"); - } - - auto exec = GpuCommandBuffer::Cast(&command_buffer)->executable(); - VLOG(3) << "Launch command buffer execuable graph " << exec - << " on a stream: " << stream; - return GpuDriver::GraphLaunch(exec, AsGpuStreamValue(stream)); -} - absl::Status GpuExecutor::LoadModule(const MultiModuleLoaderSpec& spec, ModuleHandle* module_handle) { // In GpuExecutor we store the pointer to the HSACO binary as @@ -488,16 +424,6 @@ absl::Status GpuExecutor::SynchronousMemcpy(void* host_dst, AsROCmDevicePtr(gpu_src), size); } -absl::Status GpuExecutor::Memset(Stream* stream, DeviceMemoryBase* location, - uint8 pattern, uint64_t size) { - VLOG(2) << "enqueueing memset8 operation onto stream " << stream - << " at location " << location << " with size " << size - << " and pattern " << std::hex << pattern; - return GpuDriver::AsynchronousMemsetUint8(context_, AsROCmDevicePtr(location), - pattern, size, - AsGpuStreamValue(stream)); -} - void GpuExecutor::DeallocateStream(Stream* stream) { { absl::MutexLock lock(&mu_); @@ -507,11 +433,7 @@ void GpuExecutor::DeallocateStream(Stream* stream) { } GpuStream* rocm_stream = AsGpuStream(stream); absl::MutexLock l(&alive_gpu_streams_mu_); - alive_gpu_streams_.erase(rocm_stream->platform_specific_stream()); - if (!rocm_stream->IsIdle()) { - LOG(ERROR) << "Deallocating stream with pending work"; - } - rocm_stream->Destroy(); + alive_gpu_streams_.erase(rocm_stream->gpu_stream()); } absl::Status GpuExecutor::BlockHostUntilDone(Stream* stream) { @@ -650,27 +572,13 @@ absl::StatusOr> GpuExecutor::CreateEvent() { absl::StatusOr> GpuExecutor::CreateStream( std::optional> priority) { - auto gpu_stream = std::make_unique(this); - if (priority.has_value()) { - if (std::holds_alternative(*priority)) { - gpu_stream->SetPriority(std::get(*priority)); - } else { - gpu_stream->SetPriority(std::get(*priority)); - } - } + TF_ASSIGN_OR_RETURN(auto event, CreateGpuEvent(/*allow_timing=*/false)); + auto stream = std::make_unique(this, std::move(event), priority); absl::MutexLock l(&alive_gpu_streams_mu_); - bool init_worked = gpu_stream->Init(); - if (init_worked) { - auto platform_specific_stream = gpu_stream->platform_specific_stream(); - alive_gpu_streams_[platform_specific_stream] = gpu_stream.get(); - return std::move(gpu_stream); - } else { - return absl::InvalidArgumentError("Failed to initialize GPU stream"); - } -} - -absl::StatusOr> GpuExecutor::CreateKernel() { - return std::make_unique(this); + TF_RETURN_IF_ERROR(stream->Init()); + auto gpu_stream = stream->gpu_stream(); + alive_gpu_streams_[gpu_stream] = stream.get(); + return std::move(stream); } absl::StatusOr> GpuExecutor::CreateCommandBuffer( diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc index 0ac3540c4e627d..97413a6347584d 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.cc @@ -28,67 +28,10 @@ limitations under the License. namespace stream_executor { namespace gpu { -ROCmPlatform::ROCmPlatform() - : name_("ROCM"), min_numa_node_(0), limit_numa_node_(0) {} +ROCmPlatform::ROCmPlatform() : name_("ROCM") {} ROCmPlatform::~ROCmPlatform() {} -// Due to legacy issues in user code, we can't currently call InpectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void ROCmPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - absl::once_flag once; - absl::call_once(once, [&] { - StreamExecutorConfig config; - for (int i = 0; i < VisibleDeviceCount(); i++) { - config.ordinal = i; - StreamExecutor* exec = GetExecutor(config).value(); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int ROCmPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int ROCmPlatform::DeviceToBus(int device_ordinal) { - StreamExecutorConfig config; - config.ordinal = device_ordinal; - StreamExecutor* exec = GetExecutor(config).value(); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr ROCmPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - StreamExecutorConfig config; - config.ordinal = i; - return GetExecutor(config).value(); - } - } - - return absl::Status{ - absl::StatusCode::kNotFound, - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)}; -} - Platform::Id ROCmPlatform::id() const { return rocm::kROCmPlatformId; } int ROCmPlatform::VisibleDeviceCount() const { @@ -110,47 +53,28 @@ ROCmPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr ROCmPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.GetOrCreate( + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } -absl::StatusOr ROCmPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } - return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); +absl::StatusOr ROCmPlatform::FindExisting(int ordinal) { + return executor_cache_.Get(ordinal); } absl::StatusOr> -ROCmPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::Status{ - absl::StatusCode::kInternal, - absl::StrFormat( - "failed initializing StreamExecutor for ROCM device ordinal %d: %s", - config.ordinal, init_status.ToString().c_str())}; - } - +ROCmPlatform::GetUncachedExecutor(int ordinal) { + auto executor = std::make_unique(this, ordinal); + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeROCmPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. auto status = PlatformManager::PlatformWithName("ROCM"); if (!status.ok()) { - std::unique_ptr platform(new gpu::ROCmPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK(PlatformManager::RegisterPlatform( + std::make_unique())); } } diff --git a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h index c11f9ac807673b..e37345c5275127 100644 --- a/third_party/xla/xla/stream_executor/rocm/rocm_platform.h +++ b/third_party/xla/xla/stream_executor/rocm/rocm_platform.h @@ -43,16 +43,6 @@ class ROCmPlatform : public Platform { ROCmPlatform(); ~ROCmPlatform() override; - // ROCmPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kROCmPlatform above. Platform::Id id() const override; @@ -66,16 +56,14 @@ class ROCmPlatform : public Platform { int ordinal) const override; absl::StatusOr ExecutorForDevice(int ordinal) override; - - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - - absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; + absl::StatusOr FindExisting(int ordinal) override; private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); + // Returns a device constructed with ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + absl::StatusOr> GetUncachedExecutor( + int ordinal); // This platform's name. std::string name_; @@ -86,15 +74,6 @@ class ROCmPlatform : public Platform { // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense./ - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - ROCmPlatform(const ROCmPlatform&) = delete; void operator=(const ROCmPlatform&) = delete; }; diff --git a/third_party/xla/xla/stream_executor/stream.h b/third_party/xla/xla/stream_executor/stream.h index 71cdf9a35b8da7..0fccf94270a85c 100644 --- a/third_party/xla/xla/stream_executor/stream.h +++ b/third_party/xla/xla/stream_executor/stream.h @@ -272,7 +272,9 @@ class Stream { // platform driver. virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) = 0; + const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } // Launches a data parallel kernel with the given thread/block // dimensionality and already-packed args/sizes to pass to the underlying @@ -280,7 +282,9 @@ class Stream { virtual absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) = 0; + const KernelArgs &args) { + return absl::UnimplementedError("Not implemented"); + } // Get/set a name for a stream, which can be shown in profiling tools virtual absl::string_view name() const = 0; diff --git a/third_party/xla/xla/stream_executor/stream_common.cc b/third_party/xla/xla/stream_executor/stream_common.cc index 048623da37c01a..e7833bfd25dab2 100644 --- a/third_party/xla/xla/stream_executor/stream_common.cc +++ b/third_party/xla/xla/stream_executor/stream_common.cc @@ -21,13 +21,10 @@ limitations under the License. #include #include -#include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "xla/stream_executor/blas.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/logging.h" @@ -40,19 +37,6 @@ StreamCommon::StreamCommon(StreamExecutor *parent) CHECK_NE(parent, nullptr); } -absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, const Kernel &k, - const KernelArgs &args) { - return parent_->Launch(this, thread_dims, block_dims, k, args); -} - -absl::Status StreamCommon::Launch(const ThreadDim &thread_dims, - const BlockDim &block_dims, - const ClusterDim &cluster_dims, - const Kernel &k, const KernelArgs &args) { - return parent_->Launch(this, thread_dims, block_dims, cluster_dims, k, args); -} - StreamCommon::PlatformSpecificHandle StreamCommon::platform_specific_handle() const { PlatformSpecificHandle handle; diff --git a/third_party/xla/xla/stream_executor/stream_common.h b/third_party/xla/xla/stream_executor/stream_common.h index 3d2ade72ff12e3..f7029c72fbadbf 100644 --- a/third_party/xla/xla/stream_executor/stream_common.h +++ b/third_party/xla/xla/stream_executor/stream_common.h @@ -28,7 +28,6 @@ limitations under the License. #include #include "absl/base/thread_annotations.h" -#include "absl/functional/any_invocable.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -36,8 +35,6 @@ limitations under the License. #include "absl/synchronization/mutex.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/fft.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" @@ -84,11 +81,6 @@ class StreamCommon : public Stream { std::variant priority() const override { return StreamPriority::Default; } - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const Kernel &k, const KernelArgs &args) override; - absl::Status Launch(const ThreadDim &thread_dims, const BlockDim &block_dims, - const ClusterDim &cluster_dims, const Kernel &k, - const KernelArgs &args) override; // Doesn't do anything interesting by default; GpuStream connects this to NVTX absl::string_view name() const override { return name_; } @@ -107,8 +99,6 @@ class StreamCommon : public Stream { // Checks the status and logs the error message, if any. void CheckStatus(absl::Status status) TF_LOCKS_EXCLUDED(mu_); - void SetError() { CheckError(false /* = operation_retcode */); } - std::string name_; private: diff --git a/third_party/xla/xla/stream_executor/stream_executor.h b/third_party/xla/xla/stream_executor/stream_executor.h index 60fc20de835fb7..a0c3c48e521e30 100644 --- a/third_party/xla/xla/stream_executor/stream_executor.h +++ b/third_party/xla/xla/stream_executor/stream_executor.h @@ -1,5 +1,3 @@ -#include "absl/functional/any_invocable.h" -#include "absl/log/log.h" /* Copyright 2015 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); @@ -15,12 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// The StreamExecutor is a single-device abstraction for: -// -// * Loading/launching data-parallel-kernels -// * Invoking pre-canned high-performance library routines (like matrix -// multiply) - #ifndef XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ #define XLA_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ @@ -31,6 +23,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" @@ -43,7 +36,6 @@ limitations under the License. #include "xla/stream_executor/fft.h" #include "xla/stream_executor/kernel.h" #include "xla/stream_executor/kernel_spec.h" -#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/memory_allocation.h" #include "xla/stream_executor/module_spec.h" #include "xla/stream_executor/platform.h" @@ -67,6 +59,12 @@ inline std::string MemoryTypeString(MemoryType memory_type) { } } +/// The StreamExecutor is a single-device abstraction for: +// +// * Loading/launching data-parallel-kernels +// * Invoking pre-canned high-performance library routines (like matrix +// multiply) +// // Interface which defines the method for interacting with an accelerator device // (e.g. GPU, TPU). class StreamExecutor { @@ -84,8 +82,10 @@ class StreamExecutor { // Creates and initializes a Stream. virtual absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) = 0; + std::optional> priority) = 0; + absl::StatusOr> CreateStream() { + return CreateStream(std::nullopt); + } // Creates and initializes an Event. virtual absl::StatusOr> CreateEvent() = 0; @@ -107,15 +107,13 @@ class StreamExecutor { return AllocateArray(1); } - // Retrieves (loads) a kernel, if one exists. + // Loads a kernel from a MultiKernelLoaderSpec. // // Parameters: // spec: The MultiKernelLoaderSpec is usually generated as a compile-time // constant into an appropriate namespace. - // kernel: Outparam that the kernel is loaded into. A given Kernel - // instantiation should not be loaded into more than once. - virtual absl::Status GetKernel(const MultiKernelLoaderSpec& spec, - Kernel* kernel) { + virtual absl::StatusOr> LoadKernel( + const MultiKernelLoaderSpec& spec) { return absl::UnimplementedError("Not Implemented"); } @@ -138,35 +136,6 @@ class StreamExecutor { return absl::UnimplementedError("Not Implemented"); } - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - - virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, const Kernel& k, - const KernelArgs& args) { - return absl::UnimplementedError("Not Implemented"); - } - - // Launches a data parallel kernel with the given thread/block - // dimensionality and already-packed args/sizes to pass to the underlying - // platform driver. - virtual absl::Status Launch(Stream* stream, const ThreadDim& thread_dims, - const BlockDim& block_dims, - const ClusterDim& cluster_dims, const Kernel& k, - const KernelArgs& args) { - return absl::UnimplementedError("Not Implemented"); - } - - // Submits command buffer for execution to the underlying platform driver. - virtual absl::Status Submit(Stream* stream, - const CommandBuffer& command_buffer) { - return absl::UnimplementedError("Not Implemented"); - } - - // Releases any state associated with the previously loaded kernel. - virtual void UnloadKernel(const Kernel* kernel) {} - // Synchronously allocates size bytes on the underlying platform and returns // a DeviceMemoryBase representing that allocation. In the case of failure, // nullptr is returned. @@ -244,14 +213,6 @@ class StreamExecutor { return SynchronousMemcpy(host_dst, device_src, size); } - // Enqueues an operation onto stream to set 8-bit patterns starting at - // location, for byte count given by size. Returns whether the operation was - // successfully enqueued onto the stream. - virtual absl::Status Memset(Stream* stream, DeviceMemoryBase* location, - uint8_t pattern, uint64_t size) { - return absl::InternalError("Not implemented"); - } - // Deallocates stream resources on the underlying platform. virtual void DeallocateStream(Stream* stream) = 0; @@ -314,12 +275,6 @@ class StreamExecutor { // underlying platform. virtual dnn::DnnSupport* AsDnn() { return nullptr; } - // Creates a new Kernel object. - // TODO(klucke) Combine with GetKernel. - virtual absl::StatusOr> CreateKernel() { - return absl::UnimplementedError("Kernels are not implemented"); - } - // Creates a new CommandBuffer object. virtual absl::StatusOr> CreateCommandBuffer( CommandBuffer::Mode mode) { diff --git a/third_party/xla/xla/stream_executor/stream_executor_test.cc b/third_party/xla/xla/stream_executor/stream_executor_test.cc index 34bd4593cbc475..9a2ca572534db8 100644 --- a/third_party/xla/xla/stream_executor/stream_executor_test.cc +++ b/third_party/xla/xla/stream_executor/stream_executor_test.cc @@ -25,10 +25,9 @@ limitations under the License. namespace stream_executor { -static absl::StatusOr> NewStreamExecutor() { - StreamExecutorConfig config(/*ordinal=*/0); +static absl::StatusOr NewStreamExecutor() { TF_ASSIGN_OR_RETURN(auto platform, PlatformManager::PlatformWithName("Host")); - TF_ASSIGN_OR_RETURN(auto stream_exec, platform->GetUncachedExecutor(config)); + TF_ASSIGN_OR_RETURN(auto stream_exec, platform->ExecutorForDevice(0)); return stream_exec; } diff --git a/third_party/xla/xla/stream_executor/kernel_factory.h b/third_party/xla/xla/stream_executor/stream_finder.cc similarity index 51% rename from third_party/xla/xla/stream_executor/kernel_factory.h rename to third_party/xla/xla/stream_executor/stream_finder.cc index 24e594ed89d10e..e9cdcf02c8c65c 100644 --- a/third_party/xla/xla/stream_executor/kernel_factory.h +++ b/third_party/xla/xla/stream_executor/stream_finder.cc @@ -13,32 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ -#define XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ - -#include +#include "xla/stream_executor/stream_finder.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" -#include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_spec.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" #include "xla/stream_executor/stream_executor.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" namespace stream_executor { -// Creates Kernels from kernel specifications. -class KernelFactory { - public: - // Creates kernel on a given executor from a given kernel specification. - static inline absl::StatusOr> Create( - StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { - TF_ASSIGN_OR_RETURN(auto kernel, executor->CreateKernel()); - TF_RETURN_IF_ERROR(executor->GetKernel(spec, kernel.get())); - return kernel; +absl::StatusOr FindStream(Platform* platform, void* gpu_stream) { + int number_devices = platform->VisibleDeviceCount(); + for (int i = 0; i < number_devices; ++i) { + auto stream_executor = platform->FindExisting(i); + if (!stream_executor.ok()) { + continue; + } + Stream* found_stream = nullptr; + if ((found_stream = (*stream_executor)->FindAllocatedStream(gpu_stream)) != + nullptr) { + return found_stream; + } } -}; + return absl::NotFoundError("Stream not found"); +} } // namespace stream_executor - -#endif // XLA_STREAM_EXECUTOR_KERNEL_FACTORY_H_ diff --git a/third_party/xla/xla/stream_executor/stream_finder.h b/third_party/xla/xla/stream_executor/stream_finder.h new file mode 100644 index 00000000000000..0503d3fbe4e641 --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_finder.h @@ -0,0 +1,30 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ +#define XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ + +#include "absl/status/statusor.h" +#include "xla/stream_executor/platform.h" +#include "xla/stream_executor/stream.h" + +namespace stream_executor { + +// Returns a Stream given the gpu_stream handle. +absl::StatusOr FindStream(Platform* platform, void* gpu_stream); + +} // namespace stream_executor + +#endif // XLA_STREAM_EXECUTOR_STREAM_FINDER_H_ diff --git a/third_party/xla/xla/stream_executor/stream_finder_test.cc b/third_party/xla/xla/stream_executor/stream_finder_test.cc new file mode 100644 index 00000000000000..6bb8ac86779519 --- /dev/null +++ b/third_party/xla/xla/stream_executor/stream_finder_test.cc @@ -0,0 +1,80 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/stream_executor/stream_finder.h" + +#include "absl/status/status.h" +#include "xla/stream_executor/mock_platform.h" +#include "xla/stream_executor/mock_stream.h" +#include "xla/stream_executor/mock_stream_executor.h" +#include "xla/test.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +using testing::Return; +namespace stream_executor { +namespace { + +TEST(StreamFinderTest, FindStreamFailsWithNoExecutors) { + MockStreamExecutor stream_executor; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(0)); + EXPECT_FALSE(FindStream(&platform, nullptr).ok()); +} + +TEST(StreamFinderTest, FindStreamFailsWithNoMatchingStream) { + MockStreamExecutor stream_executor; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(1)); + EXPECT_CALL(platform, FindExisting(0)).WillOnce(Return(&stream_executor)); + void *gpu_stream = reinterpret_cast(0x1234); + EXPECT_CALL(stream_executor, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(nullptr)); + EXPECT_FALSE(FindStream(&platform, gpu_stream).ok()); +} + +TEST(StreamFinderTest, FindStreamSucceeds) { + MockStreamExecutor stream_executor0; + MockStreamExecutor stream_executor1; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(2)); + EXPECT_CALL(platform, FindExisting(0)).WillOnce(Return(&stream_executor0)); + EXPECT_CALL(platform, FindExisting(1)).WillOnce(Return(&stream_executor1)); + void *gpu_stream = reinterpret_cast(0x1234); + MockStream stream; + EXPECT_CALL(stream_executor0, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(nullptr)); + EXPECT_CALL(stream_executor1, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(&stream)); + TF_ASSERT_OK_AND_ASSIGN(auto found_stream, FindStream(&platform, gpu_stream)); + EXPECT_EQ(found_stream, &stream); +} + +TEST(StreamFinderTest, OnlyExecutor1Exists) { + MockStreamExecutor stream_executor1; + MockPlatform platform; + EXPECT_CALL(platform, VisibleDeviceCount()).WillOnce(Return(2)); + EXPECT_CALL(platform, FindExisting(0)) + .WillRepeatedly(Return(absl::NotFoundError("Nope"))); + EXPECT_CALL(platform, FindExisting(1)).WillOnce(Return(&stream_executor1)); + void *gpu_stream = reinterpret_cast(0x1234); + MockStream stream; + EXPECT_CALL(stream_executor1, FindAllocatedStream(gpu_stream)) + .WillOnce(Return(&stream)); + TF_ASSERT_OK_AND_ASSIGN(auto found_stream, FindStream(&platform, gpu_stream)); + EXPECT_EQ(found_stream, &stream); +} +} // namespace +} // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/stream_test.cc b/third_party/xla/xla/stream_executor/stream_test.cc index 473472d76fa021..ef5294ebe4260b 100644 --- a/third_party/xla/xla/stream_executor/stream_test.cc +++ b/third_party/xla/xla/stream_executor/stream_test.cc @@ -29,31 +29,30 @@ namespace { class StreamTest : public ::testing::Test { protected: - std::unique_ptr NewStreamExecutor() { + StreamExecutor* NewStreamExecutor() { Platform* platform = PlatformManager::PlatformWithName("Host").value(); - StreamExecutorConfig config(/*ordinal=*/0); - return platform->GetUncachedExecutor(config).value(); + return platform->ExecutorForDevice(/*ordinal=*/0).value(); } }; TEST_F(StreamTest, InitOk) { - std::unique_ptr executor = NewStreamExecutor(); + StreamExecutor* executor = NewStreamExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); } TEST_F(StreamTest, InitWithIntPriorityOk) { - std::unique_ptr executor = NewStreamExecutor(); + StreamExecutor* executor = NewStreamExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(1)); } TEST_F(StreamTest, InitWithStreamPriorityOk) { - std::unique_ptr executor = NewStreamExecutor(); + StreamExecutor* executor = NewStreamExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream(StreamPriority::Highest)); } TEST_F(StreamTest, OneSubStream) { - std::unique_ptr executor = NewStreamExecutor(); + StreamExecutor* executor = NewStreamExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get and return a sub-stream. Sub-streams are always initialized. @@ -72,7 +71,7 @@ TEST_F(StreamTest, OneSubStream) { } TEST_F(StreamTest, TwoSubStreams) { - std::unique_ptr executor = NewStreamExecutor(); + StreamExecutor* executor = NewStreamExecutor(); TF_ASSERT_OK_AND_ASSIGN(auto stream, executor->CreateStream()); // Get two sub-streams. diff --git a/third_party/xla/xla/stream_executor/sycl/BUILD b/third_party/xla/xla/stream_executor/sycl/BUILD index 8745be98d5a4d6..86c00a09d67028 100644 --- a/third_party/xla/xla/stream_executor/sycl/BUILD +++ b/third_party/xla/xla/stream_executor/sycl/BUILD @@ -48,6 +48,7 @@ cc_library( "//xla/stream_executor/gpu:gpu_driver_header", "//xla/stream_executor/gpu:gpu_executor_header", "//xla/stream_executor/gpu:gpu_collectives_header", + "@local_tsl//tsl/platform:errors", ]), alwayslink = True, # Registers itself with the PlatformManager. ) diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc index 876775b5d3df05..a78e104670bf21 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.cc @@ -35,65 +35,16 @@ limitations under the License. #include "xla/stream_executor/platform/initialize.h" #include "xla/stream_executor/platform_manager.h" #include "xla/stream_executor/sycl/sycl_platform_id.h" +#include "tsl/platform/errors.h" #include "tsl/platform/status.h" namespace stream_executor { namespace gpu { -SyclPlatform::SyclPlatform() - : name_("SYCL"), min_numa_node_(0), limit_numa_node_(0) {} +SyclPlatform::SyclPlatform() : name_("SYCL") {} SyclPlatform::~SyclPlatform() {} -// Due to legacy issues in user code, we can't currently call InspectNumaNodes -// at module initialization time, because non-GPU programs still include this -// plugin via various methods, so instead, it has to be init-on-reference. -void SyclPlatform::InspectNumaNodes() { - // To get NUMA node information, we need to create all executors, so we can - // examine their device descriptions to see their bus assignments. - static absl::once_flag once; - absl::call_once(once, [&] { - for (int i = 0; i < VisibleDeviceCount(); i++) { - StreamExecutor* exec = *ExecutorForDevice(i); - if (i == 0) { - // NUMA nodes may not start at 0, so set the minimum node based on the - // first executor we see. - min_numa_node_ = exec->GetDeviceDescription().numa_node(); - limit_numa_node_ = min_numa_node_ + 1; - } else { - min_numa_node_ = - std::min(min_numa_node_, exec->GetDeviceDescription().numa_node()); - limit_numa_node_ = std::max( - limit_numa_node_, exec->GetDeviceDescription().numa_node() + 1); - } - } - }); -} - -int SyclPlatform::BusCount() { - InspectNumaNodes(); - return limit_numa_node_ - min_numa_node_; -} - -int SyclPlatform::DeviceToBus(int device_ordinal) { - StreamExecutor* exec = *ExecutorForDevice(device_ordinal); - return exec->GetDeviceDescription().numa_node() - min_numa_node_; -} - -absl::StatusOr SyclPlatform::FirstExecutorForBus( - int bus_ordinal) { - InspectNumaNodes(); - CHECK_LT(bus_ordinal, BusCount()) << "bus ordinal out of available range"; - for (int i = 0; i < VisibleDeviceCount(); i++) { - if (DeviceToBus(i) == bus_ordinal) { - return *ExecutorForDevice(i); - } - } - - return absl::NotFoundError( - absl::StrFormat("Executor for bus %d not found.", bus_ordinal)); -} - Platform::Id SyclPlatform::id() const { return sycl::kSyclPlatformId; } int SyclPlatform::VisibleDeviceCount() const { @@ -113,44 +64,22 @@ SyclPlatform::DescriptionForDevice(int ordinal) const { } absl::StatusOr SyclPlatform::ExecutorForDevice(int ordinal) { - StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); -} - -absl::StatusOr SyclPlatform::GetExecutor( - const StreamExecutorConfig& config) { - if (config.gpu_stream) { - // If the GPU stream was provided, it's not possible to get-or-create a - // stream with a required pointer: so we are looking for previously - // allocated streams. - return executor_cache_.Get(config); - } return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -SyclPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = std::make_unique(this, config.ordinal); - auto init_status = executor->Init(); - if (!init_status.ok()) { - return absl::InternalError(absl::StrFormat( - "failed initializing StreamExecutor for SYCL device ordinal %d: %s", - config.ordinal, init_status.ToString())); - } - +SyclPlatform::GetUncachedExecutor(int ordinal { + auto executor = std::make_unique(this, ordinal); + TF_RETURN_IF_ERROR(executor->Init()); return std::move(executor); } } // namespace gpu static void InitializeSyclPlatform() { - // Disabling leak checking, PlatformManager does not destroy its - // registered platforms. - - std::unique_ptr platform(new gpu::SyclPlatform); - TF_CHECK_OK(PlatformManager::RegisterPlatform(std::move(platform))); + TF_CHECK_OK( + PlatformManager::RegisterPlatform(std::make_unique())); } } // namespace stream_executor diff --git a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h index 0c687f4eee1179..61f0eb3d5372b9 100644 --- a/third_party/xla/xla/stream_executor/sycl/sycl_platform.h +++ b/third_party/xla/xla/stream_executor/sycl/sycl_platform.h @@ -41,16 +41,6 @@ class SyclPlatform : public Platform { SyclPlatform(); ~SyclPlatform() override; - // SyclPlatform-specific functionality - // Returns the number of distinct buses / NUMA nodes on the machine. - int BusCount(); - - // Returns the bus/NUMA node for the specified device ordinal. - int DeviceToBus(int device_ordinal); - - // Returns the lowest-ordinal-number StreamExecutor on the specified bus. - absl::StatusOr FirstExecutorForBus(int bus_ordinal); - // Platform interface implementation: // Returns the same value as kSyclPlatform above. Platform::Id id() const override; @@ -65,15 +55,12 @@ class SyclPlatform : public Platform { absl::StatusOr ExecutorForDevice(int ordinal) override; - absl::StatusOr GetExecutor( - const StreamExecutorConfig& config) override; - - absl::StatusOr> GetUncachedExecutor( - const StreamExecutorConfig& config) override; - private: - // Determines the number of NUMA nodes and the assignment of executor to each. - void InspectNumaNodes(); + // Returns a device constructed with ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + absl::StatusOr> GetUncachedExecutor( + int ordinal) override; // This platform's name. std::string name_; @@ -81,15 +68,6 @@ class SyclPlatform : public Platform { // Cache of created executors. ExecutorCache executor_cache_; - // The smallest NUMA node value for any device managed by this machine - // manager. Used, along with limit_numa_node_, to convert NUMA nodes into bus - // ordinals. The NUMA node space occupied by GPUs is assumed to be dense. - int min_numa_node_; - - // Larger than the NUMA node value for any device managed by this machine - // manager. - int limit_numa_node_; - SyclPlatform(const SyclPlatform&) = delete; void operator=(const SyclPlatform&) = delete; }; diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h index 04f09bedbc92de..6331524142cc7f 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_decl.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_decl.h @@ -64,7 +64,6 @@ typedef struct TpuSerializedProto { typedef struct SE_PlatformId { void* id; // aka stream_executor::Platform::Id } SE_PlatformId; -typedef struct SE_StreamExecutorConfig SE_StreamExecutorConfig; typedef TF_Status* (*SE_StatusCallback)(void*); typedef struct SE_DeviceMemoryBase { diff --git a/third_party/xla/xla/stream_executor/tpu/c_api_defn.h b/third_party/xla/xla/stream_executor/tpu/c_api_defn.h index 59ecd662196daf..2d4f945396ca0b 100644 --- a/third_party/xla/xla/stream_executor/tpu/c_api_defn.h +++ b/third_party/xla/xla/stream_executor/tpu/c_api_defn.h @@ -49,10 +49,6 @@ struct SE_Event { std::unique_ptr event; }; -struct SE_StreamExecutorConfig { - stream_executor::StreamExecutorConfig config; -}; - // Ignored -- these are just used to enforce the interface types struct XLA_TransferManager {}; struct XLA_ComputationPlacer {}; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h index d5b719787f4e6a..85646afbb68762 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/functional/any_invocable.h" @@ -90,8 +91,7 @@ class TpuExecutor : public tensorflow::tpu::TpuExecutorInterface { const override; absl::StatusOr> CreateStream( - std::optional> priority = - std::nullopt) override; + std::optional> priority) override; absl::StatusOr> CreateEvent() override; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h index 1d0087a02fa98b..a415204e85f592 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_c_api.h @@ -28,8 +28,7 @@ SE_Platform* TpuPlatform_New(); void TpuPlatform_Free(SE_Platform* platform); void TpuPlatform_Initialize(SE_Platform* platform, TF_Status* status); bool TpuPlatform_Initialized(SE_Platform* platform); -SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, - SE_StreamExecutorConfig* config, +SE_StreamExecutor* TpuPlatform_GetExecutor(SE_Platform* platform, int ordinal, TF_Status* status); SE_PlatformId TpuPlatform_Id(SE_Platform* platform); int64_t TpuPlatform_VisibleDeviceCount(SE_Platform* platform); @@ -132,10 +131,6 @@ const char* TpuStatus_Message(TF_Status* status); int TpuStatus_Code(TF_Status* status); bool TpuStatus_Ok(TF_Status* status); -SE_StreamExecutorConfig* TpuStreamExecutorConfig_Default(); -void TpuStreamExecutorConfig_SetOrdinal(SE_StreamExecutorConfig*, int ordinal); -void TpuStreamExecutorConfig_Free(SE_StreamExecutorConfig*); - SE_DeviceDescription* TpuDeviceDescription_New(); void TpuDeviceDescription_Free(SE_DeviceDescription* description); void TpuExecutor_CreateDeviceDescription(SE_StreamExecutor* executor, @@ -417,10 +412,6 @@ struct TfTpu_ExecutorApiFn { TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Code); TFTPU_ADD_FN_IN_STRUCT(TpuStatus_Ok); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Default); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_SetOrdinal); - TFTPU_ADD_FN_IN_STRUCT(TpuStreamExecutorConfig_Free); - TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_New); TFTPU_ADD_FN_IN_STRUCT(TpuDeviceDescription_Free); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc index 1b30487f03ffa2..5bc6a8ac9c4086 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_executor_init_fns.inc @@ -60,10 +60,6 @@ absl::Status SetExecutorStructFn( TFTPU_SET_FN(executor_fn, TpuStatus_Code); TFTPU_SET_FN(executor_fn, TpuStatus_Ok); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Default); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_SetOrdinal); - TFTPU_SET_FN(executor_fn, TpuStreamExecutorConfig_Free); - TFTPU_SET_FN(executor_fn, TpuDeviceDescription_New); TFTPU_SET_FN(executor_fn, TpuDeviceDescription_Free); diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc index 16efed10f42179..5aaf8a75e94146 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.cc @@ -77,32 +77,23 @@ int TpuPlatform::VisibleDeviceCount() const { ->TpuPlatform_VisibleDeviceCountFn(platform_); } -absl::StatusOr<::stream_executor::StreamExecutor*> TpuPlatform::GetExecutor( - const ::stream_executor::StreamExecutorConfig& config) { +absl::StatusOr<::stream_executor::StreamExecutor*> +TpuPlatform::ExecutorForDevice(int ordinal) { return executor_cache_.GetOrCreate( - config, [&]() { return GetUncachedExecutor(config); }); + ordinal, [this, ordinal]() { return GetUncachedExecutor(ordinal); }); } absl::StatusOr> -TpuPlatform::GetUncachedExecutor( - const ::stream_executor::StreamExecutorConfig& config) { - SE_StreamExecutorConfig* c_config = stream_executor::tpu::ExecutorApiFn() - ->TpuStreamExecutorConfig_DefaultFn(); - - stream_executor::tpu::ExecutorApiFn()->TpuStreamExecutorConfig_SetOrdinalFn( - c_config, config.ordinal); - +TpuPlatform::GetUncachedExecutor(int ordinal) { StatusHelper status; SE_StreamExecutor* executor = stream_executor::tpu::ExecutorApiFn()->TpuPlatform_GetExecutorFn( - platform_, c_config, status.c_status); - stream_executor::tpu::ExecutorApiFn()->TpuStreamExecutorConfig_FreeFn( - c_config); + platform_, ordinal, status.c_status); if (!status.ok()) { return status.status(); } return std::make_unique(this, executor, - config.ordinal); + ordinal); } ::stream_executor::Platform::Id TpuPlatform::id() const { diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h index f2ea4f9dba8b2e..8eb6f19b7cdd6b 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform.h +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform.h @@ -82,19 +82,13 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { } absl::StatusOr<::stream_executor::StreamExecutor*> ExecutorForDevice( + int ordinal) override; + + absl::StatusOr<::stream_executor::StreamExecutor*> FindExisting( int ordinal) override { - stream_executor::StreamExecutorConfig config; - config.ordinal = ordinal; - return GetExecutor(config); + return executor_cache_.Get(ordinal); } - absl::StatusOr<::stream_executor::StreamExecutor*> GetExecutor( - const ::stream_executor::StreamExecutorConfig& config) override; - - absl::StatusOr> - GetUncachedExecutor( - const ::stream_executor::StreamExecutorConfig& config) override; - StreamMap* stream_map() { return &stream_map_; } void InsertEvent(stream_executor::Event* key, SE_Event* val); @@ -118,6 +112,12 @@ class TpuPlatform : public ::tensorflow::tpu::TpuPlatformInterface { absl::Mutex& mutex() { return event_map_mu_; } private: + // Returns a device constructed with the ordinal without + // looking in or storing to the Platform's executor cache. + // Ownership IS transferred to the caller. + absl::StatusOr> + GetUncachedExecutor(int ordinal); + mutable SE_Platform* platform_; std::string name_; stream_executor::ExecutorCache executor_cache_; diff --git a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc index c7df6619342830..63c83e5696cfc5 100644 --- a/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc +++ b/third_party/xla/xla/stream_executor/tpu/tpu_platform_interface.cc @@ -42,6 +42,7 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, << status_or_tpu_platform.status(); return nullptr; } + LOG(INFO) << "Platform manager status: " << status_or_tpu_platform.status(); // Use any other registered TPU platform. auto status_or_other_tpu_platforms = @@ -72,12 +73,14 @@ TpuPlatformInterface* GetRegisteredPlatformStatic(bool initialize_platform, --tries_left; if (tries_left <= 0) { - LOG(INFO) << "No TPU platform found."; + LOG(INFO) << "No TPU platform found. Platform manager status: " + << status_or_other_tpu_platforms.status(); return nullptr; } LOG(INFO) << "No TPU platform registered. Waiting 1 second and trying again... (" - << tries_left << " tries left)"; + << tries_left << " tries left) Platform manager status: " + << status_or_other_tpu_platforms.status(); tsl::Env::Default()->SleepForMicroseconds(1000000); // 1 second return GetRegisteredPlatformStatic(initialize_platform, tries_left); } diff --git a/third_party/xla/xla/stream_executor/typed_kernel_factory.h b/third_party/xla/xla/stream_executor/typed_kernel_factory.h index 21600d128ea758..65ed14883152e5 100644 --- a/third_party/xla/xla/stream_executor/typed_kernel_factory.h +++ b/third_party/xla/xla/stream_executor/typed_kernel_factory.h @@ -25,7 +25,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/stream_executor/kernel.h" -#include "xla/stream_executor/kernel_factory.h" #include "xla/stream_executor/kernel_spec.h" #include "xla/stream_executor/stream_executor.h" #include "tsl/platform/statusor.h" @@ -41,7 +40,7 @@ class TypedKernelFactory { static absl::StatusOr> Create( StreamExecutor *executor, const MultiKernelLoaderSpec &spec) { TF_ASSIGN_OR_RETURN(std::unique_ptr kernel, - KernelFactory::Create(executor, spec)); + executor->LoadKernel(spec)); return TypedKernel(std::move(kernel)); } diff --git a/third_party/xla/xla/tests/BUILD b/third_party/xla/xla/tests/BUILD index 3fe2e14f5d9259..10cea42f7f5cc1 100644 --- a/third_party/xla/xla/tests/BUILD +++ b/third_party/xla/xla/tests/BUILD @@ -201,6 +201,7 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/hlo/ir:hlo", "//xla/hlo/ir:hlo_module_group", + "//xla/hlo/utils:hlo_query", "//xla/service:backend", "//xla/service:computation_layout", "//xla/service:hlo_module_util", @@ -213,11 +214,11 @@ cc_library( "//xla/service:platform_util", "//xla/stream_executor", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], @@ -299,8 +300,8 @@ cc_library( ":filecheck", "//xla/service:llvm_compiler", "//xla/service/llvm_ir:llvm_util", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], ) @@ -413,7 +414,7 @@ xla_test( "//xla/service:backend", "//xla/service:executable", "//xla/stream_executor:stream_executor_memory_allocator", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -543,8 +544,8 @@ xla_test( "//xla/client/lib:arithmetic", "//xla/service:platform_util", "//xla/stream_executor:stream_executor_memory_allocator", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -572,11 +573,11 @@ xla_test( "//xla/client:xla_computation", "//xla/service:platform_util", "//xla/service:stream_pool", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/platform:test", ], @@ -664,7 +665,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client/lib:arithmetic", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], ) @@ -1162,7 +1163,7 @@ xla_test( "//xla/client:local_client", "//xla/client:xla_builder", "//xla/client/lib:constants", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:test", ], @@ -1614,8 +1615,8 @@ xla_test( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], ) @@ -1666,12 +1667,12 @@ xla_test( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/client/lib:arithmetic", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], ) @@ -1718,9 +1719,9 @@ xla_test_library( "//xla/client:xla_builder", "//xla/client:xla_computation", "//xla/client/lib:arithmetic", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", ], @@ -1783,12 +1784,16 @@ xla_test( ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", - "//xla:array2d", + "//xla:array3d", + "//xla:array4d", "//xla:literal", + "//xla:literal_util", + "//xla:shape_util", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", "//xla/hlo/ir:hlo", - "@local_tsl//tsl/platform:protobuf", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", "@local_tsl//tsl/platform:test", ], ) @@ -1901,6 +1906,8 @@ xla_test( ":test_macros_header", ":test_utils", ":xla_internal_test_main", # fixdeps: keep + "//xla:array2d", + "//xla:array3d", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -1911,15 +1918,19 @@ xla_test( "//xla/ffi:attribute_map", "//xla/ffi:ffi_api", "//xla/hlo/ir:hlo", + "//xla/service", "//xla/service:custom_call_status", "//xla/service:custom_call_target_registry", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", + "@eigen_archive//:eigen3", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], @@ -2125,6 +2136,9 @@ xla_test( xla_test( name = "dynamic_reshape_test", srcs = ["dynamic_reshape_test.cc"], + backend_tags = { + "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error. + }, disabled_backends = ["interpreter"], tags = ["test_xla_cpu_thunks"], deps = [ @@ -2229,7 +2243,6 @@ xla_test( tags = ["test_xla_cpu_thunks"], deps = [ ":client_library_test_base", - ":literal_test_util", ":test_macros_header", ":xla_internal_test_main", "//xla:shape_util", @@ -2237,7 +2250,6 @@ xla_test( "//xla:xla_data_proto_cc", "//xla/client:local_client", "//xla/client:xla_builder", - "//xla/stream_executor", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@local_tsl//tsl/platform:ml_dtypes", @@ -2273,15 +2285,9 @@ xla_test( srcs = ["collective_ops_test.cc"], args = ["--xla_force_host_platform_device_count=4"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], "cpu": [ "notsan", @@ -2305,9 +2311,9 @@ xla_test( "//xla/service:computation_placer", "//xla/service:executable", "//xla/service:hlo_module_config", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:blocking_counter", "@local_tsl//tsl/platform:env", ], @@ -2318,14 +2324,9 @@ xla_test( srcs = ["collective_pipeline_parallelism_test.cc"], args = ["--xla_force_host_platform_device_count=4"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and Forge only supports - # single-GPU tests. Guitar skips "manual" tests unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], "cpu": [ "notsan", @@ -2364,15 +2365,9 @@ xla_test( name = "collective_ops_e2e_test", srcs = ["collective_ops_e2e_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], }, backends = [ @@ -2387,8 +2382,12 @@ xla_test( "//xla:literal", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_matchers", + "//xla/service:hlo_module_config", "//xla/service/gpu:backend_configs_cc", "//xla/stream_executor", + "@com_google_absl//absl/strings:string_view", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) @@ -2416,15 +2415,9 @@ xla_test( name = "replicated_io_feed_test", srcs = ["replicated_io_feed_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple GPUs, and - # Forge only supports single-GPU tests. Guitar skips "manual" tests - # unless they're also tagged "guitar". "gpu": [ - "guitar", - "manual", "multi_gpu", "no_oss", - "notap", ], }, backends = ["gpu"], @@ -2437,7 +2430,7 @@ xla_test( "//xla:shape_util", "//xla:test", "//xla:test_helpers", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -2514,9 +2507,9 @@ xla_test( "//xla:test", "//xla:test_helpers", "//xla/service:hlo_proto_cc", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", @@ -2544,10 +2537,10 @@ xla_test( "//xla/client:xla_computation", "//xla/client/lib:arithmetic", "//xla/client/lib:prng", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], @@ -2571,9 +2564,9 @@ xla_test( "//xla/client:global_data", "//xla/client:xla_builder", "//xla/client:xla_computation", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -2745,9 +2738,9 @@ xla_test( "//xla/client:local_client", "//xla/hlo/ir:hlo", "//xla/service:hlo_runner", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_benchmark", @@ -2799,7 +2792,10 @@ xla_test( size = "large", srcs = ["local_client_execute_test.cc"], shard_count = 30, - tags = ["optonly"], + tags = [ + "optonly", + "test_xla_cpu_thunks", + ], deps = [ ":literal_test_util", ":local_client_test_base", @@ -2841,7 +2837,7 @@ xla_test( ":local_client_test_base", ":test_macros_header", ":xla_internal_test_main", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", ], ) @@ -3010,9 +3006,9 @@ xla_test( "//xla:shape_util", "//xla/client:xla_builder", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_set", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -3047,8 +3043,8 @@ xla_cc_test( "//xla/client:xla_builder", "//xla/service:cpu_plugin", "//xla/stream_executor:platform_manager", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", ], @@ -3090,6 +3086,9 @@ xla_test( xla_test( name = "set_dimension_size_test", srcs = ["set_dimension_size_test.cc"], + backend_tags = { + "gpu": ["notsan"], # TODO(b/345034145): Fix tsan error. + }, tags = ["test_xla_cpu_thunks"], deps = [ ":hlo_test_base", @@ -3128,9 +3127,9 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:math", "//xla/client/lib:matrix", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -3154,8 +3153,8 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:arithmetic", "//xla/client/lib:matrix", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/status:statusor", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -3199,6 +3198,7 @@ xla_test( "//xla:types", "//xla/hlo/ir:hlo", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/tests/buffer_donation_test.cc b/third_party/xla/xla/tests/buffer_donation_test.cc index 732d562871afa9..35d9c648846892 100644 --- a/third_party/xla/xla/tests/buffer_donation_test.cc +++ b/third_party/xla/xla/tests/buffer_donation_test.cc @@ -30,7 +30,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/build_defs.bzl b/third_party/xla/xla/tests/build_defs.bzl index 8a42642c8e5ff5..8d73f8969255d9 100644 --- a/third_party/xla/xla/tests/build_defs.bzl +++ b/third_party/xla/xla/tests/build_defs.bzl @@ -31,6 +31,8 @@ GPU_BACKENDS = NVIDIA_GPU_BACKENDS + AMD_GPU_DEFAULT_BACKENDS GPU_DEFAULT_BACKENDS = NVIDIA_GPU_DEFAULT_BACKENDS +DEFAULT_DISABLED_BACKENDS = [] + _ALL_BACKENDS = ["cpu", "interpreter"] + NVIDIA_GPU_BACKENDS + AMD_GPU_DEFAULT_BACKENDS + list(plugins.keys()) # buildifier: disable=function-docstring @@ -175,7 +177,7 @@ def xla_test( deps, xla_test_library_deps = [], backends = [], - disabled_backends = [], + disabled_backends = DEFAULT_DISABLED_BACKENDS, real_hardware_only = False, # @unused, all backends are real hardware. args = [], tags = [], @@ -281,6 +283,8 @@ def xla_test( ] if backend in NVIDIA_GPU_BACKENDS: this_backend_tags += tf_gpu_tests_tags() + if backend in AMD_GPU_DEFAULT_BACKENDS: + this_backend_tags.append("gpu") this_backend_copts.append("-DXLA_TEST_BACKEND_GPU=1") elif backend == "interpreter": backend_deps += [ @@ -320,8 +324,23 @@ def xla_test( # b/317293391. For this reason, if we would create an empty `test_suite`, # instead create a `cc_test` with no srcs that links against `main` to have # more predictable behavior that avoids bugs. + # + # Due to b/317293391, we also mark the test suite `manual`, so that wild card builds + # like in the XLA CI won't try to build the test suite target. Instead the wild card + # build will build the individual test targets and therefore respect the tags on each + # individual test target. + # Example: Assume we have an `xla_test(name=my_test)` in `//xla/service/gpu` with backends `cpu` + # and `gpu`. This generates two test targets `//xla/service/gpu:my_test_{cpu|gpu}`. The latter + # has a tag `gpu`. + # + # - `bazel test --test_tag_filters=-gpu //xla/service/gpu/...` will only run the cpu test. + # - `bazel test //xla/service/gpu/...` will run both tests. + # - `bazel test //xla/service/gpu:my_test` will run both tests. + # Caveat: + # - `bazel test --test_tag_filters=-gpu //xla/service/gpu:my_test` will run both tests and + # not respect the tag filter - but it's way better than the previous behavoir. if test_names: - native.test_suite(name = name, tags = tags, tests = test_names) + native.test_suite(name = name, tags = tags + ["manual"], tests = test_names) else: native.cc_test(name = name, deps = ["@local_tsl//tsl/platform:test_main"]) diff --git a/third_party/xla/xla/tests/cholesky_test.cc b/third_party/xla/xla/tests/cholesky_test.cc index a0c3d4227c7246..9215319bbf8e40 100644 --- a/third_party/xla/xla/tests/cholesky_test.cc +++ b/third_party/xla/xla/tests/cholesky_test.cc @@ -28,8 +28,8 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/client_library_test_base.cc b/third_party/xla/xla/tests/client_library_test_base.cc index 07f46e75715d68..743db05f93f73b 100644 --- a/third_party/xla/xla/tests/client_library_test_base.cc +++ b/third_party/xla/xla/tests/client_library_test_base.cc @@ -98,6 +98,12 @@ ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) ->set_xla_hlo_evaluator_use_fast_path(true); } +std::string ClientLibraryTestBase::SuiteName() const { + return ::testing::UnitTest::GetInstance() + ->current_test_info() + ->test_suite_name(); +} + std::string ClientLibraryTestBase::TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } diff --git a/third_party/xla/xla/tests/client_library_test_base.h b/third_party/xla/xla/tests/client_library_test_base.h index 8246d851052429..800f14c4c014ad 100644 --- a/third_party/xla/xla/tests/client_library_test_base.h +++ b/third_party/xla/xla/tests/client_library_test_base.h @@ -71,6 +71,9 @@ class ClientLibraryTestBase : public ManifestCheckingTest { ClientLibraryTestBase(se::Platform* platform, const LocalClientOptions& client_options); + // Returns the name of the suite currently being run. + std::string SuiteName() const; + // Returns the name of the test currently being run. std::string TestName() const; diff --git a/third_party/xla/xla/tests/collective_ops_e2e_test.cc b/third_party/xla/xla/tests/collective_ops_e2e_test.cc index f1d1c78d28bb61..1e399127318242 100644 --- a/third_party/xla/xla/tests/collective_ops_e2e_test.cc +++ b/third_party/xla/xla/tests/collective_ops_e2e_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" @@ -25,12 +26,15 @@ limitations under the License. #include "xla/hlo/utils/hlo_matchers.h" #include "xla/literal.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/hlo_module_config.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream_executor.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" namespace xla { namespace { @@ -70,6 +74,31 @@ class CollectiveOpsTestE2E : public HloTestBase { GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); } + void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text, + const DebugOptions& options) { + if (!HasFp8Support()) { + return; + } + const int64_t kNumReplicas = 1; + const int64_t kNumPartitions = 4; + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + config.set_debug_options(options); + config.set_num_partitions(kNumPartitions); + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_text, config)); + + TF_ASSERT_OK_AND_ASSIGN(auto executable, + CreateExecutable(std::move(module), + /*run_hlo_passes=*/true)); + EXPECT_TRUE(executable->has_module()); + HloInstruction* gemm_op = + FindInstruction(&executable->module(), HloOpcode::kCustomCall); + EXPECT_THAT(gemm_op, NotNull()); + EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + } + absl::StatusOr> ExecuteReplicated(Executable* executable, int64_t num_replicas) { DeviceAssignment device_assignment = MakeDeviceAssn(num_replicas); @@ -154,6 +183,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllReduce) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_reduce = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -190,6 +220,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGather) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -231,6 +262,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllGatherMixedTypes) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_gather = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, @@ -268,6 +300,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectiveBroadcast) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_collective_broadcast = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -300,6 +333,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncCollectivePermute) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_collective_permute = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -343,6 +377,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncReduceScatter) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_reduce_scatter = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -376,6 +411,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithSplitDim) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -420,6 +456,7 @@ XLA_TEST_P(AsyncCollectiveOps, AsyncAllToAllWithoutSplitDim) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const bool enable_async_all_to_all = GetParam(); TF_ASSERT_OK_AND_ASSIGN(auto executable, CreateExecutable(kModuleStr, kNumReplicas)); @@ -472,6 +509,7 @@ TEST_P(AsyncCollectiveOps, MatmulReplicated) { } )"; const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -592,6 +630,7 @@ TEST_F(CollectiveOpsTestE2E, WhileLoopReduceScatterCodeMotion) { )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); DebugOptions debug_options = GetDebugOptionsForTest(); debug_options.set_xla_gpu_enable_while_loop_reduce_scatter_code_motion(true); @@ -646,6 +685,7 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) { } )"; const int64_t kNumReplicas = 2; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -677,6 +717,7 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { absl::string_view hlo_text, bool disable_dot_merger = false) { const int64_t kNumReplicas = 1; const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas * kNumPartitions); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -813,6 +854,59 @@ ENTRY main.12 { // Custom Calls. CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); +} + +TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, + WindowedEinsumE2EAllGatherMultiConsumerF8) { + absl::string_view kModuleReplicatedStr = R"( +HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4 + +ENTRY main { + rhs = f8e4m3fn[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]} + lhs0 = f8e4m3fn[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]} + scale_rhs = bf16[] parameter(3) + scale_lhs0 = bf16[] parameter(4) + scale_rhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={} + scale_lhs0_bcast = bf16[48,192]{1,0} broadcast(scale_lhs0), dimensions={} + rhs_bf16 = bf16[2,16,48]{2,1,0} convert(rhs) + lhs0_bf16 = bf16[48,192]{1,0} convert(lhs0) + rhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16) + lhs0_scaled = bf16[48,192]{1,0} multiply(scale_lhs0_bcast, lhs0_bf16) + dot0 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + lhs1 = f8e4m3fn[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]} + scale_lhs1 = bf16[] parameter(5) + scale_lhs1_bcast = bf16[48,192]{1,0} broadcast(scale_lhs1), dimensions={} + lhs1_bf16 = bf16[48,192]{1,0} convert(lhs1) + lhs1_scaled = bf16[48,192]{1,0} multiply(scale_lhs1_bcast, lhs1_bf16) + dot1 = bf16[2,16,192]{2,1,0} dot(rhs_scaled, lhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0} + ROOT add.8 = bf16[2,16,192]{2,1,0} add(dot0, dot1) +} // main +)"; + + // Disable the dot merger pass which can prevent the creation of FP8 GEMM + // Custom Calls. + CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr, + /*disable_dot_merger=*/true); + + // Verify the creation of FP8 GEMM Custom Calls on Hopper and newer + // architectures. + DebugOptions opts = GetDebugOptionsForTest(); + opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0); + opts.set_xla_gpu_multi_streamed_windowed_einsum(true); + opts.set_xla_gpu_graph_min_graph_size(200); + opts.set_xla_gpu_enable_triton_gemm(false); + opts.add_xla_disable_hlo_passes("dot-merger"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, @@ -959,7 +1053,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; - const int64_t kNumPartitions = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -967,19 +1061,7 @@ ENTRY entry { opts.set_xla_gpu_run_post_layout_collective_pipeliner(true); opts.set_xla_gpu_enable_pipelined_collectives(true); opts.set_xla_gpu_enable_triton_gemm(false); - config.set_debug_options(opts); - config.set_num_partitions(kNumPartitions); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config)); - - TF_ASSERT_OK_AND_ASSIGN(auto executable, - CreateExecutable(std::move(module), - /*run_hlo_passes=*/true)); - EXPECT_TRUE(executable->has_module()); - HloInstruction* gemm_op = - FindInstruction(&executable->module(), HloOpcode::kCustomCall); - EXPECT_THAT(gemm_op, NotNull()); - EXPECT_EQ(gemm_op->custom_call_target(), "__cublas$lt$matmul$f8"); + CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts); } TEST_F(CollectiveOpsTestE2E, @@ -1052,6 +1134,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1085,6 +1168,7 @@ ENTRY entry { )"; const int64_t kNumReplicas = 1; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); const int64_t kNumPartitions = 4; HloModuleConfig config = diff --git a/third_party/xla/xla/tests/collective_ops_test.cc b/third_party/xla/xla/tests/collective_ops_test.cc index 0d8c3062bf2238..9cd874c9e03c13 100644 --- a/third_party/xla/xla/tests/collective_ops_test.cc +++ b/third_party/xla/xla/tests/collective_ops_test.cc @@ -34,28 +34,22 @@ limitations under the License. #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/blocking_counter.h" #include "tsl/platform/env.h" #include "tsl/platform/threadpool.h" +namespace xla { +namespace { + // Tests cross-GPU operations. // // Several tests requires at least four GPUs. For instructions on running this // within Google, see go/multi-gpu-unit-test. - -#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ - if (num_devices_ < x) { \ - GTEST_SKIP() << "Test requires at least " << x << " devices"; \ - } - -namespace xla { -namespace { - class CollectiveOpsTest : public HloTestBase { public: - CollectiveOpsTest() : num_devices_(backend().device_count()) { - VLOG(1) << "Running with " << num_devices_ << " devices"; + CollectiveOpsTest() { + VLOG(1) << "Running with " << num_devices() << " devices"; } protected: @@ -180,9 +174,6 @@ class CollectiveOpsTest : public HloTestBase { /*expected_value=*/to_literal({cast(-1), cast(-2), cast(-3)})); } } - - protected: - const int64_t num_devices_; }; // Returns the non-empty subsets of {0, 1, ..., n}. For example, @@ -370,7 +361,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduceOr_Pred) { XLA_TEST_F(CollectiveOpsTest, AllReduce_AllCombinations) { const int64_t kNumElems = 1024; - for (std::vector devices : PowerSetOfIota(num_devices_)) { + for (std::vector devices : PowerSetOfIota(num_devices())) { SCOPED_TRACE(absl::StrFormat("Running on devices {%s}", absl::StrJoin(devices, ", "))); @@ -494,7 +485,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_ThreeReplicaGroups) { // Test a prime number so it's not all powers of 2. const int64_t kNumElems = 137; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -541,7 +532,7 @@ XLA_TEST_F(CollectiveOpsTest, AllReduce_Degenerate) { } )"; static constexpr int kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -577,19 +568,19 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduce)) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/false)); - ASSERT_EQ(results.size(), num_devices_); + ASSERT_EQ(results.size(), num_devices()); // sum [0, num_devices) - uint32_t expected = num_devices_ * (num_devices_ - 1) / 2; - for (int i = 0; i < num_devices_; ++i) { + uint32_t expected = num_devices() * (num_devices() - 1) / 2; + for (int i = 0; i < num_devices(); ++i) { LiteralTestUtil::ExpectR0Equal(expected, results[i]); } } @@ -613,22 +604,22 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncAllReduceTwoOperands)) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleStr, config)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/false)); - ASSERT_EQ(results.size(), num_devices_); + ASSERT_EQ(results.size(), num_devices()); // sum [0, num_devices) - uint32_t expected0 = num_devices_ * (num_devices_ - 1) / 2; + uint32_t expected0 = num_devices() * (num_devices() - 1) / 2; // sum squares [0, num_devices) uint32_t expected1 = - num_devices_ * (num_devices_ - 1) * (2 * num_devices_ - 1) / 6; - for (int i = 0; i < num_devices_; ++i) { + num_devices() * (num_devices() - 1) * (2 * num_devices() - 1) / 6; + for (int i = 0; i < num_devices(); ++i) { std::vector replica_results = results[i].DecomposeTuple(); LiteralTestUtil::ExpectR0Equal(expected0, replica_results[0]); LiteralTestUtil::ExpectR0Equal(expected1, replica_results[1]); @@ -645,18 +636,18 @@ XLA_TEST_F(CollectiveOpsTest, ReplicaId) { )"; HloModuleConfig config = - GetModuleConfigForTest(/*replica_count=*/num_devices_); + GetModuleConfigForTest(/*replica_count=*/num_devices()); TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModuleStr)); TF_ASSERT_OK_AND_ASSIGN( std::vector results, ExecuteReplicated(std::move(module), absl::Span{}, - num_devices_, + num_devices(), /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), num_devices_); - for (uint32_t i = 0; i < num_devices_; ++i) { + ASSERT_EQ(results.size(), num_devices()); + for (uint32_t i = 0; i < num_devices(); ++i) { EXPECT_TRUE(LiteralTestUtil::Equal(LiteralUtil::CreateR0(i), results[i])); } } @@ -680,7 +671,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(CollectiveBroadcast_Simple)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -716,7 +707,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Simple) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -753,7 +744,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Degenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -789,7 +780,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_NotDegenerate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -826,7 +817,7 @@ XLA_TEST_F(CollectiveOpsTest, CollectivePermute_Rotate) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -864,7 +855,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AsyncCollectivePermute)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -906,7 +897,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_EmptyReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -952,7 +943,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_OrderedReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -992,7 +983,7 @@ XLA_TEST_F(CollectiveOpsTest, AllToAll_TwoReplicaGroups) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -1024,7 +1015,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_SplitDimension)) { } )"; const int64_t kNumReplicas = 4; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2003,7 +1994,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_Simple)) { )"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2083,7 +2074,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_TwoConcurrentChains)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2162,7 +2153,7 @@ XLA_TEST_F(CollectiveOpsTest, DISABLED_ON_CPU(SendRecv_ValidationAttr1)) { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); @@ -2263,7 +2254,7 @@ body { })"; const int64_t kNumReplicas = 2; - SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); diff --git a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc index a88b1000f4737a..ee844727b9c7f8 100644 --- a/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc +++ b/third_party/xla/xla/tests/collective_pipeline_parallelism_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include @@ -32,28 +33,18 @@ limitations under the License. #include "xla/tests/verified_hlo_module.h" #include "tsl/platform/statusor.h" +namespace xla { +namespace { + // Tests cross-GPU operations. // // Several tests requires at least four GPUs. For instructions on running this // within Google, see go/multi-gpu-unit-test. - -// TODO: Move this to hlo_test_base.h -#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ - if (num_devices_ < x) { \ - GTEST_SKIP() << "Test requires at least " << x << " devices"; \ - } - -namespace xla { -namespace { - class CollectivePipelineParallelismTest : public HloTestBase { public: - CollectivePipelineParallelismTest() : num_devices_(backend().device_count()) { - VLOG(1) << "Running with " << num_devices_ << " devices"; + CollectivePipelineParallelismTest() { + VLOG(1) << "Running with " << num_devices() << " devices"; } - - protected: - const int64_t num_devices_; }; XLA_TEST_F(CollectivePipelineParallelismTest, @@ -73,13 +64,13 @@ XLA_TEST_F(CollectivePipelineParallelismTest, iter = u32[] get-tuple-element(param), index=0 data = f32[2,2] get-tuple-element(param), index=1 weights = f32[2,2] get-tuple-element(param), index=2 - matmul = f32[2,2] dot(weights, data), lhs_contracting_dims={1}, - rhs_contracting_dims={0} - cp = f32[2,2] collective-permute(matmul), - source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + cp = f32[2,2] collective-permute(data), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + matmul = f32[2,2] dot(weights, cp), + lhs_contracting_dims={1}, rhs_contracting_dims={0} iter_increment = u32[] constant(1) next_iter = u32[] add(iter, iter_increment) - ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, cp, weights) + ROOT result = (u32[], f32[2,2], f32[2,2]) tuple(next_iter, matmul, weights) } ENTRY test_computation { @@ -126,26 +117,59 @@ XLA_TEST_F(CollectivePipelineParallelismTest, LiteralTestUtil::ExpectR2Equal({{0, 0}, {1, 1}}, results[3]); } -// Naive implementation of pipeline parallelism: -// - 4 devices -// - 4 microbatches -// - no circular repeat -// - no disabled collectives -// - no collective pipelining -// -// Every stage of the pipeline is a single linear layer. -XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test +std::string GetModuleStrWithCommonComputations( + const std::string name, const std::string more_computations) { + static constexpr char kCommonComputationsStr[] = R"( + read_buffer_mb4 { + buffer = f32[4,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c4) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + read_buffer_mb5 { + buffer = f32[5,16] parameter(0) + offset = u32[] parameter(1) + index = u32[] parameter(2) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + slice = f32[1,16] dynamic-slice(buffer, index__, c0), + dynamic_slice_sizes={1,16} + ROOT slice_ = f32[16] reshape(slice) + } + + update_buffer_mb4 { + buffer = f32[4,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c4 = u32[] constant(4) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c4) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[4,16] dynamic-update-slice(buffer, update_, index__, c0) + } - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) + update_buffer_mb5 { + buffer = f32[5,16] parameter(0) + update = f32[16] parameter(1) + offset = u32[] parameter(2) + index = u32[] parameter(3) + c0 = u32[] constant(0) + c5 = u32[] constant(5) + index_ = u32[] add(index, offset) + index__ = u32[] remainder(index_, c5) + update_ = f32[1,16] reshape(update) + ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) } is_input_replica { @@ -156,10 +180,40 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { is_output_replica { replica_id = u32[] replica-id() - c1 = u32[] constant(1) - ROOT predicate = pred[] compare(replica_id, c1), direction=EQ + c3 = u32[] constant(3) + ROOT predicate = pred[] compare(replica_id, c3), direction=EQ + } + + is_read_input_mb4 { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c4 = u32[] constant(4) + is_input_iteration = pred[] compare(i, c4), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) } + is_read_input_mb5 { + is_input_replica = pred[] call(), to_apply=is_input_replica + i = u32[] parameter(0) + c5 = u32[] constant(5) + is_input_iteration = pred[] compare(i, c5), direction=LT + ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + } + )"; + return "HloModule " + name + "\n" + kCommonComputationsStr + "\n" + + more_computations; +} + +// Naive implementation of pipeline parallelism: +// - 4 devices +// - 4 microbatches +// - no circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch4Replica4) { + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) i = u32[] get-tuple-element(tuple), index=4 @@ -172,36 +226,34 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { weights = f32[16,16] get-tuple-element(tuple), index=0 input = f32[4,16] get-tuple-element(tuple), index=1 output = f32[4,16] get-tuple-element(tuple), index=2 - tmp = f32[16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3 i = u32[] get-tuple-element(tuple), index=4 - c1 = u32[] constant(1) c0 = u32[] constant(0) + c1 = u32[] constant(1) c4 = u32[] constant(4) - input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4 - prev_stage_slice = f32[16] collective-permute(tmp), + // Shift data to the next stage in the pipeline. + prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + // Select compute argument from previous stage or from input and perform + // compute. read_input = pred[] call(), to_apply=is_input_replica - compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index - output_slice = f32[1,16] reshape(compute_out) - output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index, - c0) + // Update buffers. + output_ = call(output, compute_res, c1, i), to_apply=update_buffer_mb4 i_ = add(i, c1) ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple( - weights, input, output_, compute_out, i_) + weights, input, output_, compute_res, i_) } ENTRY main { @@ -210,11 +262,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { cf0 = f32[] constant(0) output = f32[4,16] broadcast(cf0), dimensions={} - tmp = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights, - input, output, tmp, c0) + input, output, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -227,8 +279,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of 4 layers, each of which is a single linear layer. // We assign the weights to the replicas such that the layers scale the input @@ -260,7 +315,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // Check pipeline output for last replica. // The combined effect of the pipeline is to scale the input data by 24.0. const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0; - Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( kMicrobatches, kInputSize, kExpectedFactor); EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], ErrorSpec{1e-5, 1e-5})); @@ -274,32 +329,8 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch4Replica4) { // - no collective pipelining // // Every stage of the pipeline is a single linear layer. -XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test - - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) - } - - is_input_replica { - replica_id = u32[] replica-id() - c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - - is_output_replica { - replica_id = u32[] replica-id() - c1 = u32[] constant(1) - ROOT predicate = pred[] compare(replica_id, c1), direction=EQ - } - +XLA_TEST_F(CollectivePipelineParallelismTest, NaiveBFSMicrobatch5Replica4) { + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) i = u32[] get-tuple-element(tuple), index=4 @@ -312,37 +343,35 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { weights = f32[16,16] get-tuple-element(tuple), index=0 input = f32[5,16] get-tuple-element(tuple), index=1 output = f32[5,16] get-tuple-element(tuple), index=2 - tmp = f32[16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3 i = u32[] get-tuple-element(tuple), index=4 + c0 = u32[] constant(0) c1 = u32[] constant(1) c2 = u32[] constant(2) - c0 = u32[] constant(0) c5 = u32[] constant(5) - input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 - prev_stage_slice = f32[16] collective-permute(tmp), + // Shift data to the next stage in the pipeline. + prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + // Select compute argument from previous stage or from input and perform + // compute. read_input = pred[] call(), to_apply=is_input_replica - compute_in = f32[16] select(read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + compute_arg = f32[16] select(read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_index = u32[] call(c2, i, c5), to_apply=get_circ_buffer_index - output_slice = f32[1,16] reshape(compute_out) - output_ = f32[5,16] dynamic-update-slice(output, output_slice, output_index, - c0) + // Update buffers. + output_ = call(output, compute_res, c2, i), to_apply=update_buffer_mb5 i_ = add(i, c1) ROOT tuple1 = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output_, compute_out, i_) + tuple(weights, input, output_, compute_res, i_) } ENTRY main { @@ -351,11 +380,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { cf0 = f32[] constant(0) output = f32[5,16] broadcast(cf0), dimensions={} - tmp = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) tuple = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output, tmp, c0) + tuple(weights, input, output, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -368,8 +397,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of 4 layers, each of which is a single linear layer. // We assign the weights to the replicas such that the layers scale the input @@ -415,40 +447,8 @@ XLA_TEST_F(CollectivePipelineParallelismTest, NaiveDFSMicrobatch5Replica4) { // // Every stage of the pipeline is a single linear layer. XLA_TEST_F(CollectivePipelineParallelismTest, - NaiveDFSMicrobatch4CircularRepeat2Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test - - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) - } - - is_input_replica { - replica_id = u32[] replica-id() - c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - - is_output_replica { - replica_id = u32[] replica-id() - c3 = u32[] constant(3) - ROOT predicate = pred[] compare(replica_id, c3), direction=EQ - } - - is_read_input { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c4 = u32[] constant(4) - is_input_iteration = pred[] compare(i, c4), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) - } - + NaiveBFSMicrobatch4CircularRepeat2Replica4) { + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) parameter(0) i = u32[] get-tuple-element(tuple), index=4 @@ -461,36 +461,35 @@ XLA_TEST_F(CollectivePipelineParallelismTest, weights = f32[16,16] get-tuple-element(tuple), index=0 input = f32[4,16] get-tuple-element(tuple), index=1 output = f32[4,16] get-tuple-element(tuple), index=2 - tmp = f32[16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=3 i = u32[] get-tuple-element(tuple), index=4 - c1 = u32[] constant(1) c0 = u32[] constant(0) + c1 = u32[] constant(1) c4 = u32[] constant(4) - input_idx = u32[] call(c0, i, c4), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb4 - prev_stage_slice = f32[16] collective-permute(tmp), + // Shift data to the next stage in the pipeline. + prev_stage_slice = f32[16] collective-permute(prev_iteration_compute_res), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} - is_read_input = pred[] call(i), to_apply=is_read_input - compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb4 + compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_index = u32[] call(c1, i, c4), to_apply=get_circ_buffer_index - output_slice = f32[1,16] reshape(compute_out) - output_ = f32[4,16] dynamic-update-slice(output, output_slice, output_index, - c0) + // Update buffers. + output_ = f32[4,16] call(output, compute_res, c1, i), + to_apply=update_buffer_mb4 i_ = add(i, c1) ROOT tuple1 = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) - tuple(weights, input, output_, compute_out, i_) + tuple(weights, input, output_, compute_res, i_) } ENTRY main { @@ -499,11 +498,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, cf0 = f32[] constant(0) output = f32[4,16] broadcast(cf0), dimensions={} - tmp = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) tuple = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) tuple(weights, - input, output, tmp, c0) + input, output, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[4,16], f32[4,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -516,8 +515,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of a total of 8 layers (2 per replica), each of // which is a single linear layer. We assign the weights to the replicas such @@ -556,7 +558,7 @@ XLA_TEST_F(CollectivePipelineParallelismTest, ErrorSpec{1e-5, 1e-5})); } -// Naive implementation if pipeline parallelism: +// Naive implementation of pipeline parallelism: // - 4 devices // - 5 microbatches // - 2 circular repeat @@ -565,66 +567,146 @@ XLA_TEST_F(CollectivePipelineParallelismTest, // // Every stage of the pipeline is a single linear layer. XLA_TEST_F(CollectivePipelineParallelismTest, - NaiveDFSMicrobatch5CircularRepeat2Replica4) { - const absl::string_view kModuleStr = R"( - HloModule test - - get_circ_buffer_index { - offset = u32[] parameter(0) - index = u32[] parameter(1) - size = u32[] parameter(2) - t0 = u32[] add(offset, index) - t1 = u32[] divide(t0, size) - t2 = u32[] multiply(t1, size) - ROOT t4 = u32[] subtract(t0, t2) + NaiveBFSMicrobatch5CircularRepeat2Replica4) { + constexpr char kMoreComputationsStr[] = R"( + while_condition { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + i = u32[] get-tuple-element(tuple), index=5 + n = u32[] constant(13) + ROOT predicate = pred[] compare(i, n), direction=LT } - read_buffer { - buffer = f32[5,16] parameter(0) - offset = u32[] parameter(1) - index = u32[] parameter(2) - c0 = u32[] constant(0) - c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - slice = f32[1,16] dynamic-slice(buffer, index__, c0), - dynamic_slice_sizes={1,16} - ROOT slice_ = f32[16] reshape(slice) - } + while_body { + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + parameter(0) + weights = f32[16,16] get-tuple-element(tuple), index=0 + input = f32[5,16] get-tuple-element(tuple), index=1 + output = f32[5,16] get-tuple-element(tuple), index=2 + buffer = f32[5,16] get-tuple-element(tuple), index=3 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 + i = u32[] get-tuple-element(tuple), index=5 - update_buffer { - buffer = f32[5,16] parameter(0) - update = f32[16] parameter(1) - offset = u32[] parameter(2) - index = u32[] parameter(3) c0 = u32[] constant(0) + c1 = u32[] constant(1) + c2 = u32[] constant(2) + c3 = u32[] constant(3) + c4 = u32[] constant(4) c5 = u32[] constant(5) - index_ = u32[] add(index, offset) - index__ = u32[] remainder(index_, c5) - update_ = f32[1,16] reshape(update) - ROOT buffer_ = f32[5,16] dynamic-update-slice(buffer, update_, index__, c0) + + // Read from buffers. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + + // Shift data to the next stage in the pipeline. + // Directly depends on the updated buffer of the previous iteration and, + // therefore, depends on the previous iteration's compute. + is_output_replica = pred[] call(), to_apply=is_output_replica + next_stage_slice = select(is_output_replica, buffer_slice, + prev_iteration_compute_res) + prev_stage_slice = f32[16] collective-permute(next_stage_slice), + source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} + + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, + rhs_contracting_dims={0} + + // Update buffers. + output_ = f32[5,16] call(output, compute_res, c2, i), + to_apply=update_buffer_mb5 + buffer_ = f32[5,16] call(buffer, compute_res, c0, i), + to_apply=update_buffer_mb5 + + i_ = add(i, c1) + + ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output_, buffer_, compute_res, i_) } - is_input_replica { - replica_id = u32[] replica-id() + ENTRY main { + weights = f32[16,16] parameter(0) + input = f32[5,16] parameter(1) + + cf0 = f32[] constant(0) + output = f32[5,16] broadcast(cf0), dimensions={} + buffer = f32[5,16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) - ROOT predicate = pred[] compare(replica_id, c0), direction=EQ - } - is_output_replica { - replica_id = u32[] replica-id() - c3 = u32[] constant(3) - ROOT predicate = pred[] compare(replica_id, c3), direction=EQ - } + // Iterate through pipeline stages. + tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + tuple(weights, input, output, buffer, prev_iteration_compute_res, c0) + tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) + while(tuple), condition=while_condition, body=while_body - is_read_input { - is_input_replica = pred[] call(), to_apply=is_input_replica - i = u32[] parameter(0) - c5 = u32[] constant(5) - is_input_iteration = pred[] compare(i, c5), direction=LT - ROOT is_read_input = pred[] and(is_input_replica, is_input_iteration) + ROOT output_ = f32[5,16] get-tuple-element(tuple_), index=2 } + )"; + + const int64_t kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas) + + HloModuleConfig config = + GetModuleConfigForTest(/*replica_count=*/kNumReplicas); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); + + // This pipeline consists of a total of 8 layers (2 per replica), each of + // which is a single linear layer. We assign the weights to the replicas such + // that the layers scale the input data by 1.0, 2.0, 3.0 and 4.0 in the first + // and second cycle. The combined effect is to scale the input data by 576.0 + // (24.0 * 24.0). + const int64_t kInputSize = 16; + Literal weights_r0 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 1.0); + Literal weights_r1 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 2.0); + Literal weights_r2 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 3.0); + Literal weights_r3 = LiteralUtil::MakeScalarMatrixR2(kInputSize, 4.0); + + // Only the first replica holds the input to the pipeline in this naive + // implementation. The remaining replicas get zero/dummy input. + const int64_t kMicrobatches = 5; + Literal real_input = + LiteralUtil::CreateFingerprintMatixR2(kMicrobatches, kInputSize); + Literal fake_input = + LiteralUtil::CreateFull({kMicrobatches, kInputSize}, 0.0); + + // Check pipeline output for last replica. + // The combined effect of the pipeline is to scale the input data by 576.0 + // (24.0 * 24.0). + const float kExpectedFactor = 1.0 * 2.0 * 3.0 * 4.0 * 1.0 * 2.0 * 3.0 * 4.0; + Literal expected_output = LiteralUtil::CreateFingerprintMatixR2( + kMicrobatches, kInputSize, /*scale=*/kExpectedFactor); + std::vector> args = {{&weights_r0, &real_input}, + {&weights_r1, &fake_input}, + {&weights_r2, &fake_input}, + {&weights_r3, &fake_input}}; + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), args, kNumReplicas, + /*run_hlo_passes=*/true)); + EXPECT_TRUE(LiteralTestUtil::NearOrEqual(expected_output, results[3], + ErrorSpec{1e-5, 1e-5})); +} +// Naive implementation of pipeline parallelism, which breaks the direct data +// dependency between the collective permute and the previous iteration's +// compute. +// - 4 devices +// - 4 microbatches +// - 2 circular repeat +// - no disabled collectives +// - no collective pipelining +// +// Every stage of the pipeline is a single linear layer. +XLA_TEST_F(CollectivePipelineParallelismTest, + NaiveWoDirectBufferDependencyBFSMicrobatch5CircularRepeat2Replica4) { + constexpr char kMoreComputationsStr[] = R"( while_condition { tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) parameter(0) @@ -640,43 +722,46 @@ XLA_TEST_F(CollectivePipelineParallelismTest, input = f32[5,16] get-tuple-element(tuple), index=1 output = f32[5,16] get-tuple-element(tuple), index=2 buffer = f32[5,16] get-tuple-element(tuple), index=3 - prev_iteration_compute_out = f32[16] get-tuple-element(tuple), index=4 + prev_iteration_compute_res = f32[16] get-tuple-element(tuple), index=4 i = u32[] get-tuple-element(tuple), index=5 c0 = u32[] constant(0) c1 = u32[] constant(1) c2 = u32[] constant(2) c3 = u32[] constant(3) + c4 = u32[] constant(4) c5 = u32[] constant(5) - input_idx = u32[] call(c0, i, c5), to_apply=get_circ_buffer_index - input_slice = f32[1,16] dynamic-slice(input, input_idx, c0), - dynamic_slice_sizes={1,16} - input_slice_ = f32[16] reshape(input_slice) - - buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer + // Read from buffers before they are updated. + input_slice = f32[16] call(input, c0, i), to_apply=read_buffer_mb5 + buffer_slice = f32[16] call(buffer, c3, i), to_apply=read_buffer_mb5 + // Shift data to the next stage in the pipeline. + // Depends on the non-updated buffer of the previous iteration and, + // therefore, does not depend on the previous iteration's compute. is_output_replica = pred[] call(), to_apply=is_output_replica next_stage_slice = select(is_output_replica, buffer_slice, - prev_iteration_compute_out) - + prev_iteration_compute_res) prev_stage_slice = f32[16] collective-permute(next_stage_slice), source_target_pairs={{0,1}, {1,2}, {2,3}, {3,0}} - is_read_input = pred[] call(i), to_apply=is_read_input - compute_in = f32[16] select(is_read_input, input_slice_, prev_stage_slice) - - compute_out = f32[16] dot(weights, compute_in), lhs_contracting_dims={1}, + // Select compute argument from previous stage or from input and perform + // compute. + is_read_input = pred[] call(i), to_apply=is_read_input_mb5 + compute_arg = f32[16] select(is_read_input, input_slice, prev_stage_slice) + compute_res = f32[16] dot(weights, compute_arg), lhs_contracting_dims={1}, rhs_contracting_dims={0} - output_ = f32[5,16] call(output, compute_out, c2, i), to_apply=update_buffer - - buffer_ = f32[5,16] call(buffer, compute_out, c0, i), to_apply=update_buffer + // Update buffers. + buffer_ = f32[5,16] call(buffer, prev_iteration_compute_res, c4, i), + to_apply=update_buffer_mb5 + output_ = f32[5,16] call(output, compute_res, c2, i), + to_apply=update_buffer_mb5 i_ = add(i, c1) ROOT tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output_, buffer_, compute_out, i_) + tuple(weights, input, output_, buffer_, compute_res, i_) } ENTRY main { @@ -686,11 +771,12 @@ XLA_TEST_F(CollectivePipelineParallelismTest, cf0 = f32[] constant(0) output = f32[5,16] broadcast(cf0), dimensions={} buffer = f32[5,16] broadcast(cf0), dimensions={} - prev_iteration_compute_out = f32[16] broadcast(cf0), dimensions={} + prev_iteration_compute_res = f32[16] broadcast(cf0), dimensions={} c0 = u32[] constant(0) + // Iterate through pipeline stages. tuple = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) - tuple(weights, input, output, buffer, prev_iteration_compute_out, c0) + tuple(weights, input, output, buffer, prev_iteration_compute_res, c0) tuple_ = (f32[16,16], f32[5,16], f32[5,16], f32[5,16], f32[16], u32[]) while(tuple), condition=while_condition, body=while_body @@ -703,8 +789,11 @@ XLA_TEST_F(CollectivePipelineParallelismTest, HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN(auto module, - ParseAndReturnVerifiedModule(kModuleStr, config)); + TF_ASSERT_OK_AND_ASSIGN( + auto module, + ParseAndReturnVerifiedModule(GetModuleStrWithCommonComputations( + /*name=*/"test", kMoreComputationsStr), + config)); // This pipeline consists of a total of 8 layers (2 per replica), each of // which is a single linear layer. We assign the weights to the replicas such diff --git a/third_party/xla/xla/tests/compute_constant_test.cc b/third_party/xla/xla/tests/compute_constant_test.cc index d991b580a26bbd..8742656f17ff7a 100644 --- a/third_party/xla/xla/tests/compute_constant_test.cc +++ b/third_party/xla/xla/tests/compute_constant_test.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/constants_test.cc b/third_party/xla/xla/tests/constants_test.cc index a926d24819fd68..26407b42790526 100644 --- a/third_party/xla/xla/tests/constants_test.cc +++ b/third_party/xla/xla/tests/constants_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/ml_dtypes.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/convert_test.cc b/third_party/xla/xla/tests/convert_test.cc index 13ca51a4025ebb..4db6394d1503ce 100644 --- a/third_party/xla/xla/tests/convert_test.cc +++ b/third_party/xla/xla/tests/convert_test.cc @@ -14,20 +14,19 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include +#include #include #include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" -#include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "xla/stream_executor/stream_executor.h" #include "xla/tests/client_library_test_base.h" -#include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/types.h" #include "xla/xla_data.pb.h" @@ -672,6 +671,59 @@ XLA_TEST_F(ConvertTest, ConvertBF16F32) { } } +XLA_TEST_F(ConvertTest, ConvertF32BF16) { + XlaBuilder builder(TestName()); + + std::vector floats(100); + std::minstd_rand0 generator; + for (int i = 0; i < floats.size(); ++i) { + floats[i] = generator(); + + // Ensure the first 10 cases has rounding. + if (i < 10) { + auto val = absl::bit_cast(floats[i]); + val |= 1 << 15; + floats[i] = absl::bit_cast(val); + } + } + // Test NaN and -Nan. + floats.push_back(std::numeric_limits::quiet_NaN()); + floats.push_back(-std::numeric_limits::quiet_NaN()); + + std::vector expected(floats.size()); + for (int i = 0; i < expected.size(); ++i) { + expected[i] = static_cast(floats[i]); + } + + xla::XlaOp lit_f32 = ConstantR1(&builder, floats); + xla::XlaOp lit_bf16 = ConvertElementType(lit_f32, BF16); + BitcastConvertType(lit_bf16, U16); + + TF_ASSERT_OK_AND_ASSIGN(const auto results, ExecuteAndTransfer(&builder, {})); + for (int i = 0; i < expected.size(); ++i) { + const auto result = results.Get({i}); + const auto correct = absl::bit_cast(expected[i]); + if (floats[i] != 0.0f && floats[i] < std::numeric_limits::min()) { + // Subnormals may not be preserved, zero will do. + const bfloat16 same_signed_zero = + bfloat16(std::signbit(floats[i]) ? -0.0f : 0.0f); + if (result != correct) { + EXPECT_EQ(result, absl::bit_cast(same_signed_zero)); + } + } else if (std::isnan(floats[i])) { + // NaNs may not be preserved, any NaN will do. + ASSERT_TRUE(std::isnan(absl::bit_cast(correct))); + EXPECT_TRUE(std::isnan(absl::bit_cast(result))); + if (client_->platform()->Name() == "Host") { + // The sign bits must match. + EXPECT_EQ(result >> 15, correct >> 15); + } + } else { + EXPECT_EQ(result, correct); + } + } +} + XLA_TEST_F(ConvertTest, ConvertF16F8e5m2Roundtrip) { // Convert from FP16 to FP8, then back to FP16 XlaBuilder builder(TestName()); diff --git a/third_party/xla/xla/tests/copy_test.cc b/third_party/xla/xla/tests/copy_test.cc index 45d94ab0333838..91b7fa2a1473c8 100644 --- a/third_party/xla/xla/tests/copy_test.cc +++ b/third_party/xla/xla/tests/copy_test.cc @@ -13,22 +13,31 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include +#include -#include "xla/array2d.h" +#include +#include "absl/types/span.h" +#include "xla/array3d.h" +#include "xla/array4d.h" #include "xla/client/xla_builder.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" +#include "xla/literal_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/protobuf.h" #include "tsl/platform/test.h" namespace xla { @@ -50,6 +59,25 @@ class CopyOpTest : public HloTestBase { EXPECT_TRUE(LiteralTestUtil::Equal(literal, result)); } + // TODO(vsytch): Remove special handling for dynamic shapes once *all* of XLA + // supports those as module inputs/outputs. + void TestDynamicCopyOp(const Literal& literal, const Shape& bounded_shape) { + Literal dynamic_literal = literal.ToBoundedDynamic(bounded_shape); + auto builder = HloComputation::Builder(TestName()); + auto parameter = builder.AddInstruction( + HloInstruction::CreateParameter(0, dynamic_literal.shape(), "param")); + builder.AddInstruction(HloInstruction::CreateUnary( + parameter->shape(), HloOpcode::kCopy, parameter)); + auto computation = builder.Build(); + auto module = CreateNewVerifiedModule(); + module->AddEntryComputation(std::move(computation)); + + std::vector args = {&dynamic_literal}; + Literal result = ExecuteAndTransfer(std::move(module), args); + Literal dynamic_result = result.ToBoundedDynamic(bounded_shape); + EXPECT_TRUE(LiteralTestUtil::Equal(dynamic_literal, dynamic_result)); + } + void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3); void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4, absl::Span permutation); @@ -67,6 +95,59 @@ XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(LiteralUtil::CreateR1({1, 2, 3})); } +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic0) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {0}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic106632) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp( + LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {106632}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S1310720U32Dynamic1310720) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}, {true}); + TestDynamicCopyOp( + LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {1310720}), 0, 1) + .value(), + bounded_shape); +} + +XLA_TEST_F(CopyOpTest, CopyDynamicR1S512U32Dynamic64) { + // TODO(vsytch): CPU emitter doesn't handle dynamic shapes. + if (backend().platform()->Name() == "Host") { + GTEST_SKIP(); + } + Shape bounded_shape = ShapeUtil::MakeShape(PrimitiveType::F32, {512}, {true}); + TestDynamicCopyOp(LiteralUtil::CreateRandomLiteral( + ShapeUtil::MakeShape(PrimitiveType::F32, {64}), 0, 1) + .value(), + bounded_shape); +} + XLA_TEST_F(CopyOpTest, CopyR3F32_2x2x3) { TestCopyOp(LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); diff --git a/third_party/xla/xla/tests/custom_call_test.cc b/third_party/xla/xla/tests/custom_call_test.cc index 2d0f370c664a4b..2ada7f0b22152b 100644 --- a/third_party/xla/xla/tests/custom_call_test.cc +++ b/third_party/xla/xla/tests/custom_call_test.cc @@ -26,17 +26,19 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "absl/synchronization/blocking_counter.h" #include "absl/types/span.h" -#include "xla/client/lib/constants.h" +#include "xla/array2d.h" +#include "xla/array3d.h" #include "xla/client/xla_builder.h" #include "xla/ffi/ffi.h" #include "xla/ffi/ffi_api.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/literal.h" @@ -44,6 +46,7 @@ limitations under the License. #include "xla/primitive_util.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" +#include "xla/service/service.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/tests/client_library_test_base.h" @@ -55,6 +58,9 @@ limitations under the License. #include "tsl/platform/statusor.h" #include "tsl/platform/test.h" +#define EIGEN_USE_THREADS +#include "unsupported/Eigen/CXX11/Tensor" + namespace { void R0F32Add2(float* out, float** in) { ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); @@ -862,6 +868,40 @@ XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$HandleTupleDifferentRanks", "Host", kHandleTupleDifferentRanks); +static absl::Status CustomCallWithIntraOpThreadPool( + ffi::Result, + const Eigen::ThreadPoolDevice* intra_op_thread_pool) { + // We use two blocking counters to ensure that the task is actually running + // inside a thread pool. + absl::BlockingCounter counter0(1); + absl::BlockingCounter counter1(1); + + intra_op_thread_pool->getPool()->Schedule([&]() { + counter0.Wait(); + counter1.DecrementCount(); + }); + + // Unblock submitted task. + counter0.DecrementCount(); + + // TODO(b/356389210): It is unsafe to wait for the completion of a task + // submitted into an intra-op thread pool as we might be running on a thread + // inside the same thread pool, and this can lead to deadlocks. Custom calls + // should return `AsyncValue` to signal completion of all submitted tasks. + counter1.Wait(); + + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kIntraOpThreadPool, CustomCallWithIntraOpThreadPool, + ffi::Ffi::Bind() + .Ret() // unused out buffer + .Ctx()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), + "__xla_test$$intra_op_thread_pool", "Host", + kIntraOpThreadPool); + } // namespace // __xla_test$$ConcatVectors @@ -1610,5 +1650,113 @@ XLA_TEST_F(FfiCustomCallTest, FfiNestedTupleInputAndOutput) { EXPECT_EQ(result, expected); } +XLA_TEST_F(FfiCustomCallTest, IntraOpThreadPool) { + auto module = CreateNewVerifiedModule(); + auto builder = HloComputation::Builder(TestName()); + + builder.AddInstruction(HloInstruction::CreateCustomCall( + r0f32_, {}, "__xla_test$$intra_op_thread_pool", "", + /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI)); + + module->AddEntryComputation(builder.Build()); + + auto status = Execute(std::move(module), {}).status(); + EXPECT_EQ(status, absl::OkStatus()); +} + +//===----------------------------------------------------------------------===// +// Stateful XLA:FFI handler +//===----------------------------------------------------------------------===// + +struct SomeState { + explicit SomeState(float value) : value(value) {} + float value = 0; +}; + +int instantiate_called_counter = 0; + +// Every time custom call HLO operation is instantiated as a CPU runtime Thunk, +// XLA calls instantiate callback to create a new instance of the handler state, +// that will be passed to all other FFI handler calls. +static absl::StatusOr> InstantiateState() { + ++instantiate_called_counter; + return std::make_unique(42.f); +} + +// At run time we can access the state created by the instantiate callback. +static absl::Status IncrementState(R0F32ResultBuffer out, SomeState* state) { + state->value += 1.f; + auto out_data = out->typed_data(); + *out_data = state->value; + return absl::OkStatus(); +} + +XLA_FFI_DEFINE_HANDLER(kInstantiateState, InstantiateState, + ffi::Ffi::BindInstantiate()); + +XLA_FFI_DEFINE_HANDLER( + kIncrementState, IncrementState, + ffi::Ffi::Bind().Ret().Ctx>()); + +XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "__xla_test$$ffi_execution_state", + "Host", + { + /*instantiate=*/kInstantiateState, + /*prepare=*/nullptr, + /*initialize=*/nullptr, + /*execute=*/kIncrementState, + }); + +// This test doesn't care about execution results, its intent is just to test if +// instantiate function was called. +TEST_F(CustomCallTest, FfiExecutionStateInstantiate) { + const char* const kModuleStr = R"( + HloModule m + ENTRY test { + ROOT result = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // Execute the module, but don't verify the results. + instantiate_called_counter = 0; + auto result = Execute(std::move(module), {}); + + // Check that instantiate callback was called. + EXPECT_EQ(instantiate_called_counter, 1); +} + +TEST_F(CustomCallTest, FfiExecutionStateExecute) { + // Execution state is only partially implemented at the moment. + GTEST_SKIP() << "Not implemented yet."; + + // TODO(abanas): Actually, this HLO probably creates two custom call thunks, + // each one is called once. If yes then fix it, cause the intent is to call + // the same custom call twice. + const char* const kModuleStr = R"( + HloModule m + ENTRY test { + first = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + second = f32[] custom-call(), custom_call_target= + "__xla_test$$ffi_execution_state", api_version=API_VERSION_TYPED_FFI + ROOT result = (f32[], f32[]) tuple(first, second) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + Literal expected0 = + LiteralUtil::CreateR0(43.f); // Incremented once. + Literal expected1 = + LiteralUtil::CreateR0(44.f); // Incremented twice. + Literal expected = LiteralUtil::MakeTuple({&expected0, &expected1}); + + TF_ASSERT_OK_AND_ASSIGN(auto result, Execute(std::move(module), {})); + EXPECT_EQ(result, expected); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/BUILD b/third_party/xla/xla/tests/exhaustive/BUILD index 54c3820cd8648e..dcd74bcc34750c 100644 --- a/third_party/xla/xla/tests/exhaustive/BUILD +++ b/third_party/xla/xla/tests/exhaustive/BUILD @@ -29,6 +29,7 @@ cc_library( deps = [ "//xla:bit_cast", "//xla:executable_run_options", + "//xla:fp_util", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -39,6 +40,8 @@ cc_library( "//xla/client:xla_computation", "//xla/service:shaped_buffer", "//xla/tests:client_library_test_base", + "//xla/tsl/lib/core:status_test_util", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -48,7 +51,10 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", ], ) @@ -83,9 +89,11 @@ xla_test( "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/client/lib:math", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) @@ -116,9 +124,11 @@ xla_test( "//xla/client/lib:constants", "//xla/client/lib:math", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) @@ -147,17 +157,26 @@ xla_test( "//xla/client:xla_builder", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) +filegroup( + name = "exhaustive_binary_16_bit_test_srcs", + srcs = [ + "exhaustive_binary_16_bit_test.cc", + ], +) + xla_test( name = "exhaustive_binary_16_bit_test", srcs = [ - "exhaustive_binary_16_bit_test.cc", "exhaustive_test_main.cc", + ":exhaustive_binary_16_bit_test_srcs", ], backends = [ "gpu", @@ -173,12 +192,15 @@ xla_test( deps = [ ":exhaustive_op_test_utils", "//xla:literal", + "//xla:types", "//xla/client:xla_builder", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", "@eigen_archive//:eigen3", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) @@ -205,9 +227,11 @@ xla_test( "//xla:literal", "//xla/client:xla_builder", "//xla/tests:xla_internal_test_main", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc index f9da77dcd144c4..3f61111c974c84 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_16_bit_test.cc @@ -13,11 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include +#include #include #include -#include +#include +#include +#include #include +#include #include #include "absl/log/check.h" @@ -28,6 +33,7 @@ limitations under the License. #include "xla/literal.h" #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/test_macros.h" +#include "xla/types.h" #include "tsl/platform/test.h" #ifdef __FAST_MATH__ @@ -42,7 +48,7 @@ namespace { // including float16 and bfloat. // // Test parameter is a pair of (begin, end) for range under test. -template +template class Exhaustive16BitBinaryTest : public ExhaustiveBinaryTest, public ::testing::WithParamInterface> { @@ -54,9 +60,13 @@ class Exhaustive16BitBinaryTest } // Given a range of uint64_t representation, uses bits 0..15 and bits 16..31 - // for the values of src0 and src1 for a 16 bit binary operation being tested, - // and generates the cartesian product of the two sets as the two inputs for - // the test. + // for the values of src0 and src1 (see below for ordering) for the 16 bit + // binary operation being tested, and generates the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, bit 31..16 become src0 and 15..0 becomes + // src1. If `kLeftToRightPacking == false`, then bits 31..16 become src1 + // and 15..0 becomes src0. void FillInput(std::array* input_literals) override { int64_t input_size = GetInputSize(); CHECK_EQ(input_size, (*input_literals)[0].element_count()); @@ -64,17 +74,53 @@ class Exhaustive16BitBinaryTest int64_t begin, end; std::tie(begin, end) = GetParam(); - VLOG(2) << "Checking range [" << begin << ", " << end << "]"; + + uint16_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 16)); + left_end = std::bit_cast(static_cast(end >> 16)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = std::bit_cast(static_cast(begin >> 16)); + right_end = std::bit_cast(static_cast(end >> 16)); + } + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=(" << left_begin << ", " << right_begin << "); hex=(" + << std::hex << left_begin << ", " << right_begin << "); float=(" + << *reinterpret_cast(&left_begin) << ", " + << *reinterpret_cast(&right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << left_end << ", " << right_end << "); hex=(" + << std::hex << left_end << ", " << right_end << "); float=(" + << *reinterpret_cast(&left_end) << ", " + << *reinterpret_cast(&right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } absl::Span input_arr_0 = (*input_literals)[0].data(); absl::Span input_arr_1 = (*input_literals)[1].data(); for (int64_t i = 0; i < input_size; i++) { uint32_t input_val = i + begin; - // Convert the lower 16 bits to the NativeT and replaced known incorrect - // input values with 0. - input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); - input_arr_1[i] = - ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 32 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + // Left is stored at higher 16 bits. + input_arr_0[i] = + ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + input_arr_1[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + } else { + // Left is stored at lower 16 bits. + input_arr_0[i] = ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); + input_arr_1[i] = + ConvertAndReplaceKnownIncorrectValueWith(input_val >> 16, 0); + } } } @@ -105,51 +151,557 @@ using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; BINARY_TEST_F16(test_name, __VA_ARGS__) \ BINARY_TEST_BF16(test_name, __VA_ARGS__) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double AddCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) + static_cast(right); + + // Hardware flushes subnormal outputs to 0. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + BINARY_TEST_16BIT(Add, { - auto host_add = [](float x, float y) { return x + y; }; - Run(AddEmptyBroadcastDimension(Add), host_add); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if ((IsCpu(platform_) || IsTpu(platform_))) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(AddCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Add), [](float x, float y) { return x + y; }, + error_spec_gen); }) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double SubCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) - static_cast(right); + + // Hardware flushes subnormal outputs to 0. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + BINARY_TEST_16BIT(Sub, { - auto host_sub = [](float x, float y) { return x - y; }; - Run(AddEmptyBroadcastDimension(Sub), host_sub); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_) || IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(SubCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Sub), [](float x, float y) { return x - y; }, + error_spec_gen); }) -// TODO(bixia): Mul fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Mul), { - auto host_mul = [](float x, float y) { return x * y; }; - Run(AddEmptyBroadcastDimension(Mul), host_mul); +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double MulCpuTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) * static_cast(right); + + // CPU BF16 and TPU (all types) flush subnormals to 0. + auto output_is_subnormal = IsSubnormal(output); + if (output_is_subnormal) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +bool MulCpuTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + // For CPU and TPU BF16, multiplying a subnormal by infinity will lead to + // calculating 0 multiplied by infinity due to subnormal flushing, which is + // defined to be NaN. However, the calculation in higher precision does not + // flush the subnormal value to 0, leading to a result of infinity. + if ((IsSubnormal(left) && std::isinf(right)) || + (std::isinf(left) && IsSubnormal(right))) { + return true; + } + return false; +} + +BINARY_TEST_16BIT(Mul, { + ErrorSpecGen error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_) || IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(MulCpuTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .skip_comparison( + MulCpuTpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Mul), [](float x, float y) { return x * y; }, + error_spec_gen); }) -// TODO(bixia): Div fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Div), { - auto host_div = [](float x, float y) { return x / y; }; - Run(AddEmptyBroadcastDimension(Div), host_div); +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double DivCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = static_cast(left) / static_cast(right); + + // Subnormals are flushed to 0 so we add a absolute error margin that is + // larger than any subnormal. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +double DivTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float reciprocal = 1.0f / static_cast(right); + xla::bfloat16 output = left / right; + float output_as_float = static_cast(left) / static_cast(right); + + // If we calculate NaN, we don't need to adjust tolerances. + if (std::isnan(output_as_float)) { + return 0.0; + } + + // TPUs perform `left * (1 / right)`, where `left` and `1 / right` are + // flushed to `0` if they are subnormal. Also applies to if reciprocal is min + // normal. + if (IsSubnormal(left) || IsSubnormalOrMinNormal(reciprocal)) { + // Subnormals can have a larger value in BF16 than float due to rounding to + // the nearest BF16 value during conversion while having less representation + // bits. For normals, the float value is usually always bigger due to + // greater precision. + return std::max(std::abs(output), std::abs(output_as_float)); + } + + // For subnormals, we need to set absolute error to the smallest positive + // representable value due to hardware implementations that truncate + // subnormals to zero. + if (IsSubnormalOrMinNormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +bool DivTpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + float reciprocal = 1.0f / right; + + // TPU calculates `left * (1 / right)` and flushed `(1 / right)` to `0` when + // it is subnormal or min normal. It also follows the IEEE multiplication spec + // that inf * 0 is NaN. However, IEEE division of infinity by a subnormal is + // infinity, so we must skip comparison. + if (std::isinf(left) && IsSubnormalOrMinNormal(reciprocal)) { + return true; + } + + return false; +} + +BINARY_TEST_16BIT(Div, { + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(DivCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_) && std::is_same_v) { + error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build(); + }; + } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(DivTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .skip_comparison(DivTpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(std::numeric_limits::min()) + .strict_signed_zeros() + .build(); + }; + } + } + if (IsPreV5Tpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(DivTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .rel_err(std::numeric_limits::epsilon()) + .strict_signed_zeros() + .skip_comparison(DivTpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } + } + + Run( + AddEmptyBroadcastDimension(Div), [](float x, float y) { return x / y; }, + error_spec_gen); }) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double MaxMinCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + // Subnormals are treated as 0 and max returns the first if all are + // 0-equivalent. + if (IsSubnormal(left) && (right == 0.0 || IsSubnormal(right))) { + return std::abs(left); + } + return 0.0; +} + BINARY_TEST_16BIT(Max, { - Run(AddEmptyBroadcastDimension(Max), ReferenceMax); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_) || IsTpu(platform_)) { + error_spec_gen = +[](NativeT, NativeT) { + // A100 and H100 return -0 for max(-0,0). + // + // TPUs return -0 for max(0,-0) and 0 for max(-0,0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }; + } + + Run(AddEmptyBroadcastDimension(Max), ReferenceMax, error_spec_gen); }) BINARY_TEST_16BIT(Min, { - Run(AddEmptyBroadcastDimension(Min), ReferenceMin); + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(MaxMinCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_) || IsTpu(platform_)) { + error_spec_gen = +[](NativeT, NativeT) { + // A100 and H100 return 0 for min(0,-0). + // + // TPUs return 0 for min(-0,0) and -0 for min(0,-0). + return ErrorSpec::Builder().strict_signed_zeros(false).build(); + }; + } + + Run(AddEmptyBroadcastDimension(Min), ReferenceMin, error_spec_gen); }) -// TODO(bixia): Pow fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_GPU(DISABLED_ON_CPU(Pow)), { - // See b/162664705. - known_incorrect_fn_ = [](int64_t val) { - Eigen::bfloat16 f; - uint16_t val_16 = val; - memcpy(&f, &val_16, 2); - return std::isnan(f); +template +bool PowCpuGpuF16Skip(NativeT left, NativeT right) { + // Hardware always returns 1 if right is 0, no matter if left is NaN. + if (std::isnan(left) && right == 0.0f) { + return true; + } + // Hardware always returns 1 if left is 1, no matter if right is NaN. + if (left == 1.0f && std::isnan(right)) { + return true; + } + return false; +} + +double PowCpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = std::pow(static_cast(left), static_cast(right)); + + // Output is flushed to 0 if subnormal. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + // TODO(b/359325328): pow computation for subnormal bases is different from + // std::pow. + // + // If the base is subnormal, the output computation selects a different base. + // The minimum value ever chosen is slightly greater than the 1e-91 used + // below. We return an absolute error from this value to the "real" output. + // + // Because the exponent (right) can be any floating point value, this allows + // an arbitrary absolute error for subnormal values. + if (IsSubnormal(left)) { + xla::bfloat16 output_as_bf16 = static_cast(output); + auto expected = std::pow(1e-91, static_cast(right)); + auto err = std::abs(expected - output_as_bf16); + if (!std::isnan(err)) { + return err; + } + } + + return 0.0; +} + +double PowTpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = std::pow(static_cast(left), static_cast(right)); + + // Output is flushed to 0 if subnormal. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +template +bool PowTpuSkip(NativeT left, NativeT right) { + // Hardware always returns 1 if right is 0 (or subnormal due to + // flushing subnormals to zero before the operation), no matter if left is + // NaN. + if (std::isnan(left) && (right == 0.0f || IsSubnormal(right))) { + return true; + } + // Hardware always returns 1 if left is 1, no matter if right is NaN. + if (left == 1.0f && std::isnan(right)) { + return true; + } + + return false; +} + +BINARY_TEST_16BIT(Pow, { + ErrorSpecGen error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); }; - Run(AddEmptyBroadcastDimension(Pow), std::pow); + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(PowCpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .build(); + }; + } else if constexpr (std::is_same_v || + std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } + } + + if (IsGpu(platform_)) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowCpuGpuF16Skip(left, right)) + .build(); + }; + } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(PowTpuBf16AbsErr(static_cast(left), + static_cast(right))) + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowTpuSkip(left, right)) + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .skip_comparison(PowTpuSkip(left, right)) + .build(); + }; + } + } + + Run(AddEmptyBroadcastDimension(Pow), std::pow, error_spec_gen); }) -// TODO(bixia): Atan2 fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2), - { Run(AddEmptyBroadcastDimension(Atan2), std::atan2); }) +// Can be thought of as an absolute error of +// `<= |std::numeric_limits::::min()|`. +double Atan2CpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + float output = + std::atan2(static_cast(left), static_cast(right)); + + // If the output would be a subnormal float, we allow some error to account + // for BF16 implementation flushing subnormals to zero. + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +bool Atan2CpuBf16Skip(xla::bfloat16 left, xla::bfloat16 right) { + // Subnormals are flushed to 0, but 0/0 returns NaN instead of + // / which returns some positive number. We cannot set + // an error to compare against NaN. + if (IsSubnormal(left) && IsSubnormal(right)) { + return true; + } + + return false; +} + +double Atan2TpuBf16AbsErr(xla::bfloat16 left, xla::bfloat16 right) { + xla::bfloat16 output = static_cast(std::atan2(left, right)); + float output_as_float = + std::atan2(static_cast(left), static_cast(right)); + + // If the output would be a subnormal float, we allow some error to account + // for BF16 implementation flushing subnormals to zero. TPUs also seem to + // flush the minimum value to 0 along with subnormals. + if (IsSubnormalOrMinNormal(output_as_float)) { + return std::numeric_limits::min(); + } + + // Implementation of Atan2 on TPUs is that they take the reciprocal of the + // larger of left or right. If this is subnormal or the minimum value, the TPU + // flushes it to 0 before using it in multiplication. When this happens, the + // error is the output calculation, either in BF16 or float, or PI/2, + // depending on which of the three is bigger. + float reciprocal_as_float = + 1.0f / std::max(std::abs(static_cast(left)), + std::abs(static_cast(right))); + if (!std::isnan(output_as_float) && + IsSubnormalOrMinNormal(reciprocal_as_float)) { + return std::max({std::abs(output_as_float), std::abs(output), + static_cast(M_PI_2)}); + } + + return 0.0; +} + +BINARY_TEST_16BIT(Atan2, { + auto error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().strict_signed_zeros().build(); + }; + + if (IsCpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(Atan2CpuBf16AbsErr(static_cast(left), + static_cast(right))) + .strict_signed_zeros() + .skip_comparison( + Atan2CpuBf16Skip(static_cast(left), + static_cast(right))) + .build(); + }; + } + } + + if (IsGpu(platform_)) { + error_spec_gen = +[](NativeT, NativeT) { + return ErrorSpec::Builder().distance_err(1).strict_signed_zeros().build(); + }; + } + + if (IsTpu(platform_)) { + if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .abs_err(Atan2TpuBf16AbsErr(static_cast(left), + static_cast(right))) + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } else if constexpr (std::is_same_v) { + error_spec_gen = +[](NativeT left, NativeT right) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + }; + } + } + + Run(AddEmptyBroadcastDimension(Atan2), std::atan2, error_spec_gen); +}) #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc index 57d1c3fd2a371a..06cc4b0822f153 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_binary_f32_f64_test.cc @@ -63,9 +63,12 @@ class Exhaustive32BitOrMoreBinaryTest FpValues values_0; FpValues values_1; std::tie(values_0, values_1) = GetParam(); - - VLOG(2) << " testing " << values_0.ToString() << " " << values_1.ToString() - << "total values " << input_size; + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\tleft values=" << values_0.ToString(); + LOG(INFO) << "\tright values=" << values_1.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } CHECK(input_size == (*input_literals)[0].element_count() && input_size == (*input_literals)[1].element_count()); diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc index 4f606a11dc0220..ccea6e55388c0c 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -15,11 +15,13 @@ limitations under the License. #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include #include #include #include #include #include +#include #include #include #include @@ -30,21 +32,49 @@ limitations under the License. #include "absl/meta/type_traits.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_replace.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "Eigen/Core" #include "xla/literal.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" +#include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" +#include "tsl/platform/path.h" +#include "tsl/platform/test.h" namespace xla { namespace exhaustive_op_test { +int eup_version = 0; + +int GetEupVersion() { return eup_version; } + +bool dump_values = false; + +bool ShouldDumpValues() { return dump_values; } + +void AddExhaustiveFlags(std::vector& flag_list) { + flag_list.push_back( + tsl::Flag("dump_values", &xla::exhaustive_op_test::dump_values, + "Include to dump files of the expected and actual results " + "(default false).")); +} + bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } bool IsSubnormalReal(xla::complex128 value) { return IsSubnormal(value.real()); } +bool IsMinNormalReal(xla::complex64 value) { return IsMinNormal(value.real()); } + +bool IsMinNormalReal(xla::complex128 value) { + return IsMinNormal(value.real()); +} + bool IsSubnormalImaginary(xla::complex64 value) { return IsSubnormal(value.imag()); } @@ -53,6 +83,72 @@ bool IsSubnormalImaginary(xla::complex128 value) { return IsSubnormal(value.imag()); } +bool IsMinNormalImaginary(xla::complex64 value) { + return IsMinNormal(value.imag()); +} + +bool IsMinPositiveImaginary(xla::complex128 value) { + return IsMinNormal(value.imag()); +} + +/*static*/ ErrorSpec::Builder builder() { return ErrorSpecBuilder(); } + +ErrorSpecBuilder& ErrorSpecBuilder::abs_err(double abs_err) & { + spec_.abs_err = abs_err; + return *this; +} + +ErrorSpecBuilder& ErrorSpecBuilder::rel_err(double rel_err) & { + spec_.rel_err = rel_err; + return *this; +} + +ErrorSpecBuilder& ErrorSpecBuilder::distance_err(int64_t distance_err) & { + spec_.distance_err = distance_err; + return *this; +} + +ErrorSpecBuilder& ErrorSpecBuilder::strict_signed_zeros( + bool strict_signed_zeros) & { + spec_.strict_signed_zeros = strict_signed_zeros; + return *this; +} + +ErrorSpecBuilder& ErrorSpecBuilder::skip_comparison(bool skip_comparison) & { + spec_.skip_comparison = skip_comparison; + return *this; +} + +ErrorSpecBuilder&& ErrorSpecBuilder::abs_err(double abs_err) && { + spec_.abs_err = abs_err; + return std::move(*this); +} + +ErrorSpecBuilder&& ErrorSpecBuilder::rel_err(double rel_err) && { + spec_.rel_err = rel_err; + return std::move(*this); +} + +ErrorSpecBuilder&& ErrorSpecBuilder::distance_err(int64_t distance_err) && { + spec_.distance_err = distance_err; + return std::move(*this); +} + +ErrorSpecBuilder&& ErrorSpecBuilder::strict_signed_zeros( + bool strict_signed_zeros) && { + spec_.strict_signed_zeros = strict_signed_zeros; + return std::move(*this); +} + +ErrorSpecBuilder&& ErrorSpecBuilder::skip_comparison(bool skip_comparison) && { + spec_.skip_comparison = skip_comparison; + return std::move(*this); +} + +ErrorSpecBuilder::operator ErrorSpec() && { return std::move(*this).build(); } + +ErrorSpec ErrorSpecBuilder::build() && { return spec_; } + // For f64, f32, f16, and bf16, we need 17, 9, 5, and 4 decimal places of // precision to be guaranteed that we're printing the full number. // @@ -328,6 +424,22 @@ std::string StringifyNum(const std::array& inputs) { return absl::StrCat("(", absl::StrJoin(str_vals, ", "), ")"); } +template +void PrintSkipped(int64_t* skipped, const ErrorGenerator& err_generator) { + // We send some fixed amount of skipped messages to the log. The remainder we + // squelch unless we're at vlog level 2. + constexpr int64_t kMaxMismatchesLoggedToErr = 1000; + + (*skipped)++; + if (*skipped < kMaxMismatchesLoggedToErr || VLOG_IS_ON(2)) { + LOG(WARNING) << err_generator(); + } else if (*skipped == kMaxMismatchesLoggedToErr) { + LOG(WARNING) << "Not printing any more skipped messages; pass " + "--vmodule=exhaustive_op_test=2 to see " + "all of them."; + } +} + template void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { // We send a few mismatches to gunit so they show up nicely in test logs. @@ -347,6 +459,7 @@ void PrintMismatch(int64_t* mismatches, const ErrorGenerator& err_generator) { "all of them."; } } + } // namespace template @@ -356,17 +469,45 @@ void ExhaustiveOpTestBase::ExpectNear( OutputRangeCheck check_valid_range) { // Cache for when all components are subnormal testing values. std::vector pure_subnormal_cache; - // Since we take the cross product of all possible test values, and each - // component has kNumSubnormalSubstitutionValues possible test values, then - // the total number of different cache locations are - // kNumSubnormalSubstitutionValues raised to the num_components. - // num_components = N for the reals, and 2*N for the complex. - int64_t max_cache_size = - pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); - pure_subnormal_cache.reserve(max_cache_size); - for (int i = 0; i < max_cache_size; ++i) { - pure_subnormal_cache.push_back(CallOperation( - evaluate_op, FromCacheLocation(i))); + // TODO(b/353790524): Subnormal cache does not seem to work properly with + // more than 1 input. + if constexpr (N == 1) { + // Since we take the cross product of all possible test values, and each + // component has kNumSubnormalSubstitutionValues possible test values, then + // the total number of different cache locations are + // kNumSubnormalSubstitutionValues raised to the num_components. + // num_components = N for the reals, and 2*N for the complex. + int64_t max_cache_size = + pow(kNumSubnormalSubstitutionValues, N * (kIsComplex ? 2 : 1)); + pure_subnormal_cache.reserve(max_cache_size); + for (int i = 0; i < max_cache_size; ++i) { + pure_subnormal_cache.push_back(CallOperation( + evaluate_op, FromCacheLocation(i))); + } + } + + // Dump file for the test. This is unused unless this->should_dump_values is + // true. + std::unique_ptr dump_file; + if (should_dump_values_) { + auto* env = tsl::Env::Default(); + + std::string cleaned_suite_name = + absl::StrReplaceAll(SuiteName(), {{"/", "__"}}); + std::string cleaned_test_name = + absl::StrReplaceAll(TestName(), {{"/", "__"}}); + std::string dump_filename = absl::StrFormat( + "%s_%s_dump.txt", cleaned_suite_name, cleaned_test_name); + + std::string outdir; + if (tsl::io::GetTestUndeclaredOutputsDir(&outdir)) { + dump_filename = tsl::io::JoinPath(outdir, dump_filename); + } + + TF_EXPECT_OK(env->NewWritableFile(dump_filename, &dump_file)); + TF_EXPECT_OK( + dump_file->Append("input values -> actual output {expected output}\n" + "-----------------------------------------------\n")); } NativeInputsList inputs_arr; @@ -377,6 +518,7 @@ void ExhaustiveOpTestBase::ExpectNear( absl::Span result_arr = result_literal.data(); + int64_t skipped = 0; int64_t mismatches = 0; for (int64_t i = 0; i < result_arr.size(); ++i) { @@ -391,7 +533,36 @@ void ExhaustiveOpTestBase::ExpectNear( NativeT actual = result_arr[i]; NativeT expected = static_cast(CallOperation(evaluate_op, inputs_ref_ty)); + + // Dump input, actual, and expected values _before_ we do error checking to + // avoid the continues. + if (should_dump_values_) { + std::string result_string; + absl::StrAppend( + &result_string, + StringifyNum(inputs), " -> ", + StringifyNum(actual)); + absl::StrAppend(&result_string, " {", + StringifyNum(expected), + "}"); + absl::StrAppend(&result_string, "\n"); + TF_EXPECT_OK(dump_file->Append(result_string)); + } + ErrorSpec error_spec = CallErrorSpec(error_spec_gen, inputs); + ASSERT_GE(error_spec.abs_err, 0.0); + ASSERT_GE(error_spec.rel_err, 0.0); + ASSERT_GE(error_spec.distance_err, 0.0); + + if (error_spec.skip_comparison) { + PrintSkipped(&skipped, [&] { + return absl::StrFormat( + "skipping tolerance check for input %s due to " + "ErrorSpec::skip_comparison", + StringifyNum(inputs)); + }); + continue; + } if (check_valid_range != nullptr && !check_valid_range(inputs, actual)) { PrintMismatch(&mismatches, [&] { @@ -431,13 +602,19 @@ void ExhaustiveOpTestBase::ExpectNear( for (NativeRefInputs test_value : subnormal_test_inputs) { NativeRefT result; - int cache_loc = - GetCacheLocation( - test_value); - if (cache_loc == kInvalidCacheIndex) { - result = CallOperation(evaluate_op, test_value); + // TODO(b/353790524): Subnormal cache does not seem to work properly with + // more than 1 input. + if constexpr (N == 1) { + int cache_loc = + GetCacheLocation(test_value); + if (cache_loc == kInvalidCacheIndex) { + result = CallOperation(evaluate_op, test_value); + } else { + result = pure_subnormal_cache[cache_loc]; + } } else { - result = pure_subnormal_cache[cache_loc]; + result = CallOperation(evaluate_op, test_value); } if (IsClose(result, static_cast(actual), error_spec)) { @@ -476,8 +653,19 @@ void ExhaustiveOpTestBase::ExpectNear( StringifyNum(actual))); PrintMismatch(&mismatches, [mismatch] { return mismatch; }); + + // If we have emitted debug logging, we fail the test execution at the first + // comparison failure to avoid dumping too much log data and ensure the + // relevant debugging information is the last logged data. + if (should_emit_debug_logging_) { + ASSERT_TRUE(false); + } } EXPECT_EQ(mismatches, 0); + + if (should_dump_values_) { + TF_EXPECT_OK(dump_file->Close()); + } } template class ExhaustiveOpTestBase; diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h index 80c69703dfb96c..30c9acbe69c86f 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -32,10 +32,12 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "Eigen/Core" #include "xla/bit_cast.h" @@ -43,11 +45,13 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/fp_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/primitive_util.h" #include "xla/service/shaped_buffer.h" #include "xla/tests/client_library_test_base.h" +#include "xla/tsl/util/command_line_flags.h" #include "xla/types.h" #include "xla/xla_data.pb.h" #include "tsl/platform/statusor.h" @@ -55,19 +59,49 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { -// Determines if the real component of the complex number is subnormal. +// Access this through GetEupVersion. +extern int eup_version; + +// Get the TPU EUP version (if it was provided). +int GetEupVersion(); + +// Return if the user specified dumping all tested values with their expected +// and actual results. +bool ShouldDumpValues(); + +void AddExhaustiveFlags(std::vector& flag_list); + +// Determines if the real component of the complex number is subnormal (either +// sign). // // See also IsSubnormal to check if either component is subnormal. bool IsSubnormalReal(xla::complex64); bool IsSubnormalReal(xla::complex128); -// Determines if the imaginary component of the complex number is subnormal. +// Determines if the real component of the complex number is the minimum +// normal floating point value (either sign). +// +// See also IsMinPositive to check if either component is the minimum normal +// floating point value. +bool IsMinNormalReal(xla::complex64); +bool IsMinNormalReal(xla::complex128); + +// Determines if the imaginary component of the complex number is subnormal +// (either sign). // // See also IsSubnormal to check if either component is subnormal. bool IsSubnormalImaginary(xla::complex64); bool IsSubnormalImaginary(xla::complex128); -// Determines if the NativeT is subnormal. +// Determines if the imaginary component of the complex number is the minimum +// normal floating point value (either sign). +// +// See also IsMinPositive to check if either component is the minimum normal +// floating point value. +bool IsMinNormalImaginary(xla::complex64); +bool IsMinNormalImaginary(xla::complex128); + +// Determines if the NativeT is subnormal (either sign). // // For complex numbers, this will return true if either real or imaginary // component is subnormal. See IsSubnormalReal and IsSubnormalImaginary if you @@ -82,14 +116,104 @@ bool IsSubnormal(NativeT value) { } } +// Determines if the NativeT is the minimum normal floating point value +// (either sign). +// +// For complex numbers, this will return true if either real or imaginary +// component is the minimum normal floating point value. See IsMinPositiveReal +// and IsMinPositiveImaginary if you only care about one component. +template +bool IsMinNormal(NativeT value) { + if constexpr (std::is_same_v || + std::is_same_v) { + return IsMinNormalReal(value) || IsMinNormalImaginary(value); + } else { + return std::abs(value) == std::numeric_limits::min(); + } +} + +// Determines if the NativeT is subnormal or the minimum normal floating point +// value (either sign). +// +// For complex numbers, this will return true if either real or imaginary +// component is subnormal or the minimum normal floating point value. +template +bool IsSubnormalOrMinNormal(NativeT value) { + return IsSubnormal(value) || IsMinNormal(value); +} + +// Get the floating point distance (number of floating point values between) +// expected and actual. +// +// This is a wrapper around xla::CalculateDistanceInFloats for most types. For +// complex types, this returns the maximum distance between the real and +// imaginary components. +template +int64_t GetDistanceErr(NativeT expected, NativeT actual) { + if constexpr (std::is_same_v || + std::is_same_v) { + return std::max( + CalculateDistanceInFloats(expected.real(), actual.real()), + CalculateDistanceInFloats(expected.imag(), expected.imag())); + } else { + return CalculateDistanceInFloats(expected, actual); + } +} + +class ErrorSpecBuilder; + struct ErrorSpec { - double abs_err = 0; - double rel_err = 0; + using Builder = ErrorSpecBuilder; + + double abs_err = 0.0; + double rel_err = 0.0; + // The acceptable amount of floating point values between the expected and + // actual (also calling floating point distance). + // + // This is similar to absolute error, but the same distance_err can have + // different floating point values as the exponent changes. In some way, it is + // a hybrid of absolute and relative error, as it allows a fixed binary + // difference (like abs_err), but that has a varied floating point value based + // on the number (like rel_err). + int64_t distance_err = 0; // If true, will consider -0 not near to +0 and vice versa. Note that // +epsilon may still be considered close to -0, depending on the error // spec; this only covers the case when both `expected` and `actual` are // equal to 0. bool strict_signed_zeros = false; + // If true, this will skip comparing the output of the test to the expected + // value. This should be used only as a last resort, since it is effectively + // turning off the test for a specific input value set. + bool skip_comparison = false; +}; + +// Builder pattern to construct an ErrorSpec without a proliferation of +// constructors or requiring extensive argument name comments. +// +// You can use an lvalue or rvalue to call the setter functions, but you can +// only build (explicitly or implicitly) using an rvalue from std::move. +class ErrorSpecBuilder { + public: + ErrorSpecBuilder() : spec_() {} + + ErrorSpecBuilder& abs_err(double abs_err) &; + ErrorSpecBuilder& rel_err(double rel_err) &; + ErrorSpecBuilder& distance_err(int64_t distance_err) &; + ErrorSpecBuilder& strict_signed_zeros(bool strict_signed_zeros = true) &; + ErrorSpecBuilder& skip_comparison(bool skip_comparison = true) &; + + ErrorSpecBuilder&& abs_err(double abs_err) &&; + ErrorSpecBuilder&& rel_err(double rel_err) &&; + ErrorSpecBuilder&& distance_err(int64_t distance_err) &&; + ErrorSpecBuilder&& strict_signed_zeros(bool strict_signed_zeros = true) &&; + ErrorSpecBuilder&& skip_comparison(bool skip_comparison = true) &&; + + ErrorSpec build() &&; + + explicit operator ErrorSpec() &&; + + private: + ErrorSpec spec_; }; // Representations of the reference function passed in by the user. @@ -201,7 +325,10 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { using OutputRangeCheck = std::function; explicit ExhaustiveOpTestBase() - : ty_(T), platform_(client_->platform()->Name()) { + : ty_(T), + platform_(client_->platform()->Name()), + eup_version_(xla::exhaustive_op_test::GetEupVersion()), + should_dump_values_(xla::exhaustive_op_test::ShouldDumpValues()) { SetFastMathDisabled(true); // Run all HLO passes. In particular, constant folding is disabled by @@ -209,6 +336,21 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { mutable_debug_options()->clear_xla_disable_hlo_passes(); } + // Enable debug logging for the invocation of the lambda. + // + // This is intended to be used to wrap a call to `Run`, which will then log + // extra debug information for a failure such as the calculated absolute, + // relative, and distance errors. In addition, in an effort to reduce output + // log size, this will trigger an ASSERT failure to early return from a test + // at the first failure. + template , int> = 0> + void EnableDebugLoggingForScope(Callable&& work) { + should_emit_debug_logging_ = true; + work(); + should_emit_debug_logging_ = false; + } + void Run(EnqueueOp enqueue_op, EvaluateOp evaluate_op, OutputRangeCheck check_valid_range = nullptr) { Run(enqueue_op, evaluate_op, GetDefaultSpecGenerator(), @@ -239,6 +381,7 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { TF_ASSERT_OK_AND_ASSIGN(XlaComputation comp, builder.Build()); TF_ASSERT_OK_AND_ASSIGN(Literal result_literal, RunComputationHelper(comp, input_literals)); + ExpectNear(input_literals, result_literal, evaluate_op, error_spec_gen, check_valid_range); } @@ -321,6 +464,20 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { const std::string& Platform() { return platform_; } + bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } + bool IsCpu(const std::string& platform) const { return platform == "Host"; } + bool IsTpu(const std::string& platform) const { + return !IsGpu(platform) && !IsCpu(platform); + } + + int EupVersion() const { return eup_version_; } + bool IsPreV5Tpu(const std::string& platform) const { + return IsTpu(platform) && eup_version_ < 2; + } + bool IsPreV6Tpu(const std::string& platform) const { + return IsTpu(platform) && eup_version_ < 3; + } + // Returns the number of elements in each input literal. virtual int64_t GetInputSize() = 0; @@ -512,8 +669,21 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { double abs_err = std::abs(ReplaceInfWithMax(expected) - ReplaceInfWithMax(actual)); double rel_err = abs_err / std::abs(ReplaceInfWithMax(expected)); - - return abs_err <= spec.abs_err || rel_err <= spec.rel_err; + // N.B.: For sub-32-bit floats, NativeRefT is `float`, so ULP comparisons + // will be wildly off. We convert back to NativeT for this comparison. + int64_t distance_err = GetDistanceErr(NativeT(expected), NativeT(actual)); + + bool passed = abs_err <= spec.abs_err || rel_err <= spec.rel_err || + distance_err <= spec.distance_err; + if (should_emit_debug_logging_ && !passed) { + LOG(INFO) << "actual: " << actual << "; expected: " << expected + << "\n\tabs_err: " << abs_err + << "; spec.abs_err: " << spec.abs_err + << "\n\trel_err: " << rel_err << "; spec.rel_err: " << rel_err + << "\n\tdistance_err: " << distance_err + << "; spec.distance_err: " << spec.distance_err; + } + return passed; } // Converts part or all bits in an uint64_t to the value of the floating point @@ -545,9 +715,15 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // The platform under test. const std::string platform_; - // Testing will ignore inputs for which known_incorrect_fn_ returns true. The - // argument to the function is the raw bits for the data being test, zero - // extended to 64 bits if the data type is less than 64 bits. + // Version of the EUP for a TPU target. Only relevant for TPU platforms. + const int eup_version_; + + // Testing will ignore inputs for which known_incorrect_fn_ returns true. + // The argument to the function is the raw bits for the data being test, + // zero extended to 64 bits if the data type is less than 64 bits. + // + // DEPRECATED: Please see ErrorSpec::skip_comparison for an easier framework + // to skip nearness checks for certain unary or binary inputs. std::function known_incorrect_fn_; // If true, allows denormals to be flushed to non-sign-preserving 0. @@ -558,6 +734,13 @@ class ExhaustiveOpTestBase : public ClientLibraryTestBase { // // XLA:GPU preserves denormal signs, but other backends don't. bool relaxed_denormal_signs_ = platform_ != "CUDA"; + + // Indicates if files of the expected and actual values should be dumped. + bool should_dump_values_ = false; + + // Indicates if additional (potentially costly) logging should be emitted to + // ease with debugging. + bool should_emit_debug_logging_ = false; }; // Represents a set of 64 bit chunks by representing the starting bit chunk, @@ -1000,7 +1183,7 @@ inline ErrorSpec DefaultSpecGenerator(complex128) { kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = kDefaultRelativeToleranceSlackFactor * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1008,7 +1191,7 @@ inline ErrorSpec DefaultSpecGenerator(complex64) { double atol = kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = 40 * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1017,7 +1200,7 @@ inline ErrorSpec DefaultSpecGenerator(double) { kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = kDefaultRelativeToleranceSlackFactor * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1026,7 +1209,7 @@ inline ErrorSpec DefaultSpecGenerator(float) { kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = kDefaultRelativeToleranceSlackFactor * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1035,7 +1218,7 @@ inline ErrorSpec DefaultSpecGenerator(Eigen::half) { std::numeric_limits::min(); // epsilon for FP16 is quite large, so a slack factor of 5 suffices. double rtol = 5 * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1044,7 +1227,7 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16) { std::numeric_limits::min(); // epsilon for BF16 is quite large, so a slack factor of 2 suffices. double rtol = 2 * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1053,7 +1236,7 @@ inline ErrorSpec DefaultSpecGenerator(double, double) { kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = kDefaultRelativeToleranceSlackFactor * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1062,7 +1245,7 @@ inline ErrorSpec DefaultSpecGenerator(float, float) { kDefaultAbsoluteToleranceSlackFactor * std::numeric_limits::min(); double rtol = kDefaultRelativeToleranceSlackFactor * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1071,7 +1254,7 @@ inline ErrorSpec DefaultSpecGenerator(Eigen::half, Eigen::half) { std::numeric_limits::min(); // epsilon for FP16 is quite large, so a slack factor of 5 suffices. double rtol = 5 * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template <> @@ -1080,7 +1263,7 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16, bfloat16) { std::numeric_limits::min(); // epsilon for BF16 is quite large, so a slack factor of 5 suffices. double rtol = 2 * std::numeric_limits::epsilon(); - return ErrorSpec{atol, rtol}; + return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } template @@ -1138,7 +1321,13 @@ class ExhaustiveUnaryTest : public ExhaustiveOpTestBase { }; template -using ExhaustiveBinaryTest = ExhaustiveOpTestBase; +class ExhaustiveBinaryTest : public ExhaustiveOpTestBase { + public: + using typename ExhaustiveOpTestBase::ErrorSpecGen; + static ErrorSpecGen GetDefaultSpecGenerator() { + return exhaustive_op_test::GetDefaultSpecGenerator(); + } +}; } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc index 88a9befba9c74e..cc1bc9dd5533e8 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_test_main.cc @@ -18,18 +18,23 @@ limitations under the License. // the --benchmark_filter flag which specifies which benchmarks to run, // we will either run benchmarks or run the gtest tests in the program. -#include "tsl/platform/test.h" - -namespace xla { -namespace exhaustive_op_test { - -static int eup_version = 0; -int GetEupVersion() { return eup_version; } +#include +#include -} // namespace exhaustive_op_test -} // namespace xla +#include "xla/tests/exhaustive/exhaustive_op_test_utils.h" +#include "xla/tsl/util/command_line_flags.h" +#include "tsl/platform/logging.h" +#include "tsl/platform/test.h" GTEST_API_ int main(int argc, char** argv) { + std::vector flag_list; + xla::exhaustive_op_test::AddExhaustiveFlags(flag_list); + std::string usage = tsl::Flags::Usage(argv[0], flag_list); + if (!tsl::Flags::Parse(&argc, argv, flag_list)) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc index b7cc275bfff11e..3bba8b80f9967d 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_complex_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +namespace { // T is the Primitive Type of the complex number // Test parameter is a tuple containing @@ -71,14 +72,16 @@ class ExhaustiveComplexUnaryTestBase void FillInput(std::array* input_literal) override { FpValues real_values = std::get<0>(GetParam()); FpValues imag_values = std::get<1>(GetParam()); - - VLOG(2) << " testing input total " - << real_values.GetTotalNumValues() * imag_values.GetTotalNumValues() - << ", range " << real_values.ToString() << " " - << imag_values.ToString(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\treal values=" << real_values.ToString(); + LOG(INFO) << "\timag values=" << imag_values.ToString(); + LOG(INFO) << "\ttotal values to test=" + << real_values.GetTotalNumValues() * + imag_values.GetTotalNumValues(); + } absl::Span input_arr = (*input_literal)[0].data(); - uint64_t i = 0; for (auto real : real_values) { for (auto imag : imag_values) { @@ -325,5 +328,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn( GetFpValuesForMagnitudeExtremeNormals(40000, 2000)))); +} // namespace } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc index fff08fafc4b8db..c40cd8a8131656 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f32_or_smaller_test.cc @@ -21,9 +21,8 @@ limitations under the License. #include #include #include +#include #include -#include -#include #include #include @@ -45,8 +44,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { - -extern int GetEupVersion(); +namespace { using Eigen::half; @@ -195,35 +193,16 @@ template class Exhaustive32BitOrLessUnaryTest : public ExhaustiveUnaryTest, public ::testing::WithParamInterface> { - public: - public: - Exhaustive32BitOrLessUnaryTest() - : eup_version_(xla::exhaustive_op_test::GetEupVersion()) {} - public: // Sets error parameters appropriately for testing tan. void SetParamsForTan(); - bool IsGpu(const std::string& platform) const { return platform == "CUDA"; } - bool IsCpu(const std::string& platform) const { return platform == "Host"; } - bool IsTpu(const std::string& platform) const { - return !IsGpu(platform) && !IsCpu(platform); - } - int EupVersion() const { return eup_version_; } - bool IsPreV5Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 2; - } - bool IsPreV6Tpu(const std::string& platform) const { - return IsTpu(platform) && eup_version_ < 3; - } - protected: using typename ExhaustiveUnaryTest::NativeT; private: int64_t GetInputSize() override { auto [begin, end] = GetParam(); - VLOG(2) << "Checking range [" << begin << ", " << end << ")"; return end - begin; } @@ -238,8 +217,18 @@ class Exhaustive32BitOrLessUnaryTest typename ExhaustiveOpTestBase::ComponentIntegralNativeT; auto [begin, end] = GetParam(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + LOG(INFO) << "\tfrom=" << begin << "; hex=" << std::hex << begin + << "; float=" << *reinterpret_cast(&begin) + << " (inclusive)"; + LOG(INFO) << "\tto=" << end << "; hex=" << std::hex << end + << "; float=" << *reinterpret_cast(&end) + << " (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + int64_t input_size = (*input_literal)[0].element_count(); - VLOG(2) << "Checking range [" << begin << ", " << end << ")"; CHECK_EQ(input_size, end - begin); absl::Span input_arr = (*input_literal)[0].data(); @@ -249,8 +238,6 @@ class Exhaustive32BitOrLessUnaryTest this->ConvertAndReplaceKnownIncorrectValueWith(input_val, 0); } } - - const int eup_version_; }; using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; @@ -635,76 +622,93 @@ UNARY_TEST_FLOAT_32_BITS_OR_LESS(RoundNearestEven, { fesetround(curr_direction); }) +// Can be thought of as an absolute error of `<= +// |std::numeric_limits::min()|`. template -double reciprocal_abs_error(NativeT val) { - double abs_err = 0.0; - - // For subnormals, we need to set absolute error to the smallest positive - // representable value due to hardware implementations that truncate - // subnormals to zero. - bool is_subnormal_output = - std::numeric_limits::denorm_min() <= std::abs(1 / val) && - std::abs(1 / val) <= std::numeric_limits::min(); - if (is_subnormal_output) { - abs_err = std::numeric_limits::min(); +double ReciprocalCpuGpuAbsError(NativeT val) { + float output = 1.0f / static_cast(val); + + if (IsSubnormal(output)) { + return std::numeric_limits::min(); + } + + return 0.0; +} + +// Can be thought of as an absolute error of `<= +// |std::numeric_limits::min()|`. +template +double ReciprocalTpuAbsError(NativeT val) { + float output = 1.0f / static_cast(val); + + // TPU seems to flush subnormals or minimum normal to 0. We set the error to + // the minimum normal in these cases. + if (IsSubnormalOrMinNormal(output)) { + return std::numeric_limits::min(); } - return abs_err; + return 0.0; } UNARY_TEST_FLOAT_32_BITS_OR_LESS(Reciprocal, { ErrorSpecGen error_spec_gen = - +[](NativeT) { return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; }; + +[](NativeT) { return ErrorSpec{.strict_signed_zeros = true}; }; if (IsCpu(platform_)) { error_spec_gen = +[](NativeT val) { - return ErrorSpec{.abs_err = reciprocal_abs_error(val), .rel_err = 0.0}; + return ErrorSpec{.abs_err = ReciprocalCpuGpuAbsError(val), + .strict_signed_zeros = true}; }; } if (IsGpu(platform_)) { error_spec_gen = +[](NativeT val) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = reciprocal_abs_error(val), .rel_err = eps}; + return ErrorSpec{.abs_err = ReciprocalCpuGpuAbsError(val), + .rel_err = eps, + .strict_signed_zeros = true}; }; } if (IsTpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = eps, + .strict_signed_zeros = true}; } }; } if (IsPreV6Tpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = 34 * eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = 34 * eps, + .strict_signed_zeros = true}; } }; } if (IsPreV5Tpu(platform_)) { error_spec_gen = +[](NativeT val) { - auto abs_err = reciprocal_abs_error(val); - if constexpr (std::is_same()) { - return ErrorSpec{.abs_err = abs_err, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { - // N.B.: Does not require absolute error. - return ErrorSpec{.abs_err = 0.0, .rel_err = 0.0}; - } else if constexpr (std::is_same()) { + if constexpr (std::is_same_v) { + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { + return ErrorSpec{.strict_signed_zeros = true}; + } else if constexpr (std::is_same_v) { NativeT eps = std::numeric_limits::epsilon(); - return ErrorSpec{.abs_err = abs_err, .rel_err = 136 * eps}; + return ErrorSpec{.abs_err = ReciprocalTpuAbsError(val), + .rel_err = 136 * eps, + .strict_signed_zeros = true}; } }; } @@ -724,5 +728,6 @@ INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); #endif +} // namespace } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc index 8e81769afe8ea0..3f3b9de811fa60 100644 --- a/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc +++ b/third_party/xla/xla/tests/exhaustive/exhaustive_unary_f64_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace exhaustive_op_test { +namespace { // Exhaustive test for unary operations for double. // @@ -52,11 +53,14 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, void FillInput(std::array* input_literal) override { FpValues fp_values = GetParam(); int64_t input_size = (*input_literal)[0].element_count(); - LOG(INFO) << "Checking fp values " << fp_values.ToString() << ", " - << input_size; - absl::Span input_arr = (*input_literal)[0].data(); + if (VLOG_IS_ON(2)) { + LOG(INFO) << this->SuiteName() << this->TestName() << " Values:"; + LOG(INFO) << "\t" << fp_values.ToString(); + LOG(INFO) << "\ttotal values to test=" << input_size; + } uint64_t i = 0; + absl::Span input_arr = (*input_literal)[0].data(); for (auto bits : fp_values) { input_arr[i] = this->ConvertAndReplaceKnownIncorrectValueWith(bits, 1); ++i; @@ -146,5 +150,6 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(GetFpValuesForMagnitudeExtremeNormals( 4000000000ull, 16000000))); +} // namespace } // namespace exhaustive_op_test } // namespace xla diff --git a/third_party/xla/xla/tests/hlo_test_base.cc b/third_party/xla/xla/tests/hlo_test_base.cc index e7367e75a760b9..fe44f7020cbf54 100644 --- a/third_party/xla/xla/tests/hlo_test_base.cc +++ b/third_party/xla/xla/tests/hlo_test_base.cc @@ -15,7 +15,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" +#include #include +#include #include #include #include @@ -27,6 +29,9 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" #include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/hlo/utils/hlo_query.h" #include "xla/layout_util.h" #include "xla/service/hlo_module_util.h" #include "xla/service/hlo_parser.h" @@ -42,8 +47,8 @@ limitations under the License. #include "xla/tests/pjrt_client_registry.h" #include "xla/tests/test_utils.h" #include "xla/tests/verified_hlo_module.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" @@ -532,8 +537,22 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( fake_arguments, std::back_inserter(fake_argument_ptrs), [](const Literal& literal) { return const_cast(&literal); }); - return RunAndCompareNoHloPasses(std::move(module), fake_argument_ptrs, error, - reference_preprocessor); + auto assertion_result = RunAndCompareNoHloPasses( + std::move(module), fake_argument_ptrs, error, reference_preprocessor); + if (!assertion_result) { + for (const auto& literal : fake_arguments) { + uint64_t total_elements = 1; + absl::c_for_each(literal.shape().dimensions(), + [&](int64_t dim) { total_elements *= dim; }); + if (total_elements > 1000) { + LOG(ERROR) << "argument literal is too large to print: " + << literal.shape().ToString(); + continue; + } + LOG(ERROR) << "argument literal: " << literal.ToString(); + } + } + return assertion_result; } ::testing::AssertionResult HloTestBase::Run(std::unique_ptr module, @@ -1010,23 +1029,15 @@ ::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( HloComputation* HloTestBase::FindComputation(HloModule* module, absl::string_view name) { - auto computations = module->computations(); - auto it = absl::c_find_if( - computations, [&](HloComputation* c) { return c->name() == name; }); - if (it == computations.end()) { - return nullptr; - } - return *it; + return hlo_query::FindComputation(module, name); } HloInstruction* HloTestBase::FindInstruction(HloModule* module, absl::string_view name) { - for (const HloComputation* c : module->computations()) { - auto instructions = c->instructions(); - auto it = absl::c_find_if( - instructions, [&](HloInstruction* i) { return i->name() == name; }); - if (it != instructions.end()) { - return *it; + for (const HloComputation* computation : module->computations()) { + if (auto instruction = hlo_query::FindFirstInstruction(computation, name); + instruction.first != nullptr) { + return instruction.first; } } return nullptr; @@ -1034,17 +1045,25 @@ HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloInstruction* HloTestBase::FindInstruction(HloModule* module, HloOpcode opcode) { - for (const HloComputation* c : module->computations()) { - auto instructions = c->instructions(); - auto it = absl::c_find_if( - instructions, [&](HloInstruction* i) { return i->opcode() == opcode; }); - if (it != instructions.end()) { - return *it; + for (const HloComputation* computation : module->computations()) { + if (auto instruction = hlo_query::FindFirstInstruction(computation, opcode); + instruction.first != nullptr) { + return instruction.first; } } return nullptr; } +std::vector HloTestBase::FindInstructions(HloModule* module, + HloOpcode opcode) { + std::vector instructions; + for (const HloComputation* c : module->computations()) { + absl::c_copy_if(c->instructions(), std::back_inserter(instructions), + [&](HloInstruction* i) { return i->opcode() == opcode; }); + } + return instructions; +} + se::DeviceMemoryAllocator* HloTestBase::GetAllocator() { if (allocator_ == nullptr) { allocator_ = std::make_unique( diff --git a/third_party/xla/xla/tests/hlo_test_base.h b/third_party/xla/xla/tests/hlo_test_base.h index 9858ed6f53997d..4c194c8de351a0 100644 --- a/third_party/xla/xla/tests/hlo_test_base.h +++ b/third_party/xla/xla/tests/hlo_test_base.h @@ -25,8 +25,10 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module_group.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/backend.h" #include "xla/service/computation_layout.h" #include "xla/service/hlo_runner.h" @@ -423,13 +425,19 @@ class HloTestBase : public ManifestCheckingTest { } // Gets the computation/instruction from the given module with the given name. - // + // Note that it is encouraged to use these functions directly via the + // hlo_query.h header instead since they are independent from any test-time + // variables or contexts. + // This is useful for tests which create HLOs from a string and then want to // inspect a particular computation or instruction. HloComputation* FindComputation(HloModule* module, absl::string_view name); HloInstruction* FindInstruction(HloModule* module, absl::string_view name); // Gets the instruction from the given module with the given opcode. HloInstruction* FindInstruction(HloModule* module, HloOpcode opcode); + // Gets all the instructions from the given module with the given opcode. + std::vector FindInstructions(HloModule* module, + HloOpcode opcode); // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } @@ -438,6 +446,7 @@ class HloTestBase : public ManifestCheckingTest { // Returns the backend owned by the test runner. Backend& backend(); + int64_t num_devices() { return backend().device_count(); } HloRunner test_runner_; HloRunner reference_runner_; @@ -513,6 +522,13 @@ class HloTestBase : public ManifestCheckingTest { se::Platform* test_platform); }; +#define SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(x) \ + int64_t num_devices = backend().device_count(); \ + if (num_devices < x) { \ + GTEST_SKIP() << "Test requires at least " << x << " devices (" \ + << num_devices << " available)"; \ + } + } // namespace xla #endif // XLA_TESTS_HLO_TEST_BASE_H_ diff --git a/third_party/xla/xla/tests/llvm_irgen_test_base.cc b/third_party/xla/xla/tests/llvm_irgen_test_base.cc index fae82d29c84954..db3d06c69f62dd 100644 --- a/third_party/xla/xla/tests/llvm_irgen_test_base.cc +++ b/third_party/xla/xla/tests/llvm_irgen_test_base.cc @@ -21,7 +21,7 @@ limitations under the License. #include "absl/status/status.h" #include "xla/service/llvm_ir/llvm_util.h" #include "xla/tests/filecheck.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/tests/multioutput_fusion_test.cc b/third_party/xla/xla/tests/multioutput_fusion_test.cc index f6ec09a1e35a2f..97ee8b70575426 100644 --- a/third_party/xla/xla/tests/multioutput_fusion_test.cc +++ b/third_party/xla/xla/tests/multioutput_fusion_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/protobuf.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc index a24b5594f484bc..8aa1502a3a951d 100644 --- a/third_party/xla/xla/tests/multiple_devices_on_host_test.cc +++ b/third_party/xla/xla/tests/multiple_devices_on_host_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/shape_util.h" #include "xla/stream_executor/platform_manager.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/multithreaded_compilation_test.cc b/third_party/xla/xla/tests/multithreaded_compilation_test.cc index cbbfedab4f7e84..b9cb8d253cb511 100644 --- a/third_party/xla/xla/tests/multithreaded_compilation_test.cc +++ b/third_party/xla/xla/tests/multithreaded_compilation_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" #include "tsl/platform/threadpool.h" diff --git a/third_party/xla/xla/tests/numerics_test.cc b/third_party/xla/xla/tests/numerics_test.cc index 8a542423f5f909..b1bfcd9ed24d4c 100644 --- a/third_party/xla/xla/tests/numerics_test.cc +++ b/third_party/xla/xla/tests/numerics_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" #include "xla/types.h" +#include "tsl/platform/statusor.h" #include "tsl/platform/test.h" namespace xla { @@ -86,5 +87,79 @@ ENTRY entry { std::numeric_limits::quiet_NaN(), 0)); } +// Case from one of XLA users, the following code produced incorrect results on +// CPU thunks backend (due to incorrect LLVM IR generated). +// This is an HLO module optimized for CPU backend, it may be invalid for other +// backends. +XLA_TEST_F(NumericsTest, + DISABLED_ON_GPU(DISABLED_ON_TPU(MultiplySubtractConcatTest))) { + const char* test_hlo = R"( + HloModule jit_step, is_scheduled=true + + fused_computation { + param_0.2 = f32[1,5] parameter(0) + slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]} + slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]} + multiply.11 = f32[1,1] multiply(slice.11, slice.10) + slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]} + slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]} + multiply.10 = f32[1,1] multiply(slice.9, slice.8) + subtract.5 = f32[1,1] subtract(multiply.11, multiply.10) + slice.6 = f32[1,1] slice(param_0.2), slice={[0:1], [0:1]} + multiply.8 = f32[1,1] multiply(slice.6, slice.10) + subtract.4 = f32[1,1] subtract(slice.9, multiply.8) + ROOT concatenate.1 = f32[1,3] concatenate( + subtract.5, subtract.4, subtract.4), dimensions={1} + } // fused_computation + + ENTRY main { + Arg_0.0 = f32[1,5] parameter(0) + ROOT fusion = f32[1,3] fusion(Arg_0.0), kind=kLoop, + calls=fused_computation + } // main + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto test_module, + ParseAndReturnVerifiedModule(test_hlo)); + auto argument = LiteralUtil::CreateR2( + {{0.261473775, -0.642940283, -0.719902277, 0.712947428, 0.543724537}}); + + TF_ASSERT_OK_AND_ASSIGN(auto test_result, + Execute(std::move(test_module), {&argument}, + /*run_hlo_passes=*/false)); + + // Reference HLO module. It's a subgraph of the test module, it performs only + // the calculations needed for the first output element from the test module. + const char* reference_hlo = R"( + HloModule jit_step, is_scheduled=true + + fused_computation { + param_0.2 = f32[1,5] parameter(0) + slice.11 = f32[1,1] slice(param_0.2), slice={[0:1], [1:2]} + slice.10 = f32[1,1] slice(param_0.2), slice={[0:1], [4:5]} + multiply.11 = f32[1,1] multiply(slice.11, slice.10) + slice.9 = f32[1,1] slice(param_0.2), slice={[0:1], [2:3]} + slice.8 = f32[1,1] slice(param_0.2), slice={[0:1], [3:4]} + multiply.10 = f32[1,1] multiply(slice.9, slice.8) + ROOT subtract.5 = f32[1,1] subtract(multiply.11, multiply.10) + } // fused_computation + + ENTRY main { + Arg_0.0 = f32[1,5] parameter(0) + ROOT fusion = f32[1,1] fusion(Arg_0.0), kind=kLoop, + calls=fused_computation + } // main + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto reference_module, + ParseAndReturnVerifiedModule(reference_hlo)); + TF_ASSERT_OK_AND_ASSIGN(auto reference_result, + Execute(std::move(reference_module), {&argument}, + /*run_hlo_passes=*/false)); + + // Only compare the first element. + EXPECT_EQ(reference_result.data()[0], test_result.data()[0]); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/tests/outfeed_in_nested_computation_test.cc b/third_party/xla/xla/tests/outfeed_in_nested_computation_test.cc index b111c620e96f2e..44250f502e7a1b 100644 --- a/third_party/xla/xla/tests/outfeed_in_nested_computation_test.cc +++ b/third_party/xla/xla/tests/outfeed_in_nested_computation_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/pred_test.cc b/third_party/xla/xla/tests/pred_test.cc index 89ba59cd70b356..9f8af6a013d677 100644 --- a/third_party/xla/xla/tests/pred_test.cc +++ b/third_party/xla/xla/tests/pred_test.cc @@ -21,7 +21,7 @@ limitations under the License. #include "xla/client/local_client.h" #include "xla/client/xla_builder.h" #include "xla/tests/client_library_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/tests/reduce_test.cc b/third_party/xla/xla/tests/reduce_test.cc index 460b920d821eaa..f5db7397cad818 100644 --- a/third_party/xla/xla/tests/reduce_test.cc +++ b/third_party/xla/xla/tests/reduce_test.cc @@ -57,9 +57,9 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/tests/reduce_window_test.cc b/third_party/xla/xla/tests/reduce_window_test.cc index ccbb8f4f3cb8ba..c65cd9c9af1969 100644 --- a/third_party/xla/xla/tests/reduce_window_test.cc +++ b/third_party/xla/xla/tests/reduce_window_test.cc @@ -35,8 +35,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tests/replicated_io_feed_test.cc b/third_party/xla/xla/tests/replicated_io_feed_test.cc index d3600e4602f135..0164f8b6b30e69 100644 --- a/third_party/xla/xla/tests/replicated_io_feed_test.cc +++ b/third_party/xla/xla/tests/replicated_io_feed_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" // Tests replicated infeed/outfeed operations. @@ -50,7 +50,10 @@ XLA_TEST_F(ReplicatedIOFeedTest, InfeedAndOutfeed) { result = u32[] add(infeed.data, replica_id) outfeed = token[] outfeed(result, infeed.token), outfeed_shape=u32[] })"; + const int kNumReplicas = 4; + SKIP_TEST_IF_NUM_DEVICES_LESS_THAN(kNumReplicas); + auto config = GetModuleConfigForTest(); config.set_replica_count(kNumReplicas); std::unique_ptr module = diff --git a/third_party/xla/xla/tests/test_utils_test.cc b/third_party/xla/xla/tests/test_utils_test.cc index 82a95589e6b907..22212a02998239 100644 --- a/third_party/xla/xla/tests/test_utils_test.cc +++ b/third_party/xla/xla/tests/test_utils_test.cc @@ -24,7 +24,7 @@ limitations under the License. #include "xla/shape_util.h" #include "xla/tests/local_client_test_base.h" #include "xla/tests/test_macros.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/triangular_solve_test.cc b/third_party/xla/xla/tests/triangular_solve_test.cc index a82720008f9a42..b04ac99d4110e4 100644 --- a/third_party/xla/xla/tests/triangular_solve_test.cc +++ b/third_party/xla/xla/tests/triangular_solve_test.cc @@ -29,9 +29,9 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/tests/tuple_test.cc b/third_party/xla/xla/tests/tuple_test.cc index b0d765414d256c..8d6c1c641579e9 100644 --- a/third_party/xla/xla/tests/tuple_test.cc +++ b/third_party/xla/xla/tests/tuple_test.cc @@ -29,8 +29,8 @@ limitations under the License. #include "xla/tests/hlo_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/tests/value_inference_test.cc b/third_party/xla/xla/tests/value_inference_test.cc index 4fbc3356b62717..50da08967a01eb 100644 --- a/third_party/xla/xla/tests/value_inference_test.cc +++ b/third_party/xla/xla/tests/value_inference_test.cc @@ -36,8 +36,8 @@ limitations under the License. #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" diff --git a/third_party/xla/xla/tests/while_test.cc b/third_party/xla/xla/tests/while_test.cc index 7e6c2af60f4183..473875960fdd16 100644 --- a/third_party/xla/xla/tests/while_test.cc +++ b/third_party/xla/xla/tests/while_test.cc @@ -31,8 +31,8 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/literal_test_util.h" #include "xla/tests/test_macros.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/logging.h" #include "tsl/platform/test.h" #include "tsl/platform/test_benchmark.h" diff --git a/third_party/xla/xla/tests/xla_hlo_profile_test.cc b/third_party/xla/xla/tests/xla_hlo_profile_test.cc index 2436635dea5ef0..72e4387da2beb8 100644 --- a/third_party/xla/xla/tests/xla_hlo_profile_test.cc +++ b/third_party/xla/xla/tests/xla_hlo_profile_test.cc @@ -34,7 +34,7 @@ limitations under the License. #include "xla/tests/client_library_test_base.h" #include "xla/tests/test_macros.h" #include "xla/tests/test_utils.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/regexp.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/text_literal_writer_test.cc b/third_party/xla/xla/text_literal_writer_test.cc index 3c9c2d6161eef1..e517279a4c447d 100644 --- a/third_party/xla/xla/text_literal_writer_test.cc +++ b/third_party/xla/xla/text_literal_writer_test.cc @@ -22,8 +22,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/test.h" #include "xla/test_helpers.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" namespace xla { diff --git a/third_party/xla/xla/tools/BUILD b/third_party/xla/xla/tools/BUILD index ef4fb3d452f279..390188c7048e93 100644 --- a/third_party/xla/xla/tools/BUILD +++ b/third_party/xla/xla/tools/BUILD @@ -56,7 +56,6 @@ xla_cc_binary( name = "hex_floats_to_packed_literal", srcs = ["hex_floats_to_packed_literal.cc"], deps = [ - "//xla:types", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/base", "@com_google_absl//absl/strings", @@ -127,13 +126,12 @@ xla_cc_binary( name = "convert_computation", srcs = ["convert_computation.cc"], deps = [ - "//xla:types", "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/status:statusor", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:status", ], ) @@ -193,19 +191,24 @@ xla_cc_binary( name = "dumped_computation_to_text", srcs = ["dumped_computation_to_text.cc"], deps = [ - "//xla:types", - "//xla/client", + "//xla:shape_util", + "//xla:xla_proto_cc", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/client:xla_computation", + "//xla/hlo/ir:hlo", "//xla/service", "//xla/service:hlo_proto_cc", "//xla/service:interpreter_plugin", + "//xla/service:local_service", + "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", ], ) @@ -220,15 +223,16 @@ xla_cc_binary( name = "dumped_computation_to_operation_list", srcs = ["dumped_computation_to_operation_list.cc"], deps = [ - "//xla:types", - "//xla/client", + "//xla:shape_util", "//xla/client:client_library", + "//xla/client:executable_build_options", "//xla/client:local_client", "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/service", "//xla/service:hlo_proto_cc", "//xla/service:interpreter_plugin", + "//xla/service:local_service", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -237,6 +241,7 @@ xla_cc_binary( "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", ], ) @@ -309,6 +314,7 @@ xla_cc_binary( xla_cc_binary( name = "hlo-opt", testonly = True, + linkopts = ["-Wl,-rpath,$$ORIGIN/../lit_lib"], deps = [ "//xla/tools/hlo_opt:opt_main", ], @@ -485,7 +491,7 @@ xla_cc_test( ":hlo_module_loader", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test", ], ) @@ -508,6 +514,19 @@ cc_library( ], ) +xla_cc_test( + name = "prepare_reference_module_test", + srcs = ["prepare_reference_module_test.cc"], + deps = [ + ":prepare_reference_module", + "//xla:test", + "//xla/hlo/ir:hlo", + "//xla/tests:hlo_test_base", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test_main", + ], +) + tf_proto_library( name = "run_hlo_module_proto", srcs = ["run_hlo_module.proto"], @@ -587,7 +606,7 @@ xla_cc_test( "//xla:literal", "//xla:literal_util", "//xla:xla_data_proto_cc", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -633,6 +652,7 @@ xla_cc_test( "data/add.hlo", "data/add_mhlo.mlir", "data/add_stablehlo.mlir", + "data/input_literal_f32_2_2.pbtxt", "data/must_alias.hlo", "data/must_alias_with_sharding.hlo", ":run_hlo_module", @@ -642,8 +662,9 @@ xla_cc_test( "//xla:literal_util", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", - "@local_tsl//tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:statusor", @@ -692,8 +713,8 @@ xla_cc_test( "//xla/service/spmd:spmd_partitioner", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/strings", - "@local_tsl//tsl/lib/core:status_test_util", ], ) @@ -718,8 +739,14 @@ xla_cc_binary( srcs = ["compute_cost.cc"], deps = [ ":hlo_module_loader", + "//xla:debug_options_flags", + "//xla:shape_util", "//xla/service:hlo_cost_analysis", + "//xla/tsl/util:command_line_flags", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:platform_port", + "@local_tsl//tsl/platform:status", ], ) @@ -798,32 +825,63 @@ tsl_gpu_library( "//xla/service/gpu:amdgpu_compiler", "//xla/service/gpu:amdgpu_compiler_impl", ]) + if_gpu_is_configured([ - "//xla/service/gpu:autotuner_util", "//xla/service/gpu:executable_proto_cc", "//xla/service/gpu:gpu_compiler", + "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor/gpu:gpu_init", "//xla/service/gpu:gpu_symbol_repository", ]) + if_google(["@com_google_protobuf//:duration_cc_proto"]), ) xla_test( - name = "xla_compile_lib_test", - srcs = ["xla_compile_lib_test.cc"], + name = "xla_cpu_compile_lib_test", + srcs = ["xla_cpu_compile_lib_test.cc"], + backends = [ + "cpu", + ], + data = [ + ":data/add.hlo", + ], + deps = [ + ":xla_compile_lib", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/service:platform_util", + "//xla/service:symbol_repository", + "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/tests:hlo_test_base", + "//xla/tests:xla_internal_test_main", # fixdeps: keep + "//xla/tsl/lib/core:status_test_util", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:env", + "@local_tsl//tsl/platform:env_time", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", + "@local_tsl//tsl/protobuf:status_proto_cc", + ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), +) + +xla_test( + name = "xla_gpu_compile_lib_test", + srcs = ["xla_gpu_compile_lib_test.cc"], backend_tags = { "gpu": ["requires-gpu-nvidia"] + if_google(["config-cuda-only"]), }, backends = [ - "cpu", "gpu", ], tags = ["no_rocm"], data = [ ":data/add.hlo", "//xla/service:xla_aot_compile_test_gpu_target_config.prototxt", + "//xla/service/gpu:gpu_compiler_test_autotune_db.textproto", ], - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]) + if_rocm_is_configured([ - "TENSORFLOW_USE_ROCM=1", - ]), deps = [ ":xla_compile_lib", "//xla:util", @@ -831,23 +889,21 @@ xla_test( "//xla/service:platform_util", "//xla/service:symbol_repository", "//xla/service:xla_compile_result_proto_cc_impl", + "//xla/service/gpu:gpu_symbol_repository", + "//xla/service/gpu/autotuning:autotuner_util", "//xla/stream_executor:device_description_proto_cc", "//xla/tests:hlo_test_base", - "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", # fixdeps: keep - "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", + "//xla/tsl/lib/core:status_test_util", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", - "@local_tsl//tsl/platform:env_time", "@local_tsl//tsl/platform:path", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", "@local_tsl//tsl/protobuf:status_proto_cc", - ] + if_google(["@com_google_protobuf//:duration_cc_proto"]), + ], ) xla_test( diff --git a/third_party/xla/xla/tools/compute_cost.cc b/third_party/xla/xla/tools/compute_cost.cc index c5e0ddcda4e221..9615ae01b59940 100644 --- a/third_party/xla/xla/tools/compute_cost.cc +++ b/third_party/xla/xla/tools/compute_cost.cc @@ -21,9 +21,16 @@ limitations under the License. #include #include +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "xla/debug_options_flags.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/tools/hlo_module_loader.h" +#include "xla/tsl/util/command_line_flags.h" #include "tsl/platform/init_main.h" +#include "tsl/platform/status.h" namespace { const char* const kUsage = R"( diff --git a/third_party/xla/xla/tools/convert_computation.cc b/third_party/xla/xla/tools/convert_computation.cc index f81d517b06e847..7ebc5d3f5aa4cd 100644 --- a/third_party/xla/xla/tools/convert_computation.cc +++ b/third_party/xla/xla/tools/convert_computation.cc @@ -16,6 +16,7 @@ limitations under the License. // Usage: convert_computation serialized_computation_proto // // bin2txt spits out the result to stdout. txt2bin modifies the file in place. +#include "tsl/platform/status.h" #ifndef _WIN32 #include #endif @@ -23,9 +24,7 @@ limitations under the License. #include -#include "absl/status/statusor.h" #include "xla/service/hlo.pb.h" -#include "xla/types.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" diff --git a/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt new file mode 100644 index 00000000000000..6c39d030e855ae --- /dev/null +++ b/third_party/xla/xla/tools/data/input_literal_f32_2_2.pbtxt @@ -0,0 +1,20 @@ +# proto-file: third_party/tensorflow/compiler/xla/tools/run_hlo_module.proto +# proto-message: RunHloModuleIterationLiterals +arguments { + shape { + element_type: F32 + dimensions: 2 + dimensions: 2 + layout { + minor_to_major: 1 + minor_to_major: 0 + tail_padding_alignment_in_elements: 1 + } + is_dynamic_dimension: false + is_dynamic_dimension: false + } + f32s: 0.1 + f32s: 0.2 + f32s: 0.3 + f32s: 0.4 +} \ No newline at end of file diff --git a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc index e17d487c36aede..645021031107ca 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_operation_list.cc @@ -22,20 +22,26 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" -#include "xla/client/client.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/dfs_hlo_visitor_with_default.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo.pb.h" -#include "xla/service/service.h" -#include "xla/types.h" +#include "xla/service/local_service.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { namespace tools { diff --git a/third_party/xla/xla/tools/dumped_computation_to_text.cc b/third_party/xla/xla/tools/dumped_computation_to_text.cc index df9116ec618c29..695d4c928a6866 100644 --- a/third_party/xla/xla/tools/dumped_computation_to_text.cc +++ b/third_party/xla/xla/tools/dumped_computation_to_text.cc @@ -21,16 +21,21 @@ limitations under the License. #include "absl/status/statusor.h" #include "absl/types/span.h" -#include "xla/client/client.h" #include "xla/client/client_library.h" +#include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo.pb.h" -#include "xla/service/service.h" -#include "xla/types.h" +#include "xla/service/local_service.h" +#include "xla/shape.h" +#include "xla/tsl/util/command_line_flags.h" +#include "xla/xla.pb.h" #include "tsl/platform/env.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" +#include "tsl/platform/status.h" namespace xla { namespace tools { diff --git a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc index c4d591ba34928a..6388a8fb84d71c 100644 --- a/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc +++ b/third_party/xla/xla/tools/hex_floats_to_packed_literal.cc @@ -21,10 +21,10 @@ limitations under the License. #include "absl/base/casts.h" #include "absl/strings/string_view.h" #include "xla/tsl/util/command_line_flags.h" -#include "xla/types.h" #include "tsl/lib/io/buffered_inputstream.h" #include "tsl/lib/io/random_inputstream.h" #include "tsl/platform/env.h" +#include "tsl/platform/file_system.h" #include "tsl/platform/init_main.h" #include "tsl/platform/logging.h" #include "tsl/platform/status.h" diff --git a/third_party/xla/xla/tools/hlo_bisect/BUILD b/third_party/xla/xla/tools/hlo_bisect/BUILD index 3df7fbbccdbad9..7f9747bad6f252 100644 --- a/third_party/xla/xla/tools/hlo_bisect/BUILD +++ b/third_party/xla/xla/tools/hlo_bisect/BUILD @@ -84,6 +84,7 @@ cc_library( "//xla:protobuf_util", "//xla:util", "//xla/hlo/ir:hlo", + "//xla/service:dump", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", diff --git a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc index 0e381514275a7b..d4e6d0d60e70fa 100644 --- a/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc +++ b/third_party/xla/xla/tools/hlo_bisect/hlo_bisect_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/error_spec.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/protobuf_util.h" +#include "xla/service/dump.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_parser.h" #include "xla/service/hlo_proto_util.h" @@ -137,7 +138,7 @@ absl::Status DumpHloModule(HloModule* module, const std::string& file_name, HloProto proto = MakeHloProto(*module); if (output_format == "hlo") { tsl::Env* env = tsl::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(std::string(dir_path))); + TF_RETURN_IF_ERROR(CreateDirIfNeeded(std::string(dir_path), env)); std::string file_path = tsl::io::JoinPath(dir_path, SanitizeFileName(file_name)) + ".hlo"; LOG(INFO) << "Dumped HLO text to " << file_path; @@ -148,8 +149,8 @@ absl::Status DumpHloModule(HloModule* module, const std::string& file_name, .set_compact_operands(false)))); } else if (output_format == "pb") { std::string path; - TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, std::string(dir_path), file_name, &path)); + TF_RETURN_IF_ERROR( + DumpProtoToDirectory(proto, std::string(dir_path), file_name, &path)); LOG(INFO) << "Dumped HLO module proto to " << path; } else { diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc index d37c6b04fcec14..e81f3ef6ee8604 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.cc @@ -496,7 +496,7 @@ absl::StatusOr HloControlFlowFlattening::Run( TF_RETURN_IF_ERROR(RemoveCollective(instruction).status()); } changed = true; - } else if (remove_comm_ && + } else if ((remove_comm_ || remove_id_) && (instruction->opcode() == HloOpcode::kPartitionId || instruction->opcode() == HloOpcode::kReplicaId || (instruction->opcode() == HloOpcode::kCustomCall && diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening.h b/third_party/xla/xla/tools/hlo_control_flow_flattening.h index cff9db4c11a7dd..450aeabc5a25ec 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening.h +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening.h @@ -49,6 +49,8 @@ class HloControlFlowFlattening : public HloModulePass { bool flatten_while_loop = true; bool remove_comm = true; bool remove_host_transfer = false; + // Removes partition-id, replica-id, and slice-id. + bool remove_id = false; }; explicit HloControlFlowFlattening(const Options& options) : while_execution_count_(options.while_execution_count), @@ -57,7 +59,8 @@ class HloControlFlowFlattening : public HloModulePass { remove_infeed_outfeed_(options.remove_infeed_outfeed), flatten_while_loop_(options.flatten_while_loop), remove_host_transfer_(options.remove_host_transfer), - remove_comm_(options.remove_comm) {} + remove_comm_(options.remove_comm), + remove_id_(options.remove_id) {} ~HloControlFlowFlattening() override = default; absl::string_view name() const override { return "control-flow-flattening"; } using HloPassInterface::Run; @@ -102,6 +105,7 @@ class HloControlFlowFlattening : public HloModulePass { HloInstruction* recv_done, absl::flat_hash_set* additional_removed) const; bool remove_comm_; + bool remove_id_; }; // Retrieves the original loop bound. If fail, return a default value. If bounds diff --git a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc index a391a59ebdad34..40d16ca88574cf 100644 --- a/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc +++ b/third_party/xla/xla/tools/hlo_control_flow_flattening_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/service/spmd/spmd_partitioner.h" #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" namespace xla { namespace { @@ -515,6 +515,37 @@ TEST_F(HloControlFlowFlatteningTest, ReplicaIdSucceedsWithChange) { "replica-id.18600"); } +TEST_F(HloControlFlowFlatteningTest, RemoveReplicaIdButKeepAllReduce) { + absl::string_view kHloText = R"( + HloModule RemoveReplicaIdButKeepCollective + +%sum (a: f32[], b: f32[]) -> f32[] { + %a = f32[] parameter(0) + %b = f32[] parameter(1) + ROOT %add = f32[] add(f32[] a, f32[] b) + } + ENTRY ReplicaId { + replica-id.1 = u32[]{:T(128)} replica-id() + ROOT all-reduce.1 = u32[]{:T(128)} all-reduce(replica-id.1), to_apply=sum, replica_groups={} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloText)); + HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{ + /*while_execution_count=*/1, /*max_outer_loop_count=*/1, + /*max_loop_count=*/1, /*remove_infeed_outfeed=*/false, + /*flatten_while_loop=*/false, /*remove_comm=*/false, + /*remove_host_transfer=*/false, /*remove_id=*/true}); + EXPECT_TRUE(flattening.Run(module.get()).value()); + TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true, + /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::AllReduce()); + EXPECT_THAT(module->entry_computation()->root_instruction()->operand(0), + op::Constant()); +} + TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) { absl::string_view hlo_string = R"( HloModule CollectivePermuteInPlaceUpdate diff --git a/third_party/xla/xla/tools/hlo_module_loader_test.cc b/third_party/xla/xla/tools/hlo_module_loader_test.cc index e3916a0ec98ac9..16fbe45e4ae451 100644 --- a/third_party/xla/xla/tools/hlo_module_loader_test.cc +++ b/third_party/xla/xla/tools/hlo_module_loader_test.cc @@ -18,7 +18,7 @@ limitations under the License. #include #include "xla/tests/hlo_test_base.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/test.h" namespace xla { diff --git a/third_party/xla/xla/tools/hlo_opt/BUILD b/third_party/xla/xla/tools/hlo_opt/BUILD index c50fb9a25941f2..d9cefa0eddad8a 100644 --- a/third_party/xla/xla/tools/hlo_opt/BUILD +++ b/third_party/xla/xla/tools/hlo_opt/BUILD @@ -175,6 +175,7 @@ lit_test_suite( cfg = "//xla:lit.cfg.py", data = [":test_utilities"], default_tags = tf_cuda_tests_tags(), + hermetic_cuda_data_dir = "%S/../../../../cuda_nvcc", tags_override = { "gpu_hlo_ptx.hlo": ["no_rocm"], }, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD index 6987e41b923995..f485145a5f286c 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/BUILD +++ b/third_party/xla/xla/tools/multihost_hlo_runner/BUILD @@ -84,6 +84,7 @@ xla_cc_binary( testonly = True, tags = [ "gpu", + "no_rocm", "nomac", ] + tf_gpu_tests_tags(), deps = [ @@ -171,12 +172,9 @@ xla_test( name = "functional_hlo_runner_test", srcs = ["functional_hlo_runner_test.cc"], backend_tags = { - # This test is tagged "manual" because it requires multiple (2) GPUs. "gpu": [ - "manual", - "multi_gpu", + "multi_gpu_h100", "no_oss", - "notap", ], }, backends = [ @@ -198,6 +196,7 @@ xla_test( "//xla:xla_proto_cc", "//xla/pjrt:pjrt_client", "//xla/tests:filecheck", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:command_line_flags", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -206,7 +205,6 @@ xla_test( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", "@com_google_googletest//:gtest", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:path", diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc index 1a490a0802ef13..3aabf56650af23 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/create_client.cc @@ -57,43 +57,41 @@ static absl::StatusOr> GetPjRtClient( if (enable_mock_nccl) { CHECK_GT(num_nodes, 1); return CreateMockGpuClient(num_nodes); - } else { - if (num_nodes == 1) { - return CreateGpuClient({}); - } else { - TF_RET_CHECK(!address.empty()); - TF_RET_CHECK(node_id >= 0) - << "Node id is expected to be in range [0, num_nodes)"; - TF_RET_CHECK(node_id < num_nodes) - << "Node id is expected to be in range [0, num_nodes)"; - - CHECK_GT(address.length(), 0); - // Multinode. Start service on task 0. - if (node_id == 0) { - std::string coordinator_bind_address = - "[::]:" + std::string(address).substr(address.rfind(':') + 1); - xla::CoordinationServiceImpl::Options options; - options.num_nodes = num_nodes; - auto status_or = xla::GetDistributedRuntimeService( - coordinator_bind_address, options); - TF_QCHECK_OK(status_or.status()); - service = std::move(status_or.value()); - } - xla::DistributedRuntimeClient::Options options; - options.node_id = node_id; - options.init_timeout = init_timeout; - distributed_client = - GetDistributedRuntimeClient(std::string(address), options); - TF_QCHECK_OK(distributed_client->Connect()); - kv_store = GetDistributedKeyValueStore(distributed_client, - /*key_prefix=*/"gpu:"); - GpuClientOptions gpu_client_options; - gpu_client_options.node_id = node_id; - gpu_client_options.num_nodes = num_nodes; - gpu_client_options.kv_store = kv_store; - return CreateGpuClient(std::move(gpu_client_options)); - } } + + if (num_nodes == 1) { + return CreateGpuClient({}); + } + + TF_RET_CHECK(!address.empty()); + TF_RET_CHECK(node_id >= 0) + << "Node id is expected to be in range [0, num_nodes)"; + TF_RET_CHECK(node_id < num_nodes) + << "Node id is expected to be in range [0, num_nodes)"; + + CHECK_GT(address.length(), 0); + // Multinode. Start service on task 0. + if (node_id == 0) { + std::string coordinator_bind_address = + "[::]:" + std::string(address).substr(address.rfind(':') + 1); + xla::CoordinationServiceImpl::Options options; + options.num_nodes = num_nodes; + TF_ASSIGN_OR_RETURN(service, xla::GetDistributedRuntimeService( + coordinator_bind_address, options)); + } + xla::DistributedRuntimeClient::Options options; + options.node_id = node_id; + options.init_timeout = init_timeout; + distributed_client = + GetDistributedRuntimeClient(std::string(address), options); + TF_QCHECK_OK(distributed_client->Connect()); + kv_store = GetDistributedKeyValueStore(distributed_client, + /*key_prefix=*/"gpu:"); + GpuClientOptions gpu_client_options; + gpu_client_options.node_id = node_id; + gpu_client_options.num_nodes = num_nodes; + gpu_client_options.kv_store = kv_store; + return CreateGpuClient(std::move(gpu_client_options)); } absl::StatusOr GetPjRtClient(absl::string_view device_type, diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo index c745ee721e486a..c182a59714982b 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo +++ b/third_party/xla/xla/tools/multihost_hlo_runner/data/multiple_gemm_fusions.hlo @@ -1,35 +1,46 @@ f1 { - p0 = f16[720,720,720]{2,1,0} parameter(0) - p1 = s8[720,720,720]{2,1,0} parameter(1) - c = f16[720,720,720]{2,1,0} convert(p1) - ROOT d1 = f16[720,720,720]{2,1,0} dot(p0, c), + p0 = f16[64,64,64] parameter(0) + p1 = s8[64,64,64] parameter(1) + c = f16[64,64,64] convert(p1) + ROOT d1 = f32[64,64,64] dot(p0, c), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} } f2 { - p0 = s8[720,720,720]{2,1,0} parameter(0) - c0 = f32[720,720,720]{2,1,0} convert(p0) - p1 = f16[720,720,720]{2,1,0} parameter(1) - c1 = f32[720,720,720]{2,1,0} convert(p1) - ROOT %dot.1 = f32[720,720,720]{2,1,0} dot(c0, c1), + p0 = s8[64,64,64] parameter(0) + c0 = f32[64,64,64] convert(p0) + p1 = f16[64,64,64] parameter(1) + c1 = f32[64,64,64] convert(p1) + ROOT d2 = f32[64,64,64] dot(c0, c1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1} } +f3 { + p0 = f16[64,64,64] parameter(0) + p1 = f16[64,64,64] parameter(1) + ROOT d3 = f32[64,64,64] dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={1}, + rhs_batch_dims={0}, rhs_contracting_dims={2} +} + fa { - p1 = f16[720,720,720]{2,1,0} parameter(1) - c = f32[720,720,720]{2,1,0} convert(p1) - p0 = f32[720,720,720]{2,1,0} parameter(0) - ROOT %add.1.1 = f32[720,720,720]{2,1,0} add(c, p0) + p0 = f32[64,64,64] parameter(0) + p1 = f32[64,64,64] parameter(1) + p2 = f32[64,64,64] parameter(2) + a1 = f32[64,64,64] add(p2, p1) + ROOT a = f32[64,64,64] add(p0, a1) } ENTRY e { - p1 = s8[720,720,720]{2,1,0} parameter(1) - p0 = f16[720,720,720]{2,1,0} parameter(0) - f1r = f16[720,720,720]{2,1,0} fusion(p0, p1), kind=kCustom, calls=f1, + p0 = f16[64,64,64] parameter(0) + p1 = s8[64,64,64] parameter(1) + f1r = f32[64,64,64] fusion(p0, p1), kind=kCustom, calls=f1, + backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} + f2r = f32[64,64,64] fusion(p1, p0), kind=kCustom, calls=f2, backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} - f2r = f32[720,720,720]{2,1,0} fusion(p1, p0), kind=kCustom, calls=f2, + f3r = f32[64,64,64] fusion(p0, p0), kind=kCustom, calls=f3, backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}} - ROOT _ = f32[720,720,720]{2,1,0} fusion(f2r, f1r), kind=kLoop, calls=fa + ROOT _ = f32[64,64,64] fusion(f1r, f2r, f3r), kind=kLoop, calls=fa } diff --git a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc index 0d043ad757c1d0..a7f986c8fc3d66 100644 --- a/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc +++ b/third_party/xla/xla/tools/multihost_hlo_runner/functional_hlo_runner_test.cc @@ -30,9 +30,9 @@ limitations under the License. #include "xla/pjrt/pjrt_client.h" #include "xla/tests/filecheck.h" #include "xla/tools/multihost_hlo_runner/create_client.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/command_line_flags.h" #include "xla/xla.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/file_system.h" @@ -256,7 +256,7 @@ TEST_F(FunctionalHloRunnerTest, CanCompileWithoutHavingEnoughGpus) { // Name of the test binary. static const char* binary_name; -constexpr int kNumNodes = 3; +constexpr int kNumNodes = 2; TEST_F(FunctionalHloRunnerTest, ShardedAutotuningWorks) { if (IsTestingCpu()) { @@ -308,13 +308,8 @@ absl::Status ShardedAutotuningWorksTestBody(const int node_id) { env.kv_store->Get("gemm_fusion_autotuning_results_1_1", absl::Seconds(1))); CHECK(absl::StrContains(results1, "run_time")); - // First two nodes autotune two different fusions. + // The nodes autotune different fusions. CHECK_NE(results0, results1); - TF_ASSIGN_OR_RETURN(std::string results2, - env.kv_store->Get("gemm_fusion_autotuning_results_1_2", - absl::Seconds(1))); - // Third node has nothing to autotune. - CHECK(!absl::StrContains(results2, "run_time")); } return absl::OkStatus(); } @@ -360,9 +355,9 @@ int main(int argc, char* argv[]) { xla::AppendDebugOptionsFlags(&flag_list); std::string usage = tsl::Flags::Usage(argv[0], flag_list); tsl::Flags::Parse(&argc, argv, flag_list); + testing::InitGoogleTest(&argc, argv); if (node_id >= 0) { return !xla::ShardedAutotuningWorksTestBody(node_id).ok(); } - testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/third_party/xla/xla/tools/prepare_reference_module.cc b/third_party/xla/xla/tools/prepare_reference_module.cc index 4ce766d12e65dd..82fd57a8f183f5 100644 --- a/third_party/xla/xla/tools/prepare_reference_module.cc +++ b/third_party/xla/xla/tools/prepare_reference_module.cc @@ -34,7 +34,8 @@ absl::StatusOr> PrepareReferenceModule( const HloModule& test_module, HloRunnerInterface* test_runner, const std::function& config_modifier_hook, const std::function& module_modifier_hook) { + HloModule*)>& module_modifier_hook, + bool skip_despecialization) { DebugOptions debug_options = GetDebugOptionsFromFlags(); // The combination of fast math and optimizations leads to unsound code // transformations (see third_party/tensorflow/compiler/xla/xla.proto for @@ -51,7 +52,7 @@ absl::StatusOr> PrepareReferenceModule( if (module_modifier_hook) { TF_RETURN_IF_ERROR( module_modifier_hook(test_module, test_runner, reference_module.get())); - } else { + } else if (!skip_despecialization) { TF_RETURN_IF_ERROR(Despecializer().Run(reference_module.get()).status()); } return std::move(reference_module); diff --git a/third_party/xla/xla/tools/prepare_reference_module.h b/third_party/xla/xla/tools/prepare_reference_module.h index 4a1064d8c9e3c0..f26e84745b40d3 100644 --- a/third_party/xla/xla/tools/prepare_reference_module.h +++ b/third_party/xla/xla/tools/prepare_reference_module.h @@ -37,7 +37,8 @@ absl::StatusOr> PrepareReferenceModule( const HloModule& test_module, HloRunnerInterface* test_runner, const std::function& config_modifier_hook = {}, const std::function& module_modifier_hook = {}); + HloModule*)>& module_modifier_hook = {}, + bool skip_despecialization = false); } // namespace xla diff --git a/third_party/xla/xla/tools/prepare_reference_module_test.cc b/third_party/xla/xla/tools/prepare_reference_module_test.cc new file mode 100644 index 00000000000000..0b2ad0e4c1b6b2 --- /dev/null +++ b/third_party/xla/xla/tools/prepare_reference_module_test.cc @@ -0,0 +1,76 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tools/prepare_reference_module.h" + +#include "xla/hlo/ir/hlo_module.h" +#include "xla/test.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/statusor.h" + +namespace xla { +namespace { + +const char* const kModuleStr = R"( + HloModule jit_step + + %fused_computation (param_0.2: f32[1,4]) -> f32[1,3] { + %param_0.2 = f32[1,4]{1,0} parameter(0) + ROOT %slice.11 = f32[1,3]{1,0} slice(f32[1,4]{1,0} %param_0.2), + slice={[0:1], [0:3]} + } + + ENTRY %main.3491 (Arg_0.0: f32[1,4]) -> f32[1,3] { + %Arg_0.0 = f32[1,4]{1,0} parameter(0) + ROOT %fusion = f32[1,3]{1,0} fusion(f32[1,4]{1,0} %Arg_0.0), kind=kLoop, + calls=%fused_computation + } +)"; + +using PrepareReferenceModuleTest = HloTestBase; + +// Ideally 'Despecializer' pass should be mocked. Because it is not feasible +// with the current design, despecialization tests in this file are based on +// Despecializer's implementation (Despecializer removes fusion op from the +// module). +TEST_F(PrepareReferenceModuleTest, PerformDespecialization) { + TF_ASSERT_OK_AND_ASSIGN(auto test_module, + ParseAndReturnVerifiedModule(kModuleStr)); + + TF_ASSERT_OK_AND_ASSIGN( + auto reference_module, + PrepareReferenceModule(*test_module, nullptr, {}, {}, + /*skip_despecialization=*/false)); + + // Fusion op should have been removed. + EXPECT_THAT(reference_module->ToString(), + Not(::testing::HasSubstr("fusion"))); +} + +TEST_F(PrepareReferenceModuleTest, SkipDespecialization) { + TF_ASSERT_OK_AND_ASSIGN(auto test_module, + ParseAndReturnVerifiedModule(kModuleStr)); + + TF_ASSERT_OK_AND_ASSIGN( + auto reference_module, + PrepareReferenceModule(*test_module, nullptr, {}, {}, + /*skip_despecialization=*/true)); + + // Fusion op should be there. + EXPECT_THAT(reference_module->ToString(), ::testing::HasSubstr("fusion")); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/tools/run_hlo_module.cc b/third_party/xla/xla/tools/run_hlo_module.cc index 690b50972808d8..22c0c02cafde9b 100644 --- a/third_party/xla/xla/tools/run_hlo_module.cc +++ b/third_party/xla/xla/tools/run_hlo_module.cc @@ -168,6 +168,12 @@ absl::StatusOr ExecuteWithRunner( return std::move(result_status).value(); } +void UseCpuThunkRuntime(HloModule& module) { + auto debug_options = module.config().debug_options(); + debug_options.set_xla_cpu_use_thunk_runtime(true); + module.mutable_config().set_debug_options(debug_options); +} + absl::Status RunAndCompareInternal( std::unique_ptr test_module, const BufferAssignmentProto* buffer_assignment_proto, @@ -255,17 +261,27 @@ absl::Status RunAndCompareInternal( std::unique_ptr reference_module; if (reference_runner != nullptr) { + // If reference platform is the same as test platform, we shouldn't + // deoptimize the reference module. + bool skip_deoptimization = options.reference_platform == options.platform; + // PrepareReferenceModule needs to know the *test* runner, in order to // properly match the test runner's numerics. TF_ASSIGN_OR_RETURN( reference_module, copy_result_on_failure( - PrepareReferenceModule(*test_module, test_runner, - config_modifier_hook, - reference_module_modifier_hook), + PrepareReferenceModule( + *test_module, test_runner, config_modifier_hook, + reference_module_modifier_hook, skip_deoptimization), ModuleResult::kCompilationError, reference_run_result)); } + // Now when reference_module is ready, we can modify test_module without + // impacting the reference run. + if (options.force_use_cpu_thunk_runtime_for_test) { + UseCpuThunkRuntime(*test_module); + } + TF_ASSIGN_OR_RETURN( auto test_result, copy_result_on_failure( diff --git a/third_party/xla/xla/tools/run_hlo_module.h b/third_party/xla/xla/tools/run_hlo_module.h index 3300f1b1b8671d..66afdc551f4981 100644 --- a/third_party/xla/xla/tools/run_hlo_module.h +++ b/third_party/xla/xla/tools/run_hlo_module.h @@ -40,6 +40,7 @@ struct RunHloModuleOptions { bool flatten_control_flow{false}; bool run_test_hlo_passes{true}; bool run_reference_hlo_passes{true}; + bool force_use_cpu_thunk_runtime_for_test{false}; // Using small float range by default, as otherwise all reductions // miscompare vs. the interpreter with inf/nan. bool use_large_float_range{false}; diff --git a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc index bb4d6b32cdcf34..9b6138f45e7246 100644 --- a/third_party/xla/xla/tools/run_hlo_module_bin_test.cc +++ b/third_party/xla/xla/tools/run_hlo_module_bin_test.cc @@ -14,15 +14,17 @@ limitations under the License. ==============================================================================*/ #include +#include #include #include #include "absl/log/log.h" +#include "absl/strings/str_cat.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/hlo_parser.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/path.h" #include "tsl/platform/statusor.h" @@ -32,22 +34,41 @@ limitations under the License. namespace xla { namespace { +std::vector make_args( + const std::string& run_hlo_module_bin, const std::string& file_name, + const std::vector& extra_args = {}, + std::optional input_literals_file = std::nullopt) { + std::string hlo_path = file_name[0] == '/' + ? file_name + : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), + "tools", "data", file_name); + + std::vector args = {run_hlo_module_bin, hlo_path, + "--platform=Host"}; + + args.insert(args.end(), extra_args.begin(), extra_args.end()); + + if (input_literals_file.has_value()) { + std::string input_path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "tools", "data", + input_literals_file.value()); + args.push_back(absl::StrCat("--input_literals_file=", input_path)); + } + + return args; +} + class RunHloModuleTest : public ::testing::Test { protected: void RunHlo(const std::string& file_name, - const std::vector& extra_args = {}) { + const std::vector& extra_args = {}, + std::optional input_literals_file = std::nullopt) { std::string run_hlo_module_bin = tsl::io::JoinPath( tsl::testing::XlaSrcRoot(), "tools", "run_hlo_module"); - std::string hlo_path = file_name[0] == '/' - ? file_name - : tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), - "tools", "data", file_name); - tsl::SubProcess proc; - std::vector args = {run_hlo_module_bin, hlo_path, - "--platform=Host"}; - args.insert(args.end(), extra_args.begin(), extra_args.end()); + auto args = make_args(run_hlo_module_bin, file_name, extra_args, + input_literals_file); proc.SetProgram(run_hlo_module_bin, args); proc.SetChannelAction(tsl::CHAN_STDOUT, tsl::ACTION_PIPE); proc.SetChannelAction(tsl::CHAN_STDERR, tsl::ACTION_PIPE); @@ -131,6 +152,22 @@ TEST_F(RunHloModuleTest, MustAliasWithSharding) { testing::Not(testing::HasSubstr("memory allocation bug"))); } +TEST_F(RunHloModuleTest, ReadInputLiteralsFromFile) { + RunHlo("add.hlo", + /*extra_args=*/{"--print_literals=true", "--reference_platform="}, + /*input_literals_file=*/"input_literal_f32_2_2.pbtxt"); + + EXPECT_TRUE(exited_normally_); + EXPECT_EQ(exit_status_, 0); + + ASSERT_THAT( + stdout_output_, + testing::HasSubstr("{ 0.1, 0.2 },")); // First two values of the input + ASSERT_THAT( + stdout_output_, + testing::HasSubstr("{ 0.2, 0.4 },")); // First two values of the result +} + TEST_F(RunHloModuleTest, AddSnapshot) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnUnverifiedModule(R"( diff --git a/third_party/xla/xla/tools/run_hlo_module_main.cc b/third_party/xla/xla/tools/run_hlo_module_main.cc index 0b3ff5dbda56d8..92e2e23efad6e0 100644 --- a/third_party/xla/xla/tools/run_hlo_module_main.cc +++ b/third_party/xla/xla/tools/run_hlo_module_main.cc @@ -122,7 +122,13 @@ int main(int argc, char** argv) { "other " "than the reference this is necessary because some HLO passes are " "legalization passes which must be run prior to code generation."), - + tsl::Flag( + "force_use_cpu_thunk_runtime_for_test", + &opts.force_use_cpu_thunk_runtime_for_test, + "Use thunk runtime for the test platform. If true, thunks runtime " + "will be used for the test run regardless of the " + "xla_cpu_use_thunk_runtime flag in XLA_FLAGS. This option doesn't " + "impact reference run. It is ignored for platforms other than CPU."), tsl::Flag("random_init_input_literals", &opts.random_init_input_literals, "Initialize input literals with random numbers." "Leave them uninitialized otherwise."), @@ -252,9 +258,9 @@ int main(int argc, char** argv) { &input_literals_proto); } - for (int i = 1; i <= iteration_count; ++i) { + for (int i = 0; i < iteration_count; ++i) { if (iteration_count != 1) { - std::cerr << "\n=== Iteration " << i << "\n"; + std::cerr << "\n=== Iteration " << i + 1 << "\n"; } xla::RunHloModuleIterationLiterals* iteration_literals_proto = nullptr; if (!opts.output_literals_file.empty() || @@ -276,7 +282,7 @@ int main(int argc, char** argv) { opts, iteration_literals_proto, /*reference_module_modifier_hook=*/{}, [&](xla::HloModuleConfig* config) { - config->set_seed(different_random_seeds ? i : 42); + config->set_seed(different_random_seeds ? i + 1 : 42); }); if (result.ok()) { diff --git a/third_party/xla/xla/tools/run_hlo_module_test.cc b/third_party/xla/xla/tools/run_hlo_module_test.cc index 2ac2d2b6f12074..255563a5893657 100644 --- a/third_party/xla/xla/tools/run_hlo_module_test.cc +++ b/third_party/xla/xla/tools/run_hlo_module_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/tools/run_hlo_module.pb.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/xla_data.pb.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tools/xla_compile_lib.cc b/third_party/xla/xla/tools/xla_compile_lib.cc index ce195d1a925bbd..16d4d0d65e29d3 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.cc +++ b/third_party/xla/xla/tools/xla_compile_lib.cc @@ -68,7 +68,7 @@ limitations under the License. #include "tsl/platform/statusor.h" #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_symbol_repository.h" #include "xla/stream_executor/gpu/gpu_init.h" @@ -232,37 +232,43 @@ ReadModuleFromSymbolRepo(absl::string_view symbol_repo, return mod; } -static absl::StatusOr LoadAutotuneDataFromModule( +static std::unique_ptr ReadTargetConfigFromModule( HloModuleAndMetadata* mod, BackendType backend) { if (backend == BackendType::kGpu) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); - data != nullptr && data->autotune_results.has_value()) { - TF_RETURN_IF_ERROR( - gpu::AutotunerUtil::LoadAutotuneResults(*data->autotune_results)); - return true; + data != nullptr) { + return std::move(mod->target_config); } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } - return false; + + return nullptr; } -static std::unique_ptr ReadTargetConfigFromModule( - HloModuleAndMetadata* mod, BackendType backend) { +namespace internal { + +absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, + BackendType backend) { if (backend == BackendType::kGpu) { #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM if (auto* data = static_cast( mod->backend_specific_data.get()); - data != nullptr) { - return std::move(mod->target_config); + data != nullptr && data->autotune_results.has_value() && + mod->hlo_module->config().debug_options().xla_gpu_autotune_level() > + 0) { + TF_RETURN_IF_ERROR( + gpu::AutotunerUtil::LoadAutotuneResults(*data->autotune_results)); + return true; } #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } - - return nullptr; + return false; } +} // namespace internal + absl::Status XlaCompileMain(const XlaCompileOptions& options) { std::unique_ptr hlo_module; std::unique_ptr target_config; @@ -299,7 +305,7 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { ReadModuleFromSymbolRepo(symbol_repo, optimized_symbol_id, backend)); #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM - TF_ASSIGN_OR_RETURN(found_autotune, LoadAutotuneDataFromModule( + TF_ASSIGN_OR_RETURN(found_autotune, internal::LoadAutotuneDataFromModule( optimized_mod.get(), backend)); #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM } @@ -340,7 +346,8 @@ absl::Status XlaCompileMain(const XlaCompileOptions& options) { if (absl::string_view autotune_results_path = options.gpu_options.autotune_results_path; - !found_autotune && !autotune_results_path.empty()) { + !found_autotune && !autotune_results_path.empty() && + hlo_module->config().debug_options().xla_gpu_autotune_level() > 0) { TF_RETURN_IF_ERROR(gpu::AutotunerUtil::LoadAutotuneResultsFromFile( autotune_results_path)); } diff --git a/third_party/xla/xla/tools/xla_compile_lib.h b/third_party/xla/xla/tools/xla_compile_lib.h index 8d4f9e0dae8e01..3892f156c913be 100644 --- a/third_party/xla/xla/tools/xla_compile_lib.h +++ b/third_party/xla/xla/tools/xla_compile_lib.h @@ -84,6 +84,15 @@ struct XlaCompileOptions { // correspond to fields in XlaCompileOptions. absl::Status XlaCompileMain(const XlaCompileOptions& compile_options); +namespace internal { + +// Loads autotuning data if autotuning is enabled and autotuning results are +// present. Returns true if data was present and successfully loaded, false +// otherwise. +absl::StatusOr LoadAutotuneDataFromModule(HloModuleAndMetadata* mod, + BackendType backend); + +} // namespace internal } // namespace xla #endif // XLA_TOOLS_XLA_COMPILE_LIB_H_ diff --git a/third_party/xla/xla/tools/xla_compile_lib_test.cc b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc similarity index 64% rename from third_party/xla/xla/tools/xla_compile_lib_test.cc rename to third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc index 6bf9051f221c83..62c06734ddb990 100644 --- a/third_party/xla/xla/tools/xla_compile_lib_test.cc +++ b/third_party/xla/xla/tools/xla_cpu_compile_lib_test.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/tools/xla_compile_lib.h" - #include #include #include @@ -23,19 +21,18 @@ limitations under the License. #include "google/protobuf/duration.pb.h" #include #include -#include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/platform_util.h" #include "xla/service/symbol_repository.h" #include "xla/service/xla_compile_result.pb.h" -#include "xla/stream_executor/device_description.pb.h" #include "xla/tests/hlo_test_base.h" -#include "xla/tests/test_macros.h" +#include "xla/tools/xla_compile_lib.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/util.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/env_time.h" +#include "tsl/platform/errors.h" #include "tsl/platform/path.h" #include "tsl/platform/status_matchers.h" #include "tsl/platform/statusor.h" @@ -53,21 +50,10 @@ using ::tsl::testing::IsOk; using ::tsl::testing::IsOkAndHolds; using ::tsl::testing::StatusIs; -#if XLA_TEST_BACKEND_CPU -static constexpr absl::string_view kPlatformName = "Host"; -#elif XLA_TEST_BACKEND_GPU -static constexpr absl::string_view kPlatformName = -#if TENSORFLOW_USE_ROCM - "ROCM"; -#else - "CUDA"; -#endif -#endif // XLA_TEST_BACKEND_CPU - class XlaCompileLibTest : public HloTestBase { protected: XlaCompileLibTest() - : HloTestBase(*PlatformUtil::GetPlatform(std::string(kPlatformName)), + : HloTestBase(*PlatformUtil::GetPlatform("Host"), GetReferencePlatform()) {} void SetUp() override { const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), @@ -80,48 +66,26 @@ class XlaCompileLibTest : public HloTestBase { std::unique_ptr module_; }; -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(CompilesForCpu)) { +TEST_F(XlaCompileLibTest, CompilesForCpu) { CompilationResult result; EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kCpu, std::nullopt, result), IsOkAndHolds(Not(IsEmpty()))); } -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithDevice)) { - CompilationResult result; - EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu, - std::nullopt, result), - IsOkAndHolds(Not(IsEmpty()))); - EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(CompilesForGpuWithoutDevice)) { - const std::string target_config_path = - tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", - "xla_aot_compile_test_gpu_target_config.prototxt"); - stream_executor::GpuTargetConfigProto target_config; - TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path, - &target_config)); - CompilationResult result; - EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu, - std::nullopt, result), - IsOkAndHolds(Not(IsEmpty()))); - EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); -} - -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(ErrorsOnUnexpectedPlatform)) { +TEST_F(XlaCompileLibTest, ErrorsOnUnexpectedPlatform) { XlaCompileOptions options; options.platform = "tpu"; EXPECT_THAT(XlaCompileMain(options), StatusIs(tsl::error::UNIMPLEMENTED)); } -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFilePropagatesErrors)) { +TEST_F(XlaCompileLibTest, WriteResultFilePropagatesErrors) { TimerStats stats; CompilationResult result; EXPECT_THAT(WriteResultFile("/does/not/exist", stats, result), Not(IsOk())); } -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(WriteResultFileWritesTheFile)) { +TEST_F(XlaCompileLibTest, WriteResultFileWritesTheFile) { std::string result_output_file; ASSERT_TRUE(tsl::Env::Default()->LocalTempFilename(&result_output_file)); @@ -167,7 +131,7 @@ TEST_F(XlaCompileLibTest, LoadModuleLoadsTextFormat) { EXPECT_THAT(LoadModule(module_file), IsOkAndHolds(Not(IsNull()))); } -TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { +TEST_F(XlaCompileLibTest, MainForCpu) { const std::string module_file = tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, @@ -191,29 +155,12 @@ TEST_F(XlaCompileLibTest, DISABLED_ON_GPU(MainForCpu)) { EXPECT_EQ(result.status().code(), tensorflow::error::OK); } -TEST_F(XlaCompileLibTest, DISABLED_ON_CPU(MainForGpu)) { - const std::string module_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); - TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, - module_->ToString())); - - const std::string output_path = - tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output"); - const std::string result_file = - tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb"); +TEST_F(XlaCompileLibTest, LoadAutotuneDataCpu) { + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); - XlaCompileOptions options; - options.module_path = module_file; - options.output_path = output_path; - options.platform = "gpu"; - options.result_output_file = result_file; - options.gpu_options.use_attached_device = true; - TF_EXPECT_OK(XlaCompileMain(options)); - - CompilationResult result; - TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); - EXPECT_TRUE(result.has_status()); - EXPECT_EQ(result.status().code(), tensorflow::error::OK); + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kCpu), + IsOkAndHolds(false)); } } // namespace diff --git a/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc new file mode 100644 index 00000000000000..bc34c8790fb14e --- /dev/null +++ b/third_party/xla/xla/tools/xla_gpu_compile_lib_test.cc @@ -0,0 +1,195 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include +#include +#include "xla/hlo/ir/hlo_module.h" +#include "xla/service/gpu/autotuning/autotuner_util.h" +#include "xla/service/gpu/gpu_symbol_repository.h" +#include "xla/service/platform_util.h" +#include "xla/service/symbol_repository.h" +#include "xla/service/xla_compile_result.pb.h" +#include "xla/stream_executor/device_description.pb.h" +#include "xla/tests/hlo_test_base.h" +#include "xla/tools/xla_compile_lib.h" +#include "xla/tsl/lib/core/status_test_util.h" +#include "xla/util.h" +#include "tsl/platform/env.h" +#include "tsl/platform/path.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" +#include "tsl/protobuf/error_codes.pb.h" +#include "tsl/protobuf/status.pb.h" + +namespace xla { +namespace { + +using ::testing::IsEmpty; +using ::testing::Not; +using ::tsl::testing::IsOkAndHolds; + +class XlaCompileLibTest : public HloTestBase { + protected: + XlaCompileLibTest() + : HloTestBase(*PlatformUtil::GetPlatform(std::string("GPU")), + GetReferencePlatform()) {} + void SetUp() override { + const std::string hlo_path = tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), + "tools", "data", "add.hlo"); + std::string hlo; + TF_ASSERT_OK(tsl::ReadFileToString(tsl::Env::Default(), hlo_path, &hlo)); + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo)); + } + + std::unique_ptr module_; +}; + +TEST_F(XlaCompileLibTest, CompilesForGpuWithDevice) { + CompilationResult result; + EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu, + std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); + EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); +} + +TEST_F(XlaCompileLibTest, CompilesForGpuWithoutDevice) { + const std::string target_config_path = + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", + "xla_aot_compile_test_gpu_target_config.prototxt"); + stream_executor::GpuTargetConfigProto target_config; + TF_ASSERT_OK(tsl::ReadTextProto(tsl::Env::Default(), target_config_path, + &target_config)); + CompilationResult result; + EXPECT_THAT(CompileExecutable(std::move(module_), BackendType::kGpu, + std::nullopt, result), + IsOkAndHolds(Not(IsEmpty()))); + EXPECT_TRUE(result.has_hlo_module()) << result.DebugString(); +} + +TEST_F(XlaCompileLibTest, MainForGpu) { + const std::string module_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "module.txt"); + TF_ASSERT_OK(tsl::WriteStringToFile(tsl::Env::Default(), module_file, + module_->ToString())); + + const std::string output_path = + tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_output"); + const std::string result_file = + tsl::io::JoinPath(tsl::testing::TmpDir(), "gpu_result.pb"); + + XlaCompileOptions options; + options.module_path = module_file; + options.output_path = output_path; + options.platform = "gpu"; + options.result_output_file = result_file; + options.gpu_options.use_attached_device = true; + TF_EXPECT_OK(XlaCompileMain(options)); + + CompilationResult result; + TF_ASSERT_OK(tsl::ReadBinaryProto(tsl::Env::Default(), result_file, &result)); + EXPECT_TRUE(result.has_status()); + EXPECT_EQ(result.status().code(), tensorflow::error::OK); +} + +TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningEnabled) { + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + auto data = std::make_unique(); + + AutotuneResults autotune_results; + TF_ASSERT_OK(tsl::ReadTextProto( + tsl::Env::Default(), + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"), + &autotune_results)); + data->autotune_results = autotune_results; + mod.backend_specific_data = std::move(data); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(3); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(true)); + EXPECT_FALSE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +} + +TEST_F(XlaCompileLibTest, LoadAutotuneDataGpuDataPresentAndAutotuningDisabled) { + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + auto data = std::make_unique(); + + AutotuneResults autotune_results; + TF_ASSERT_OK(tsl::ReadTextProto( + tsl::Env::Default(), + tsl::io::JoinPath(tsl::testing::XlaSrcRoot(), "service", "gpu", + "gpu_compiler_test_autotune_db.textproto"), + &autotune_results)); + data->autotune_results = autotune_results; + mod.backend_specific_data = std::move(data); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(0); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +} + +TEST_F(XlaCompileLibTest, + LoadAutotuneDataGpuDataNotPresentAndAutotuningEnabled) { + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(3); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +} + +TEST_F(XlaCompileLibTest, + LoadAutotuneDataGpuDataNotPresentAndAutotuningDisabled) { + gpu::AutotunerUtil::ClearAutotuneResults(); + + HloModuleAndMetadata mod; + mod.hlo_module = std::move(module_); + + DebugOptions opts = mod.hlo_module->config().debug_options(); + opts.set_xla_gpu_autotune_level(0); + mod.hlo_module->mutable_config().set_debug_options(opts); + + EXPECT_THAT(internal::LoadAutotuneDataFromModule(&mod, BackendType::kGpu), + IsOkAndHolds(false)); + EXPECT_TRUE(gpu::AutotunerUtil::ResultCacheIsEmpty()); +} + +} // namespace +} // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD index 1257db1ceef463..179ae72588080a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/BUILD @@ -21,8 +21,30 @@ cc_library( "//xla:xla_data_proto_cc", "//xla/mlir_hlo", "//xla/service:hlo_proto_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "async_importer", + srcs = ["async_importer.cc"], + hdrs = ["async_importer.h"], + deps = [ + ":attribute_importer", + ":hlo_utils", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/mlir_hlo", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", ], ) @@ -67,6 +89,7 @@ cc_library( "hlo_module_importer.h", ], deps = [ + ":async_importer", ":attribute_importer", ":custom_call_importer", ":hlo_utils", @@ -83,10 +106,9 @@ cc_library( "//xla:xla_proto_cc", "//xla/hlo/ir:hlo", "//xla/mlir_hlo", - "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:optional", @@ -113,7 +135,11 @@ cc_library( ":hlo_module_importer", "//xla:status_macros", "//xla/mlir/utils:error_util", + "//xla/service/llvm_ir:llvm_util", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:errors", ], ) @@ -128,14 +154,14 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", "//xla/mlir/utils:type_util", "//xla/mlir_hlo", - "//xla/mlir_hlo:convert_op_folder", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:SparseTensorEnums", + "@local_tsl//tsl/platform:statusor", ], ) @@ -149,9 +175,9 @@ xla_cc_test( "//xla:shape_util", "//xla:test", "//xla:types", + "//xla/tsl/lib/core:status_test_util", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc new file mode 100644 index 00000000000000..57bc78a0ead971 --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.cc @@ -0,0 +1,383 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/translate/hlo_to_mhlo/async_importer.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/translate/hlo_to_mhlo/attribute_importer.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include "xla/util.h" +#include "tsl/platform/errors.h" + +namespace xla { + +namespace { + +constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; +constexpr char kShardingAttr[] = "mhlo.sharding"; + +// ============ +// Imports an old-style async start op. E.g. an HLO all-gather-start +// instruction is imported as an async-start associated with an all-gather +// computation. +// +// Eventually, old-style async ops (e.g. all-gather-start) and new-style async +// ops (i.e. async-start, async-update and async-done) will converge on the +// HLO side, so we decided to not introduce new MHLO ops for all-gather-start +// and friends. +// +// In the end, there may be new ops added in the old-style because they're not +// compatible with the new-style async semantics, but those should be handled +// on their own, rather than this function which "upgrades" ops to the +// new-style async API. +// ============ +template +absl::StatusOr ImportOldStyleAsyncStart( + mlir::SymbolTable& symbol_table, + llvm::SmallVectorImpl& attributes, + const llvm::SmallVectorImpl& operands, mlir::Location loc, + mlir::Type result_type, mlir::OpBuilder* builder, std::string func_name, + std::function mutate_op) { + auto context = builder->getContext(); + if (!llvm::isa(result_type)) { + return tsl::errors::InvalidArgument( + "expected async_bundle tuple result type"); + } + auto result_types = result_type.cast().getTypes(); + if (result_types.size() < 2) { + return tsl::errors::InvalidArgument( + "async_bundle must contain at least two values"); + } + auto func_type = mlir::FunctionType::get(context, Untuple(result_types[0]), + Untuple(result_types[1])); + auto function = mlir::func::FuncOp::create(loc, func_name, func_type); + + // The new function doesn't need to be inserted in the beginning but is done + // to make testing easier and preserve the original behavior. + mlir::Block& block = symbol_table.getOp()->getRegion(0).front(); + symbol_table.insert(function, mlir::Block::iterator(block.begin())); + + function.setPrivate(); + auto async_builder = mlir::OpBuilder(function.getBody()); + + llvm::SmallVector async_attributes; + async_attributes.push_back(builder->getNamedAttr( + "called_computation", + mlir::FlatSymbolRefAttr::get(builder->getContext(), function.getName()))); + async_attributes.push_back(builder->getNamedAttr( + "execution_thread", builder->getStringAttr("main"))); + + // Attach the frontend_attributes and sharding attributes to the async op + // instead of the sync op. First, semantically sharding attributes cannot be + // attached to the sync op since the sync op may not produce the same number + // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO + // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the + // `mhlo.async_start` ops, so attaching them to the sync op will make them + // disappear during MHLO to HLO lowering. + for (auto it = attributes.begin(); it != attributes.end();) { + if (it->getName() == kShardingAttr || + it->getName() == kFrontendAttributesAttr) { + async_attributes.push_back(*it); + it = attributes.erase(it); + } else { + ++it; + } + } + + llvm::SmallVector locs(Untuple(result_types[0]).size(), + loc); + auto sync_operand = + async_builder + .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs) + ->getArguments(); + auto sync_operation = async_builder.create( + loc, Untuple(result_types[1]), sync_operand, attributes); + async_builder.create(loc, sync_operation->getResults()); + TF_RETURN_IF_ERROR(mutate_op(sync_operation)); + + function->setAttr("execution_thread", builder->getStringAttr("main")); + + auto bundle_result_type = + mlir::mhlo::AsyncBundleType::get(context, result_types); + return builder + ->create(loc, bundle_result_type, operands, + async_attributes) + .getOperation(); +} + +absl::StatusOr ImportOldStyleAsyncDone( + llvm::SmallVectorImpl& attributes, + const llvm::SmallVectorImpl& operands, mlir::Location loc, + mlir::Type result_type, mlir::OpBuilder* builder, + bool useBundleResult = false) { + assert(operands.size() == 1 && + "*-done ops must take only a single async_bundle operand"); + auto async_start = operands[0].getDefiningOp(); + if (!async_start) return InvalidArgument("*-start requires *-done as input"); + attributes.push_back(builder->getNamedAttr( + "called_computation", + mlir::FlatSymbolRefAttr::get(builder->getContext(), + async_start.getCalledComputation()))); + attributes.push_back(builder->getNamedAttr("execution_thread", + builder->getStringAttr("main"))); + + auto async_bundle = llvm::cast( + async_start.getResult().getType()); + + auto start_tuple = + llvm::dyn_cast(async_bundle.getTypes()[1]); + if (start_tuple && llvm::isa(start_tuple.getType(0))) { + auto op = builder->create(loc, result_type, + operands, attributes); + return {op}; + } else { + if (useBundleResult) result_type = async_bundle.getTypes()[1]; + auto op = builder->create( + loc, Untuple(result_type), operands, attributes); + return CreateTupleFromOpResults(builder, loc, op.getOperation(), + result_type); + } +} + +} // namespace + +// Op Converters + +absl::StatusOr ImportSend( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table) { + auto send_op = Cast(instruction); + attributes.push_back(builder->getNamedAttr( + "is_host_transfer", builder->getBoolAttr(send_op->is_host_transfer()))); + if (send_op->channel_id().has_value()) { + ChannelHandle channel_handle; + channel_handle.set_handle(send_op->channel_id().value()); + channel_handle.set_type(send_op->is_host_transfer() + ? ChannelHandle::DEVICE_TO_HOST + : ChannelHandle::DEVICE_TO_DEVICE); + attributes.push_back(ConvertChannelHandle(channel_handle, builder)); + } + + // Return async_start/done for pipelined send. + // + // old-style send returns a bundle of (arg, sync flag, token) to be passed + // along to send-done. + // However, the new-style async ops have a shared bundle + // format of (args, results, scratchpad), so to rewrite the `send` and + // `send-done` ops to use the new-style async API, we need to reorder the + // arguments to be in (args, token, sync flag) order. + auto result_types = result_type.cast().getTypes(); + if (result_types.size() != 3) + return InvalidArgument("send should return a 3-tuple"); + auto async_arg_type = mlir::TupleType::get( + builder->getContext(), {result_types[0], result_types[2]}); + auto async_bundled_tuple = + mlir::TupleType::get(builder->getContext(), + {async_arg_type, result_types[2], result_types[1]}); + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, async_bundled_tuple, builder, + "send_", [](auto) { return absl::OkStatus(); }); +} + +absl::StatusOr ImportRecv( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table) { + auto recv_op = Cast(instruction); + attributes.push_back(builder->getNamedAttr( + "is_host_transfer", builder->getBoolAttr(recv_op->is_host_transfer()))); + if (recv_op->channel_id().has_value()) { + ChannelHandle channel_handle; + channel_handle.set_handle(recv_op->channel_id().value()); + channel_handle.set_type(recv_op->is_host_transfer() + ? ChannelHandle::HOST_TO_DEVICE + : ChannelHandle::DEVICE_TO_DEVICE); + attributes.push_back(ConvertChannelHandle(channel_handle, builder)); + } + + // Old-style `recv` returns a bundle of (result, sync flag, token) to be + // passed along to recv-done. + // However, the new-style async ops have a shared + // bundle format of (args, results, scratchpad), so to rewrite the `recv` + // and `recv-done` ops to use the new-style async API, we need to reorder + // the arguments to be in (token, (result, token), sync flag) order. + // OR (token, token, sync flag) if no result is received. + auto result_types = result_type.cast().getTypes(); + if (result_types.size() != 3) + return InvalidArgument("recv should return a 3-tuple"); + + // Allow recv of no values, only token. + // b/TODO: Allow recv of no values, only token. + auto async_result_type = mlir::TupleType::get( + builder->getContext(), {result_types[0], result_types[2]}); + auto async_bundled_tuple = mlir::TupleType::get( + builder->getContext(), + {result_types[2], async_result_type, result_types[1]}); + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, async_bundled_tuple, builder, + "recv_", [](auto) { return absl::OkStatus(); }); +} + +// Async Collectives + +absl::StatusOr ImportAllGatherStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table) { + auto all_gather_start = Cast(instruction); + attributes.push_back(builder->getNamedAttr( + "all_gather_dim", + builder->getI64IntegerAttr(all_gather_start->all_gather_dimension()))); + attributes.push_back( + ConvertReplicaGroups(all_gather_start->replica_groups(), builder)); + if (all_gather_start->channel_id().has_value()) + attributes.push_back( + ConvertChannelHandle(all_gather_start->channel_id().value(), builder)); + if (all_gather_start->use_global_device_ids()) + attributes.push_back(ConvertUseGlobalDeviceIds(builder)); + if (all_gather_start->operands().size() > 1) + return InvalidArgument("Async tuple all-gather is not supported in MHLO"); + + if (!llvm::isa(result_type)) { + // Async AllGather's output type is bundle + // There are some instances where the output type is not a tuple, this seems + // to be the more modern case, so we will wrap these in a tuple for MHLO. + result_type = mlir::TupleType::get(builder->getContext(), + {operands[0].getType(), result_type}); + } + + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, result_type, builder, + "all_gather_", [](auto) { return absl::OkStatus(); }); +} + +absl::StatusOr ImportAllReduceStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + std::function mutate_op, + mlir::SymbolTable& symbol_table) { + auto all_reduce_start = Cast(instruction); + attributes.push_back( + ConvertReplicaGroups(all_reduce_start->replica_groups(), builder)); + if (all_reduce_start->channel_id().has_value()) + attributes.push_back( + ConvertChannelHandle(all_reduce_start->channel_id().value(), builder)); + if (all_reduce_start->use_global_device_ids()) + attributes.push_back(ConvertUseGlobalDeviceIds(builder)); + if (all_reduce_start->operands().size() > 1) + return InvalidArgument("Async tuple all-reduce is not supported in MHLO"); + + if (!llvm::isa(result_type)) { + // Async AllReduce's output type is bundle + // There are some instances where the output type is not a tuple, this seems + // to be the more modern case, so we will wrap these in a tuple for MHLO. + result_type = mlir::TupleType::get(builder->getContext(), + {operands[0].getType(), result_type}); + } + + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, result_type, builder, + "all_reduce_", mutate_op); +} + +// Collective Permute + +absl::StatusOr ImportCollectivePermuteStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table) { + attributes.push_back( + ConvertSourceTargetPairs(instruction->source_target_pairs(), builder)); + if (!llvm::isa(result_type)) { + // Async CollectivePermute's output type is bundle + // There are some instances where the output type is not a tuple, this seems + // to be the more modern case, so we will wrap these in a tuple for MHLO. + result_type = mlir::TupleType::get(builder->getContext(), + {operands[0].getType(), result_type}); + } + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, result_type, builder, + "collective_permute_", [&](auto) { return absl::OkStatus(); }); +} + +absl::StatusOr ImportCopyStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table) { + auto context = builder->getContext(); + auto copy_start_instruction = Cast(instruction); + if (auto cross_program_prefetch_index = + copy_start_instruction->cross_program_prefetch_index()) { + attributes.push_back(builder->getNamedAttr( + "cross_program_prefetch_index", + builder->getIntegerAttr(builder->getIntegerType(32), + *cross_program_prefetch_index))); + // Cross-program prefetch allows copy ops to accept tuples, in which + // case, we need to double-wrap inputs and outputs in tuples. + if (operands[0].getType().isa()) { + auto result_types = result_type.cast().getTypes(); + result_type = mlir::TupleType::get( + context, + {mlir::TupleType::get(context, {result_types[0]}), + mlir::TupleType::get(context, {result_types[1]}), result_types[2]}); + } + } + return ImportOldStyleAsyncStart( + symbol_table, attributes, operands, loc, result_type, builder, "copy_", + [](auto) { return absl::OkStatus(); }); +} + +absl::StatusOr ImportAsyncOpDone( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder) { + return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, + builder); +} + +} // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h new file mode 100644 index 00000000000000..efdd487c21f03d --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/async_importer.h @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ +#define XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace xla { + +// Op Converters +absl::StatusOr ImportSend( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table); + +absl::StatusOr ImportRecv( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table); + +// Async Collectives +absl::StatusOr ImportAllGatherStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table); + +absl::StatusOr ImportAllReduceStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + std::function mutate_op, + mlir::SymbolTable& symbol_table); + +absl::StatusOr ImportCollectivePermuteStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table); + +absl::StatusOr ImportCopyStart( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder, + mlir::SymbolTable& symbol_table); + +absl::StatusOr ImportAsyncOpDone( + const HloInstruction* instruction, mlir::Location loc, + const llvm::SmallVectorImpl& operands, + llvm::SmallVectorImpl& attributes, + mlir::Type result_type, mlir::OpBuilder* builder); + +} // namespace xla + +#endif // XLA_TRANSLATE_HLO_TO_MHLO_ASYNC_IMPORTER_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc index 9a2fa06fe19848..fbe1a904cfa21a 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.cc @@ -17,13 +17,24 @@ limitations under the License. #include +#include +#include #include #include #include +#include "absl/algorithm/container.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "xla/layout.h" #include "xla/layout_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/shape.h" #include "xla/shape_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -195,6 +206,62 @@ absl::StatusOr ConvertCustomCallApiVersion( } } +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder) { + return builder->getNamedAttr( + "channel_handle", + mlir::mhlo::ChannelHandleAttr::get(builder->getContext(), + channel.handle(), channel.type())); +} +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder) { + ChannelHandle channel_handle; + if (channel_id) channel_handle.set_handle(*channel_id); + return ConvertChannelHandle(channel_handle, builder); +} + +mlir::NamedAttribute ConvertReplicaGroups( + absl::Span replica_groups, mlir::Builder* builder) { + const int64_t num_groups = replica_groups.size(); + // Replica groups in HLO can be non-uniform in size, for example: + // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D + // tensor, pad the smaller sized replica groups with -1. + const int64_t group_size = absl::c_accumulate( + replica_groups, static_cast(0), + [](int64_t current, const ReplicaGroup& g) { + return std::max(current, g.replica_ids_size()); + }); + // Initialize all elements to -1 to support non-uniform replica groups. + std::vector attr(num_groups * group_size, -1); + for (int i = 0; i < num_groups; ++i) { + int index = i * group_size; + for (const int64_t& id : replica_groups[i].replica_ids()) + attr[index++] = id; + } + auto type = mlir::RankedTensorType::get({num_groups, group_size}, + builder->getIntegerType(64)); + return builder->getNamedAttr("replica_groups", + mlir::DenseIntElementsAttr::get(type, attr)); +} + +mlir::NamedAttribute ConvertSourceTargetPairs( + const std::vector>& source_target_pairs, + mlir::Builder* builder) { + std::vector attr(source_target_pairs.size() * 2); + for (const auto& p : llvm::enumerate(source_target_pairs)) { + attr[2 * p.index()] = p.value().first; + attr[2 * p.index() + 1] = p.value().second; + } + auto type = mlir::RankedTensorType::get( + {static_cast(attr.size() / 2), 2}, builder->getIntegerType(64)); + return builder->getNamedAttr("source_target_pairs", + mlir::DenseIntElementsAttr::get(type, attr)); +} + +mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder) { + return builder->getNamedAttr("use_global_device_ids", builder->getUnitAttr()); +} + absl::StatusOr ExtractLayoutsFromShapes( const absl::Span shapes_with_layouts, mlir::Builder* builder) { std::vector layouts; diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h index 4f1ba9e14a68df..f83681429645ec 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/attribute_importer.h @@ -16,10 +16,14 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_ATTRIBUTE_IMPORTER_H_ +#include +#include #include #include #include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -66,6 +70,20 @@ absl::StatusOr ConvertTranspose( absl::StatusOr ConvertCustomCallApiVersion( xla::CustomCallApiVersion api_version); +mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel, + mlir::Builder* builder); +mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertReplicaGroups( + absl::Span replica_groups, mlir::Builder* builder); + +mlir::NamedAttribute ConvertSourceTargetPairs( + const std::vector>& source_target_pairs, + mlir::Builder* builder); + +mlir::NamedAttribute ConvertUseGlobalDeviceIds(mlir::Builder* builder); + // Extracts layouts from shapes and converts it into layout attributes (array of // rank-1 index tensors). Returns an error if any of the shapes is a tuple. absl::StatusOr ExtractLayoutsFromShapes( diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc index e719db6c4bbeed..64908395b52ba8 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.cc @@ -15,7 +15,6 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" -#include #include #include #include @@ -26,8 +25,9 @@ limitations under the License. #include #include -#include "absl/algorithm/container.h" +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "llvm/ADT/APInt.h" @@ -35,7 +35,9 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -52,25 +54,23 @@ limitations under the License. #include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "xla/comparison_util.h" #include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/hlo/ir/hlo_sharding_metadata.h" #include "xla/layout.h" -#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/protobuf_util.h" #include "xla/service/hlo.pb.h" -#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/status_macros.h" +#include "xla/translate/hlo_to_mhlo/async_importer.h" #include "xla/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/translate/hlo_to_mhlo/custom_call_importer.h" #include "xla/translate/hlo_to_mhlo/hlo_utils.h" @@ -90,6 +90,8 @@ using mlir::Type; using mlir::Value; using mlir::func::FuncOp; +#define DEBUG_TYPE "xla-translate" + namespace xla { namespace { @@ -100,7 +102,7 @@ constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication"; // Note: This sanitization function causes an irreversible many-to-one mapping // and any solution to mitigate this would cause issues with the reverse -// direction. Longterm solution is to add a function attribute to maintain the +// direction. Long-term solution is to add a function attribute to maintain the // original HLO naming. std::string SanitizeFunctionName(llvm::StringRef name) { std::string output(name); @@ -170,115 +172,6 @@ Operation* createReturnOp(mlir::OpBuilder& builder, mlir::Location loc, } // namespace -mlir::TypeRange Untuple(const mlir::Type& type) { - if (type.isa()) { - return llvm::dyn_cast(type).getTypes(); - } - return type; -} - -template -absl::StatusOr HloFunctionImporter::ImportOldStyleAsyncStart( - llvm::SmallVectorImpl& attributes, - const llvm::SmallVectorImpl& operands, mlir::Location loc, - mlir::Type result_type, mlir::OpBuilder* func_builder, - std::string func_name, std::function mutate_op) { - auto result_types = result_type.cast().getTypes(); - if (result_types.size() < 2) { - return tsl::errors::InvalidArgument( - "async_bundle must contain at least two values"); - } - auto func_type = mlir::FunctionType::get(context_, Untuple(result_types[0]), - Untuple(result_types[1])); - auto function = FuncOp::create(loc, func_name, func_type); - - // The new function doesn't need to be inserted in the beginning but is done - // to make testing easier and preserve the original behavior. - mlir::Block& block = symbol_table_.getOp()->getRegion(0).front(); - symbol_table_.insert(function, mlir::Block::iterator(block.begin())); - - function.setPrivate(); - auto async_builder = mlir::OpBuilder(function.getBody()); - - llvm::SmallVector async_attributes; - async_attributes.push_back(builder_->getNamedAttr( - "called_computation", mlir::FlatSymbolRefAttr::get(builder_->getContext(), - function.getName()))); - async_attributes.push_back(builder_->getNamedAttr( - "execution_thread", builder_->getStringAttr("main"))); - - // Attach the frontend_attributes and sharding attributes to the async op - // instead of the sync op. First, semantically sharding attributes cannot be - // attached to the sync op since the sync op may not produce the same number - // of results as the sharding's tuple element count, e.g., `mhlo.send` vs. HLO - // `send`. Second, `mlir_hlo_to_hlo.cc` imports these attributes from the - // `mhlo.async_start` ops, so attaching them to the sync op will make them - // disappear during MHLO to HLO lowering. - for (auto it = attributes.begin(); it != attributes.end();) { - if (it->getName() == kShardingAttr || - it->getName() == kFrontendAttributesAttr) { - async_attributes.push_back(*it); - it = attributes.erase(it); - } else { - ++it; - } - } - - llvm::SmallVector locs(Untuple(result_types[0]).size(), - loc); - auto sync_operand = - async_builder - .createBlock(&function.getBody(), {}, Untuple(result_types[0]), locs) - ->getArguments(); - auto sync_operation = async_builder.create( - loc, Untuple(result_types[1]), sync_operand, attributes); - async_builder.create(loc, sync_operation->getResults()); - TF_RETURN_IF_ERROR(mutate_op(sync_operation)); - - function->setAttr("execution_thread", builder_->getStringAttr("main")); - - auto bundle_result_type = - mlir::mhlo::AsyncBundleType::get(context_, result_types); - return func_builder - ->create(loc, bundle_result_type, operands, - async_attributes) - .getOperation(); -} - -absl::StatusOr HloFunctionImporter::ImportOldStyleAsyncDone( - llvm::SmallVectorImpl& attributes, - const llvm::SmallVectorImpl& operands, mlir::Location loc, - mlir::Type result_type, mlir::OpBuilder* func_builder) { - if (operands.size() != 1) { - return InvalidArgument( - "async-done must take only a single async_bundle operand"); - } - auto async_start = operands[0].getDefiningOp(); - if (!async_start) return InvalidArgument("*-start requires *-done as input"); - attributes.push_back(builder_->getNamedAttr( - "called_computation", - mlir::FlatSymbolRefAttr::get(builder_->getContext(), - async_start.getCalledComputation()))); - attributes.push_back(builder_->getNamedAttr("execution_thread", - builder_->getStringAttr("main"))); - - auto start_tuple = async_start.getResult() - .getType() - .cast() - .getTypes()[1] - .dyn_cast(); - if (start_tuple && start_tuple.getType(0).isa()) { - auto op = func_builder->create( - loc, result_type, operands, attributes); - return {op}; - } else { - auto op = func_builder->create( - loc, Untuple(result_type), operands, attributes); - return CreateTupleFromOpResults(func_builder, loc, op.getOperation(), - result_type); - } -} - void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( mlir::Operation* op, llvm::ArrayRef implicit_operands) { assert((mlir::dyn_cast(*op) || @@ -296,20 +189,6 @@ void HloFunctionImporter::ReplaceBlockArgumentsWithImplicitOperands( } } -mlir::Operation* HloFunctionImporter::CreateTupleFromOpResults( - mlir::OpBuilder* func_builder, mlir::Location loc, mlir::Operation* op, - mlir::Type type) { - if (!type.isa()) return op; - - mlir::ValueRange flattened_results_ref(op->getResults()); - auto result = - CreateTupleValue(func_builder, loc, flattened_results_ref, type); - auto defining_tuple_op = result.getDefiningOp(); - assert(defining_tuple_op && "builder didn't return the right type"); - auto tupleOp = defining_tuple_op.getOperation(); - return tupleOp; -} - static bool IsNestedTupleInData(Type type) { auto tuple_type = type.dyn_cast(); if (!tuple_type) return false; @@ -328,36 +207,6 @@ static bool IsNestedTupleInData(Type type) { return false; } -static bool HasCustomLayout(const Shape& shape) { - if (shape.IsTuple()) { - return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); - } - return shape.has_layout() && !shape.layout().minor_to_major().empty() && - shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); -} - -static mlir::Attribute GetLayoutAttribute(mlir::Builder& b, - const Shape& shape) { - if (shape.IsTuple()) { - llvm::SmallVector element_attrs; - for (const auto& tuple_shape : shape.tuple_shapes()) { - element_attrs.push_back(GetLayoutAttribute(b, tuple_shape)); - } - return b.getArrayAttr(element_attrs); - } - - llvm::SmallVector layout; - if (shape.has_layout()) { - layout = {shape.layout().minor_to_major().begin(), - shape.layout().minor_to_major().end()}; - } else { - Layout layout_for_shape = LayoutUtil::GetDefaultLayoutForShape(shape); - layout = {layout_for_shape.minor_to_major().begin(), - layout_for_shape.minor_to_major().end()}; - } - return b.getIndexTensorAttr(layout); -} - mlir::Attribute GetFrontendAttributes(mlir::Builder& b, const FrontendAttributes& attributes) { llvm::SmallVector attrs; @@ -411,27 +260,6 @@ llvm::SmallVector HloFunctionImporter::FlattenTupleValues( return flattened_values; } -Value HloFunctionImporter::CreateTupleValue(mlir::OpBuilder* func_builder, - mlir::Location loc, - mlir::ValueRange& flatten_values, - Type type) { - auto tuple_type = type.dyn_cast(); - if (!tuple_type) { - assert(!flatten_values.empty()); - auto retval = flatten_values.front(); - flatten_values = flatten_values.drop_front(); - return retval; - } - - llvm::SmallVector flatten_sub_values; - for (auto child_type : tuple_type.getTypes()) - flatten_sub_values.push_back( - CreateTupleValue(func_builder, loc, flatten_values, child_type)); - - return func_builder->create(loc, flatten_sub_values) - .getResult(); -} - absl::StatusOr HloFunctionImporter::ImportAsFunc( const HloComputation& computation, mlir::SymbolTable& symbol_table, std::unordered_map* function_map, @@ -591,34 +419,6 @@ absl::StatusOr HloFunctionImporter::ImportAsFunc( builder_->getStringAttr(computation.execution_thread())); } - // The MLIR CPU pipeline assumes default layouts throughout the program. At - // the boundaries, this may not be the case, so layout information needs to - // be propagated to adapt the data layouts. - if (computation.IsEntryComputation()) { - const auto& computation_layout = - computation.parent()->entry_computation_layout(); - if (computation_layout.LayoutIsSet()) { - if (HasCustomLayout(computation_layout.result_layout().shape())) { - function->setAttr( - "xla_entry_computation_result_layout", - GetLayoutAttribute(*builder_, - computation_layout.result_layout().shape())); - } - if (llvm::any_of(computation_layout.parameter_layouts(), - [](const ShapeLayout& shape) { - return HasCustomLayout(shape.shape()); - })) { - llvm::SmallVector parameter_layouts; - for (auto& layout : computation_layout.parameter_layouts()) { - parameter_layouts.push_back( - GetLayoutAttribute(*builder_, layout.shape())); - } - function->setAttr("xla_entry_computation_parameter_layouts", - builder_->getArrayAttr(parameter_layouts)); - } - } - } - symbol_table_.insert(function); // Add to the map right away for function calls if map is set. @@ -819,7 +619,10 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( frontend_attributes.push_back( builder_->getNamedAttr(k, builder_->getStringAttr(v))); } + + int frontend_attributes_index = 0; if (!frontend_attributes.empty()) { + frontend_attributes_index = attributes.size(); attributes.push_back(builder_->getNamedAttr( kFrontendAttributesAttr, builder_->getDictionaryAttr(frontend_attributes))); @@ -984,6 +787,68 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( FuncOp function, ImportAsFunc(*instruction->to_apply(), /*is_main=*/false)); mlir::Operation* new_operation; + if (instruction->is_composite()) { + // TODO: b/354721812 - Support flatten_computation_args_result_ flag + // for composite calls + + mlir::DictionaryAttr frontend_attributes_attr = + builder_->getDictionaryAttr(frontend_attributes); + if (frontend_attributes.empty() || + !frontend_attributes_attr.contains("composite.attributes") || + !frontend_attributes_attr.contains("composite.name") || + !frontend_attributes_attr.contains("composite.version")) { + return InvalidArgument( + "A composite call op must have frontend attributes with the " + "following keys: composite.attributes, composite.name, " + "composite.version"); + } + + llvm::SmallVector fe_attrs_without_composite_attrs; + for (const auto& attr : frontend_attributes) { + if (attr.getName() != "composite.attributes" && + attr.getName() != "composite.name" && + attr.getName() != "composite.version") { + fe_attrs_without_composite_attrs.push_back(attr); + } + } + + // Frontend attributes may have been created by composite related + // attributes. If frontend attributes is empty after removing + // composite related attributes, it is not needed, so we remove it + // entirely. Otherwise, we update it. + if (fe_attrs_without_composite_attrs.empty()) { + attributes.erase(attributes.begin() + frontend_attributes_index); + } else { + attributes[frontend_attributes_index] = builder_->getNamedAttr( + kFrontendAttributesAttr, + builder_->getDictionaryAttr(fe_attrs_without_composite_attrs)); + } + + auto frontend_attributes_map = instruction->frontend_attributes().map(); + mlir::StringAttr name = builder_->getStringAttr( + frontend_attributes_map.find("composite.name")->second); + mlir::Attribute composite_attributes = mlir::parseAttribute( + frontend_attributes_map.find("composite.attributes")->second, + builder_->getContext()); + mlir::FlatSymbolRefAttr decomposition = mlir::SymbolRefAttr::get( + builder_->getContext(), instruction->to_apply()->name()); + mlir::IntegerAttr version = builder_->getIntegerAttr( + builder_->getI32Type(), + std::stoi( + frontend_attributes_map.find("composite.version")->second)); + + new_operation = func_builder->create( + loc, result_type, operands); + new_operation->setAttr("name", name); + new_operation->setAttr("composite_attributes", composite_attributes); + new_operation->setAttr("decomposition", decomposition); + new_operation->setAttr("version", version); + for (const auto& attr : attributes) { + new_operation->setAttr(attr.getName(), attr.getValue()); + } + return new_operation; + } + if (flatten_computation_args_result_) { // Flatten the tuple-typed operands. llvm::SmallVector flattened_operands = FlattenTupleValues( @@ -1004,7 +869,7 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( } else { new_operation = func_builder->create(loc, function, operands); - for (auto attr : attributes) { + for (const auto& attr : attributes) { new_operation->setAttr(attr.getName(), attr.getValue()); } } @@ -1015,8 +880,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back(ConvertReplicaGroups( collective_broadcast->replica_groups(), builder_)); if (collective_broadcast->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(collective_broadcast->channel_id().value())); + attributes.push_back(ConvertChannelHandle( + collective_broadcast->channel_id().value(), builder_)); return func_builder ->create(loc, result_type, operands, attributes) @@ -1028,23 +893,21 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back(ConvertSourceTargetPairs( collective_permute->source_target_pairs(), builder_)); if (collective_permute->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(collective_permute->channel_id().value())); + attributes.push_back(ConvertChannelHandle( + collective_permute->channel_id().value(), builder_)); return func_builder ->create(loc, result_type, operands, attributes) .getOperation(); } case HloOpcode::kCollectivePermuteStart: { - attributes.push_back(ConvertSourceTargetPairs( - instruction->source_target_pairs(), builder_)); - return ImportOldStyleAsyncStart( - attributes, operands, loc, result_type, func_builder, - "collective_permute_", [&](auto) { return absl::OkStatus(); }); + return ImportCollectivePermuteStart(instruction, loc, operands, + attributes, result_type, func_builder, + symbol_table_); } case HloOpcode::kCollectivePermuteDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kCustomCall: { auto custom_call = Cast(instruction); @@ -1368,103 +1231,31 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( .getOperation(); } case HloOpcode::kCopyStart: { - auto copy_start_instruction = Cast(instruction); - if (auto cross_program_prefetch_index = - copy_start_instruction->cross_program_prefetch_index()) { - attributes.push_back(builder_->getNamedAttr( - "cross_program_prefetch_index", - builder_->getIntegerAttr(builder_->getIntegerType(32), - *cross_program_prefetch_index))); - // Cross-program prefetch allows copy ops to accept tuples, in which - // case, we need to double-wrap inputs and outputs in tuples. - if (operands[0].getType().isa()) { - auto result_types = result_type.cast().getTypes(); - result_type = mlir::TupleType::get( - context_, {mlir::TupleType::get(context_, {result_types[0]}), - mlir::TupleType::get(context_, {result_types[1]}), - result_types[2]}); - } - } - return ImportOldStyleAsyncStart( - attributes, operands, loc, result_type, func_builder, "copy_", - [](auto) { return absl::OkStatus(); }); + return ImportCopyStart(instruction, loc, operands, attributes, + result_type, func_builder, symbol_table_); } case HloOpcode::kCopyDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kSend: { - // old-style send returns a bundle of (arg, sync flag, token) to be passed - // along to send-done. - // However, the new-style async ops have a shared bundle - // format of (args, results, scratchpad), so to rewrite the `send` and - // `send-done` ops to use the new-style async API, we need to reorder the - // arguments to be in (args, token, sync flag) order. - auto result_types = result_type.cast().getTypes(); - if (result_types.size() != 3) - return InvalidArgument("send should return a 3-tuple"); - auto async_arg_type = - mlir::TupleType::get(context_, {result_types[0], result_types[2]}); - auto async_bundled_tuple = mlir::TupleType::get( - context_, {async_arg_type, result_types[2], result_types[1]}); - auto send_op = Cast(instruction); - attributes.push_back(builder_->getNamedAttr( - "is_host_transfer", - builder_->getBoolAttr(send_op->is_host_transfer()))); - if (send_op->channel_id().has_value()) { - ChannelHandle channel_handle; - channel_handle.set_handle(send_op->channel_id().value()); - channel_handle.set_type(send_op->is_host_transfer() - ? ChannelHandle::DEVICE_TO_HOST - : ChannelHandle::DEVICE_TO_DEVICE); - attributes.push_back(ConvertChannelHandle(channel_handle)); - } - return ImportOldStyleAsyncStart( - attributes, operands, loc, async_bundled_tuple, func_builder, "send_", - [](auto) { return absl::OkStatus(); }); + return ImportSend(instruction, loc, operands, attributes, result_type, + func_builder, symbol_table_); } case HloOpcode::kSendDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kRecv: { - // Old-style `recv` returns a bundle of (result, sync flag, token) to be - // passed along to recv-done. - // However, the new-style async ops have a shared - // bundle format of (args, results, scratchpad), so to rewrite the `recv` - // and `recv-done` ops to use the new-style async API, we need to reorder - // the arguments to be in (token, (result, token), sync flag) order. - auto result_types = result_type.cast().getTypes(); - if (result_types.size() != 3) - return InvalidArgument("recv should return a 3-tuple"); - auto async_result_type = - mlir::TupleType::get(context_, {result_types[0], result_types[2]}); - auto async_bundled_tuple = mlir::TupleType::get( - context_, {result_types[2], async_result_type, result_types[1]}); - auto recv_op = Cast(instruction); - attributes.push_back(builder_->getNamedAttr( - "is_host_transfer", - builder_->getBoolAttr(recv_op->is_host_transfer()))); - if (recv_op->channel_id().has_value()) { - ChannelHandle channel_handle; - channel_handle.set_handle(recv_op->channel_id().value()); - channel_handle.set_type(recv_op->is_host_transfer() - ? ChannelHandle::HOST_TO_DEVICE - : ChannelHandle::DEVICE_TO_DEVICE); - attributes.push_back(ConvertChannelHandle(channel_handle)); - } - return ImportOldStyleAsyncStart( - attributes, operands, loc, async_bundled_tuple, func_builder, "recv_", - [](auto) { return absl::OkStatus(); }); + return ImportRecv(instruction, loc, operands, attributes, result_type, + func_builder, symbol_table_); } case HloOpcode::kRecvDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kConditional: { llvm::SmallVector rets; - - // Flatten the tuple-typed operands. llvm::SmallVector flattened_operands = FlattenTupleValues(func_builder, loc, operands); @@ -1556,9 +1347,9 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertReplicaGroups(all_gather->replica_groups(), builder_)); if (all_gather->channel_id().has_value()) attributes.push_back( - ConvertChannelHandle(all_gather->channel_id().value())); + ConvertChannelHandle(all_gather->channel_id().value(), builder_)); if (all_gather->use_global_device_ids()) - attributes.push_back(ConvertUseGlobalDeviceIds()); + attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); auto all_gather_op = func_builder->create( loc, result_types, operands, attributes); if (result_tuple_ty) { @@ -1570,28 +1361,12 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return all_gather_op.getOperation(); } case HloOpcode::kAllGatherStart: { - auto all_gather_start = Cast(instruction); - attributes.push_back(builder_->getNamedAttr( - "all_gather_dim", builder_->getI64IntegerAttr( - all_gather_start->all_gather_dimension()))); - attributes.push_back( - ConvertReplicaGroups(all_gather_start->replica_groups(), builder_)); - if (all_gather_start->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_gather_start->channel_id().value())); - if (all_gather_start->use_global_device_ids()) - attributes.push_back(ConvertUseGlobalDeviceIds()); - if (all_gather_start->operands().size() > 1) - return InvalidArgument( - "Async tuple all-gather is not supported in MHLO"); - - return ImportOldStyleAsyncStart( - attributes, operands, loc, result_type, func_builder, "all_gather_", - [](auto) { return absl::OkStatus(); }); + return ImportAllGatherStart(instruction, loc, operands, attributes, + result_type, func_builder, symbol_table_); } case HloOpcode::kAllGatherDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kAllReduce: { auto all_reduce = Cast(instruction); @@ -1606,9 +1381,9 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( ConvertReplicaGroups(all_reduce->replica_groups(), builder_)); if (all_reduce->channel_id().has_value()) attributes.push_back( - ConvertChannelHandle(all_reduce->channel_id().value())); + ConvertChannelHandle(all_reduce->channel_id().value(), builder_)); if (all_reduce->use_global_device_ids()) - attributes.push_back(ConvertUseGlobalDeviceIds()); + attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); auto all_reduce_op = func_builder->create( loc, result_types, operands, attributes); TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(), @@ -1622,29 +1397,19 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( return all_reduce_op.getOperation(); } case HloOpcode::kAllReduceStart: { - auto all_reduce_start = Cast(instruction); - attributes.push_back( - ConvertReplicaGroups(all_reduce_start->replica_groups(), builder_)); - if (all_reduce_start->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(all_reduce_start->channel_id().value())); - if (all_reduce_start->use_global_device_ids()) - attributes.push_back(ConvertUseGlobalDeviceIds()); - if (all_reduce_start->operands().size() > 1) - return InvalidArgument( - "Async tuple all-reduce is not supported in MHLO"); + auto appendRegion = [&](mlir::mhlo::AllReduceOp all_reduce_sync) { + TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->to_apply(), + &all_reduce_sync.getComputation())); + return absl::OkStatus(); + }; - return ImportOldStyleAsyncStart( - attributes, operands, loc, result_type, func_builder, "all_reduce_", - [&](auto all_reduce_sync) { - TF_RETURN_IF_ERROR(ImportAsRegion( - *instruction->to_apply(), &all_reduce_sync.getComputation())); - return absl::OkStatus(); - }); + return ImportAllReduceStart(instruction, loc, operands, attributes, + result_type, func_builder, appendRegion, + symbol_table_); } case HloOpcode::kAllReduceDone: { - return ImportOldStyleAsyncDone(attributes, operands, loc, result_type, - func_builder); + return ImportAsyncOpDone(instruction, loc, operands, attributes, + result_type, func_builder); } case HloOpcode::kAllToAll: { auto all_to_all = Cast(instruction); @@ -1681,7 +1446,8 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( replica_groups_attr); if (all_to_all->channel_id().has_value()) { - auto handle = ConvertChannelHandle(all_to_all->channel_id().value()); + auto handle = + ConvertChannelHandle(all_to_all->channel_id().value(), builder_); result.setChannelHandleAttr( handle.getValue().cast()); } @@ -1887,10 +1653,10 @@ absl::StatusOr HloFunctionImporter::ImportInstructionImpl( attributes.push_back( ConvertReplicaGroups(reduce_scatter->replica_groups(), builder_)); if (reduce_scatter->channel_id().has_value()) - attributes.push_back( - ConvertChannelHandle(reduce_scatter->channel_id().value())); + attributes.push_back(ConvertChannelHandle( + reduce_scatter->channel_id().value(), builder_)); if (reduce_scatter->use_global_device_ids()) - attributes.push_back(ConvertUseGlobalDeviceIds()); + attributes.push_back(ConvertUseGlobalDeviceIds(builder_)); auto reduce_scatter_op = func_builder->create( loc, result_type, operands, attributes); @@ -2250,10 +2016,22 @@ HloFunctionImporter::ImportInstructionWithLayout( const HloInstruction* instruction, const llvm::SmallVectorImpl& operands, mlir::OpBuilder* func_builder, DynamicShapeHandlingMode mode) { + LLVM_DEBUG(llvm::dbgs() << "Importing instruction: " + << HloOpcodeString(instruction->opcode()) << '\n'); + LLVM_DEBUG({ + llvm::dbgs() << " operands: ("; + llvm::interleaveComma(operands, llvm::dbgs(), + [](Value v) { llvm::dbgs() << v.getType(); }); + llvm::dbgs() << ")\n"; + }); TF_ASSIGN_OR_RETURN( mlir::Operation * op, ImportInstructionImpl(instruction, operands, func_builder, mode)); - if (op == nullptr) return op; + if (op == nullptr) { + LLVM_DEBUG(llvm::dbgs() << " instruction skipped.\n"); + return op; + } + LLVM_DEBUG(llvm::dbgs() << " imported: " << *op << '\n'); // See MlirToHloConversionOptions for more about layouts. // @@ -2380,67 +2158,11 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding( return builder_->getNamedAttr("padding", attr); } -mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs( - const std::vector>& source_target_pairs, - mlir::Builder* builder) { - std::vector attr(source_target_pairs.size() * 2); - for (const auto& p : llvm::enumerate(source_target_pairs)) { - attr[2 * p.index()] = p.value().first; - attr[2 * p.index() + 1] = p.value().second; - } - auto type = mlir::RankedTensorType::get( - {static_cast(attr.size() / 2), 2}, builder->getIntegerType(64)); - return builder->getNamedAttr("source_target_pairs", - DenseIntElementsAttr::get(type, attr)); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertReplicaGroups( - absl::Span replica_groups, mlir::Builder* builder) { - const int64_t num_groups = replica_groups.size(); - // Replica groups in HLO can be non-uniform in size, for example: - // replica_groups={{0},{1,2},{3}}. Since we are representing them as a 2D - // tensor, pad the smaller sized replica groups with -1. - const int64_t group_size = absl::c_accumulate( - replica_groups, int64_t(0), [](int64_t current, const ReplicaGroup& g) { - return std::max(current, g.replica_ids_size()); - }); - // Initialize all elements to -1 to support non-uniform replica groups. - std::vector attr(num_groups * group_size, -1); - for (int i = 0; i < num_groups; ++i) { - int index = i * group_size; - for (const int64_t& id : replica_groups[i].replica_ids()) - attr[index++] = id; - } - auto type = mlir::RankedTensorType::get({num_groups, group_size}, - builder->getIntegerType(64)); - return builder->getNamedAttr("replica_groups", - DenseIntElementsAttr::get(type, attr)); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( - std::optional channel_id) { - ChannelHandle channel_handle; - if (channel_id) channel_handle.set_handle(*channel_id); - return ConvertChannelHandle(channel_handle); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertChannelHandle( - const ChannelHandle& channel) { - return builder_->getNamedAttr( - "channel_handle", mlir::mhlo::ChannelHandleAttr::get( - context_, channel.handle(), channel.type())); -} - -mlir::NamedAttribute HloFunctionImporter::ConvertUseGlobalDeviceIds() { - return builder_->getNamedAttr("use_global_device_ids", - builder_->getUnitAttr()); -} - void HloFunctionImporter::SetLayoutForMlir(mlir::Operation* op, const Shape& shape, llvm::StringRef attr_name) { mlir::Builder b(op->getContext()); - op->setAttr(attr_name, GetLayoutAttribute(b, shape)); + op->setAttr(attr_name, GetLayoutAttribute(b, shape).first); } absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( @@ -2470,6 +2192,42 @@ absl::Status HloFunctionImporter::ConvertShapeToMlirLayout( return Internal("Couldn't convert layout."); } +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder) { + llvm::SmallVector element_attrs; + alias.ForEachAlias([&](const ShapeIndex& output_index, + const HloInputOutputAliasConfig::Alias& alias) { + std::string kindToString; + switch (alias.kind) { + case HloInputOutputAliasConfig::AliasKind::kMayAlias: + kindToString = "may_alias"; + break; + case HloInputOutputAliasConfig::AliasKind::kMustAlias: + kindToString = "must_alias"; + break; + default: + kindToString = "undefined_alias"; + } + mlir::NamedAttribute alias_named_attributes[3] = { + builder->getNamedAttr( + "parameter_index", + builder->getDenseI64ArrayAttr(ArrayRef( + alias.parameter_index.begin(), alias.parameter_index.end()))), + builder->getNamedAttr("parameter_number", builder->getI64IntegerAttr( + alias.parameter_number)), + builder->getNamedAttr("kind", builder->getStringAttr(kindToString))}; + + mlir::NamedAttribute named_attributes[2] = { + builder->getNamedAttr("output_index", + builder->getDenseI64ArrayAttr(ArrayRef( + output_index.begin(), output_index.end()))), + builder->getNamedAttr( + "alias", builder->getDictionaryAttr(alias_named_attributes))}; + element_attrs.push_back(builder->getDictionaryAttr(named_attributes)); + }); + return builder->getArrayAttr(element_attrs); +} + mlir::Attribute ConvertSharding(const HloSharding& sharding, mlir::Builder* builder) { return builder->getStringAttr(sharding.ToString(/*include_metadata=*/true)); diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h index cb3953990f4030..fa22a6d11f1086 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_function_importer.h @@ -16,23 +16,31 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_FUNCTION_IMPORTER_H_ -#include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/optional.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "xla/comparison_util.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/hlo/ir/hlo_sharding.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" @@ -89,30 +97,12 @@ class HloFunctionImporter { static void SetLayoutForMlir(mlir::Operation* op, const Shape& shape, llvm::StringRef attr_name); - // TODO(b/179166199): move this to attribute_importer.h. - // Converts XLA instruction source target pairs to MLIR attribute. - static mlir::NamedAttribute ConvertSourceTargetPairs( - const std::vector>& source_target_pairs, - mlir::Builder* builder); - - // TODO(b/179166199): move this to attribute_importer.h. - // Converts replica groups to attribute - static mlir::NamedAttribute ConvertReplicaGroups( - absl::Span replica_groups, mlir::Builder* builder); - // For mlir::IfOp or mlir::CaseOp, replace the uses of their region's block // arguments with 'implicit_operands'. Here | implicit_operands | == sum of // the number of arguments in all the regions in IfOp or CaseOp. void ReplaceBlockArgumentsWithImplicitOperands( mlir::Operation* op, llvm::ArrayRef implicit_operands); - // Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. - // Otherwise, return 'op'. - mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, - mlir::Location loc, - mlir::Operation* op, - mlir::Type type); - // FlattenTupleType flattens the types in (nested) tuple-type 'type' and // stores them in 'flattened_types'. static void FlattenTupleType( @@ -130,23 +120,6 @@ class HloFunctionImporter { mlir::OpBuilder* func_builder, mlir::Location loc, mlir::ValueRange values, std::optional reserve_size = std::nullopt); - // CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using - // the non-tuple-typed values in 'flatten_values'. - // - // e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, - // The function returns %t2 such that: - // %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple - // %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> - // - // Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to - // resp. flatten and create tuples in the exact same order. - // 2. `flatten_values`, initially storing the flattened values, will be - // mutated to a 0-length array by the end of function invocation. - static mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, - mlir::Location loc, - mlir::ValueRange& flatten_values, - mlir::Type type); - private: HloFunctionImporter(mlir::SymbolTable& symbol_table, std::unordered_map GetMlirValue(const HloInstruction* instruction); + // TODO(b/179166199): Move attribute converters to attribute_importer. // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. mlir::NamedAttribute ConvertComparisonDirection( ComparisonDirection direction); @@ -244,43 +218,6 @@ class HloFunctionImporter { // padding low and padding high for each of the spatial dimensions. mlir::NamedAttribute ConvertPadding(llvm::ArrayRef padding); - // Converts channel id to attribute - mlir::NamedAttribute ConvertChannelHandle(std::optional channel_id); - - // Convert use global device ids flag to attribute - mlir::NamedAttribute ConvertUseGlobalDeviceIds(); - - // Converts channel handle to attribute - mlir::NamedAttribute ConvertChannelHandle(const ChannelHandle& channel); - - // ============ - // Imports an old-style async start op. E.g. an HLO all-gather-start - // instruction is imported as an async-start associated with an all-gather - // computation. - // - // Eventually, old-style async ops (e.g. all-gather-start) and new-style async - // ops (i.e. async-start, async-update and async-done) will converge on the - // HLO side, so we decided to not introduce new MHLO ops for all-gather-start - // and friends. - // - // In the end, there may be new ops added in the old-style because they're not - // compatible with the new-style async semantics, but those should be handled - // on their own, rather than this function which "upgrades" ops to the - // new-style async API. - // ============ - template - absl::StatusOr ImportOldStyleAsyncStart( - llvm::SmallVectorImpl& attributes, - const llvm::SmallVectorImpl& operands, mlir::Location loc, - mlir::Type result_type, mlir::OpBuilder* func_builder, - std::string func_name, std::function mutate_op); - - // Imports an old-style async done op - absl::StatusOr ImportOldStyleAsyncDone( - llvm::SmallVectorImpl& attributes, - const llvm::SmallVectorImpl& operands, mlir::Location loc, - mlir::Type result_type, mlir::OpBuilder* func_builder); - mlir::MLIRContext* context_; // SymbolTable to which new functions should be inserted. @@ -297,6 +234,12 @@ class HloFunctionImporter { bool flatten_computation_args_result_; }; +// Returns a StringAttr that carries a prettyprinted representation of the +// given HLO C++ input_output_alias_config. +// Always succeeds and returns a non-empty attribute. +mlir::Attribute ConvertInputOutputAlias(const HloInputOutputAliasConfig& alias, + mlir::Builder* builder); + // Returns a StringAttr that carries a prettyprinted representation of the // given HLO C++ sharding. // Always succeeds and returns a non-empty attribute. diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc index 1f2ea997c81e8a..d7bd8404b9adaa 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_module_importer.cc @@ -18,10 +18,12 @@ limitations under the License. #include #include #include +#include #include "absl/container/flat_hash_map.h" #include "absl/status/status.h" #include "absl/types/span.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -35,9 +37,12 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/shape_layout.h" #include "xla/shape_util.h" #include "xla/translate/hlo_to_mhlo/hlo_function_importer.h" +#include "xla/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/translate/hlo_to_mhlo/module_config_importer.h" #include "xla/xla.pb.h" #include "tsl/platform/errors.h" @@ -122,6 +127,10 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { ConvertSharding(hlo_module.spmd_output_sharding(), &builder_)); } + module->setAttr("mhlo.input_output_alias", + ConvertInputOutputAlias( + hlo_module.input_output_alias_config(), &builder_)); + if (hlo_module.has_spmd_parameters_shardings()) { llvm::SmallVector parameter_shardings; parameter_shardings.reserve(hlo_module.spmd_parameters_shardings().size()); @@ -147,6 +156,45 @@ absl::Status HloModuleImporter::Import(const HloModule& hlo_module) { /*is_main*/ true, flatten_computation_args_result_) .status(); + // The MLIR CPU pipeline assumes default layouts throughout the program. At + // the boundaries, this may not be the case, so layout information needs to + // be propagated to adapt the data layouts. + if (const auto& computation_layout = hlo_module.entry_computation_layout(); + computation_layout.LayoutIsSet() && + !computation_layout.result_layout().shape().IsTuple()) { + if (HasCustomLayout(computation_layout.result_layout().shape())) { + std::pair layout_attrs = + GetLayoutAttribute(builder_, + computation_layout.result_layout().shape(), + computation_layout.result_layout().layout()); + module->setAttr("mhlo.xla_entry_computation_result_layout", + layout_attrs.first); + module->setAttr("mhlo.xla_entry_computation_result_tiles", + layout_attrs.second); + } + if (llvm::any_of(computation_layout.parameter_layouts(), + [](const ShapeLayout& shape) { + return HasCustomLayout(shape.shape()); + })) { + llvm::SmallVector parameter_layouts; + llvm::SmallVector parameter_tiles; + for (auto& layout : computation_layout.parameter_layouts()) { + std::pair layout_attrs = + GetLayoutAttribute( + builder_, layout.shape(), + (layout.LayoutIsSet() && !layout.shape().IsTuple()) + ? std::optional(layout.layout()) + : std::nullopt); + parameter_layouts.push_back(layout_attrs.first); + parameter_tiles.push_back(layout_attrs.second); + } + module->setAttr("mhlo.xla_entry_computation_parameter_layouts", + builder_.getArrayAttr(parameter_layouts)); + module->setAttr("mhlo.xla_entry_computation_parameter_tiles", + builder_.getArrayAttr(parameter_tiles)); + } + } + auto* module_entry_computation = hlo_module.entry_computation(); for (const auto* computation : hlo_module.computations()) TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc( diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc index e8d81dcc4a92d2..d6dafe01300c82 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.cc @@ -15,12 +15,29 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" #include "xla/mlir/utils/error_util.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/status_macros.h" #include "xla/translate/hlo_to_mhlo/hlo_module_importer.h" +#include "tsl/platform/errors.h" namespace xla { +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, + bool import_all_computations, bool flatten_computation_args_result) { + mlir::OwningOpRef module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); + TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module, + import_all_computations, + flatten_computation_args_result)); + return module; +} + absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModuleProto const* hlo_module_proto, bool import_all_computation, @@ -32,7 +49,7 @@ absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, } absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModule* hlo_module, + const xla::HloModule* hlo_module, bool import_all_computation, bool flatten_computation_args_result) { mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext()); @@ -41,4 +58,15 @@ absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, .Import(*hlo_module); } +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, + bool import_all_computations, bool flatten_computation_args_result) { + mlir::OwningOpRef module = + llvm_ir::CreateMlirModuleOp(mlir::UnknownLoc::get(&ctx)); + TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(*module, hlo_module, + import_all_computations, + flatten_computation_args_result)); + return module; +} + } // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h index 161823a102c28c..775d6367dc8fc9 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h @@ -19,6 +19,10 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OwningOpRef.h" namespace mlir { class ModuleOp; @@ -35,6 +39,11 @@ class HloModuleProto; // // If `flatten_computation_args_result` is set to true, flattens all tuple // arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, xla::HloModuleProto const* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModuleProto const* hlo_module, bool import_all_computations = false, @@ -47,8 +56,13 @@ absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, // // If `flatten_computation_args_result` is set to true, flattens all tuple // arguments and result of every computation when importing them as func ops. +absl::StatusOr> ConvertHloToMlirHlo( + mlir::MLIRContext& ctx, const xla::HloModule* hlo_module, + bool import_all_computations = false, + bool flatten_computation_args_result = false); + absl::Status ConvertHloToMlirHlo(mlir::ModuleOp module, - xla::HloModule* hlo_module, + const xla::HloModule* hlo_module, bool import_all_computations = false, bool flatten_computation_args_result = false); diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc index 468c29aa4ffc56..e6004cfe5291d6 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.cc @@ -17,20 +17,34 @@ limitations under the License. #include "xla/translate/hlo_to_mhlo/hlo_utils.h" +#include #include -#include +#include #include +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/ValueRange.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/mlir/utils/type_util.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/primitive_util.h" +#include "xla/shape.h" #include "xla/types.h" #include "xla/util.h" #include "xla/xla_data.pb.h" +#include "tsl/platform/statusor.h" namespace xla { namespace { @@ -139,4 +153,46 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector( vector); } +mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange& flatten_values, + mlir::Type type) { + auto tuple_type = type.dyn_cast(); + if (!tuple_type) { + assert(!flatten_values.empty()); + auto retval = flatten_values.front(); + flatten_values = flatten_values.drop_front(); + return retval; + } + + llvm::SmallVector flatten_sub_values; + for (auto child_type : tuple_type.getTypes()) + flatten_sub_values.push_back( + CreateTupleValue(func_builder, loc, flatten_values, child_type)); + + return func_builder->create(loc, flatten_sub_values) + .getResult(); +} + +mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, + mlir::Location loc, + mlir::Operation* op, + mlir::Type type) { + if (!type.isa()) return op; + + mlir::ValueRange flattened_results_ref(op->getResults()); + auto result = + CreateTupleValue(func_builder, loc, flattened_results_ref, type); + auto defining_tuple_op = result.getDefiningOp(); + assert(defining_tuple_op && "builder didn't return the right type"); + auto tupleOp = defining_tuple_op.getOperation(); + return tupleOp; +} + +mlir::TypeRange Untuple(const mlir::Type& type) { + if (llvm::isa(type)) { + return llvm::dyn_cast(type).getTypes(); + } + return type; +} + } // namespace xla diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h index 72c30be491e767..dd7f68aa09ea9c 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils.h @@ -18,17 +18,31 @@ limitations under the License. #ifndef XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ #define XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ +#include +#include +#include +#include + +#include "absl/status/statusor.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/SparseTensor/IR/Enums.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" -#include "xla/hlo/ir/hlo_instruction.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "xla/layout.h" +#include "xla/layout_util.h" +#include "xla/literal.h" #include "xla/mlir/utils/type_util.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/util.h" +#include "tsl/platform/statusor.h" namespace xla { @@ -159,6 +173,75 @@ static absl::StatusOr ConvertShapeToType(const Shape& shape, return ConvertTensorShapeToType(shape, builder); } +// CreateTupleValue creates a root TupleOp of (nested) tuple-type 'type' using +// the non-tuple-typed values in 'flatten_values'. +// +// e.g., Given 'flatten_values': [V1, V2, V3] &'type': tuple>, +// The function returns %t2 such that: +// %t1 = mhlo.tuple(V2,V3) : (T2,T3) -> tuple +// %t2 = mhlo.tuple(V1,%t1): (T1,tuple) -> tuple> +// +// Note: 1. FlattenTupleValue and CreateTupleValue is a pair of functions to +// resp. flatten and create tuples in the exact same order. +// 2. `flatten_values`, initially storing the flattened values, will be +// mutated to a 0-length array by the end of function invocation. +mlir::Value CreateTupleValue(mlir::OpBuilder* func_builder, mlir::Location loc, + mlir::ValueRange& flatten_values, mlir::Type type); + +// Create a TupleOp using the results of 'op' if 'type' is a mlir::TupleType. +// Otherwise, return 'op'. +mlir::Operation* CreateTupleFromOpResults(mlir::OpBuilder* func_builder, + mlir::Location loc, + mlir::Operation* op, mlir::Type type); + +mlir::TypeRange Untuple(const mlir::Type& type); + +static std::pair GetLayoutAttribute( + mlir::Builder& b, const Shape& shape, + std::optional maybe_layout = std::nullopt) { + if (shape.IsTuple()) { + llvm::SmallVector element_attrs; + llvm::SmallVector tile_attrs; + for (const auto& tuple_shape : shape.tuple_shapes()) { + // TODO here we do not dissect the layout of a tuple into sublayouts. + // Presently ShapeLayout cannot represent an explicit layout for a tuple + // type so this should never occur. However, if this function were to + // be used in another context where this assumption were to be lifted. + // users should be aware of this limitation which will use the default + // layout for tuple subshapes. + std::pair inner = + GetLayoutAttribute(b, tuple_shape); + element_attrs.push_back(inner.first); + tile_attrs.push_back(inner.second); + } + return std::make_pair((mlir::Attribute)b.getArrayAttr(element_attrs), + b.getArrayAttr(tile_attrs)); + } + + Layout layout = maybe_layout.value_or( + shape.has_layout() ? shape.layout() + : LayoutUtil::GetDefaultLayoutForShape(shape)); + + llvm::SmallVector vec_of_tiles; + for (const Tile& tile : layout.tiles()) { + llvm::SmallVector tile_vec = {tile.dimensions().begin(), + tile.dimensions().end()}; + vec_of_tiles.push_back(b.getIndexTensorAttr(tile_vec)); + } + llvm::SmallVector layout_vec = {layout.minor_to_major().begin(), + layout.minor_to_major().end()}; + return std::make_pair(b.getIndexTensorAttr(layout_vec), + b.getArrayAttr(vec_of_tiles)); +} + +static bool HasCustomLayout(const Shape& shape) { + if (shape.IsTuple()) { + return llvm::any_of(shape.tuple_shapes(), HasCustomLayout); + } + return shape.has_layout() && !shape.layout().minor_to_major().empty() && + shape.layout() != LayoutUtil::GetDefaultLayoutForShape(shape); +} + } // namespace xla #endif // XLA_TRANSLATE_HLO_TO_MHLO_HLO_UTILS_H_ diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc index c5a4e1a6c0e5d6..b16e5870e99d79 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc +++ b/third_party/xla/xla/translate/hlo_to_mhlo/hlo_utils_test.cc @@ -27,8 +27,8 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/shape_util.h" #include "xla/test.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/types.h" -#include "tsl/lib/core/status_test_util.h" namespace xla { namespace { diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD index 9c3500cdc7f49e..fd980a09c9dc4f 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/BUILD @@ -11,6 +11,7 @@ lit_test_suite( [ "bool_compare.hlo", "case_conditional.hlo", + "composite_call.hlo", "custom_call.hlo", "dynamic_param.hlo", "entry_computation_layout.hlo", @@ -20,11 +21,11 @@ lit_test_suite( "if_conditional.hlo", "import.hlo", "import_async.hlo", + "import_async2.hlo", "layouts_and_names.hlo", "location.hlo", "module_attributes.hlo", "module_config.hlo", - "send_recv.hlo", "simple.hlo", "spmd_module_sharding.hlo", "stacktrace_to_location.hlo", diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo new file mode 100644 index 00000000000000..ad3dc7031e9e64 --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/composite_call.hlo @@ -0,0 +1,186 @@ +// RUN: xla-translate -split-input-file -hlo-text-to-mlir-hlo %s | FileCheck %s + +// dictionary-like frontend_attributes +HloModule composite, entry_computation_layout={()->f32[]} + +// CHECK: func.func @main() -> tensor { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor}, decomposition = @add.2, version = 1 : i32} : (tensor) -> tensor +// CHECK: return %1 : tensor +// CHECK: } + +// CHECK: func.func private @add.2(%arg0: tensor) -> tensor { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: return %1 : tensor +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) +} + +ENTRY %main.7 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} +} + +// ----- + +// string-like frontend_attributes +HloModule composite, entry_computation_layout={()->f32[]} + +// CHECK: func.func @main() -> tensor { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor}, decomposition = @add.2, version = 1 : i32} : (tensor) -> tensor +// CHECK: return %1 : tensor +// CHECK: } + +// CHECK: func.func private @add.2(%arg0: tensor) -> tensor { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: return %1 : tensor +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) +} + +ENTRY %main.7 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor}",composite.name="foo.bar",composite.version="1"} +} + +// ----- + +// zero-output composite +HloModule composite, entry_computation_layout={()->()} + +// CHECK: func.func @main() -> tuple<> { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor}, decomposition = @return.2, version = 1 : i32, xla_shape = "()"} : (tensor) -> tuple<> +// CHECK: return %1 : tuple<> +// CHECK: } +// CHECK: func.func private @return.2(%arg0: tensor) -> tuple<> { +// CHECK: %0 = mhlo.tuple {xla_shape = "()"} : tuple<> +// CHECK: return %0 : tuple<> +// CHECK: } +%return.2 (Arg_0.3: f32[]) -> () { + %Arg_0.3 = f32[] parameter(0) + ROOT %tuple.4 = () tuple() +} + +ENTRY %main.7 () -> () { + %constant.1 = f32[] constant(42) + ROOT %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} +} + +// ----- + +// multi-output composite +HloModule composite, entry_computation_layout={()->(f32[], f32[])} + +// CHECK: func.func @main() -> tuple, tensor> { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor}, decomposition = @add.2, version = 1 : i32, xla_shape = "(f32[], f32[])"} : (tensor) -> tuple, tensor> +// CHECK: return %1 : tuple, tensor> +// CHECK: } +// CHECK: func.func private @add.2(%arg0: tensor) -> tuple, tensor> { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: %2 = mhlo.tuple %1, %1 {xla_shape = "(f32[], f32[])"} : tuple, tensor> +// CHECK: return %2 : tuple, tensor> +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5) +} + +ENTRY %main.9 () -> (f32[], f32[]) { + %constant.1 = f32[] constant(42) + ROOT %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} +} + +// ----- + +// optional composite attributes +HloModule composite, entry_computation_layout={()->f32[]} + +// CHECK: func.func @main() -> tensor { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2, version = 1 : i32} : (tensor) -> tensor +// CHECK: return %1 : tensor +// CHECK: } + +// CHECK: func.func private @add.2(%arg0: tensor) -> tensor { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: return %1 : tensor +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) +} + +ENTRY %main.7 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} +} + +// ----- + +// optional composite version +HloModule composite, entry_computation_layout={()->f32[]} + +// CHECK: func.func @main() -> tensor { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {composite_attributes = {n = 1 : i32, tensor = dense<1> : tensor}, decomposition = @add.2} : (tensor) -> tensor +// CHECK: return %1 : tensor +// CHECK: } + +// CHECK: func.func private @add.2(%arg0: tensor) -> tensor { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: return %1 : tensor +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) +} + +ENTRY %main.7 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes="{n = 1 : i32, tensor = dense<1> : tensor}",composite.name="foo.bar",composite.version="0"} +} + +// ----- + +// optional composite attributes and version +HloModule composite, entry_computation_layout={()->f32[]} + +// CHECK: func.func @main() -> tensor { +// CHECK: %0 = mhlo.constant dense<4.200000e+01> : tensor +// CHECK: %1 = mhlo.composite "foo.bar" %0 {decomposition = @add.2} : (tensor) -> tensor +// CHECK: return %1 : tensor +// CHECK: } + +// CHECK: func.func private @add.2(%arg0: tensor) -> tensor { +// CHECK: %0 = mhlo.constant dense<2.000000e+00> : tensor +// CHECK: %1 = mhlo.add %arg0, %0 : tensor +// CHECK: return %1 : tensor +// CHECK: } +%add.2 (Arg_0.3: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %constant.4 = f32[] constant(2) + ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) +} + +ENTRY %main.7 () -> f32[] { + %constant.1 = f32[] constant(42) + ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} +} diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo index fa99b77174cb53..a8b6707dcc6278 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/entry_computation_layout.hlo @@ -1,18 +1,30 @@ // RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s -HloModule entry, entry_computation_layout={(f32[2,3,4]{0,1,2}, f32[2,3,4]{1,2,0}, (f32[1,2]{1,0}, f32[1,2]{0,1}))->f32[2,3,4]{2,0,1}} +HloModule entry, entry_computation_layout={(f32[2,3,4]{0,1,2}, f32[2,3,4]{1,2,0}, (f32[1,2]{1,0}, f32[1,2]{0,1}), s32[]{:T(128)})->f32[2,3,4]{2,0,1}} ENTRY entry { p0 = f32[2,3,4]{2,1,0} parameter(0) p1 = f32[2,3,4]{2,1,0} parameter(1) p2 = (f32[1,2]{1,0}, f32[1,2]{0,1}) parameter(2) + p3 = s32[]{:T(128)} parameter(3) ROOT add = f32[2,3,4]{2,1,0} add(p0, p1) } -// CHECK: func.func @main( -// CHECK-SAME: xla_entry_computation_parameter_layouts -// CHECK-SAME: dense<[0, 1, 2]> -// CHECK-SAME: dense<[1, 2, 0]> -// CHECK-SAME: [dense<[1, 0]> -// CHECK-SAME: , dense<[0, 1]> -// CHECK-SAME: xla_entry_computation_result_layout = dense<[2, 0, 1]> +// CHECK: module @entry +// CHECK-SAME: mhlo.xla_entry_computation_parameter_layouts = [ +// CHECK-SAME: dense<[0, 1, 2]> : tensor<3xindex>, +// CHECK-SAME: dense<[1, 2, 0]> : tensor<3xindex>, +// CHECK-SAME: [dense<[1, 0]> : tensor<2xindex>, +// CHECK-SAME: dense<[0, 1]> : tensor<2xindex>], +// CHECK-SAME: dense<> : tensor<0xindex>] +// CHECK-SAME: mhlo.xla_entry_computation_parameter_tiles = [ +// CHECK-SAME: [], +// CHECK-SAME: [], +// CHECK-SAME: [ +// CHECK-SAME: [], +// CHECK-SAME: [] +// CHECK-SAME: ], +// CHECK-SAME: [dense<128> : tensor<1xindex>] +// CHECK-SAME: ], +// CHECK-SAME: mhlo.xla_entry_computation_result_layout = dense<[2, 0, 1]> : tensor<3xindex> +// CHECK-SAME: mhlo.xla_entry_computation_result_tiles = [] diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo index a8ce57c90f5d3d..0c175bc850e32e 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import.hlo @@ -2021,7 +2021,7 @@ add { } // CHECK-LABEL: func private @test_topk // CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> tuple, tensor<4x2xi32>> -// CHECK: mhlo.topk([[ARG]], k = 2, largest = true) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) +// CHECK: mhlo.topk([[ARG]], k = 2) : tensor<4x4xf32> -> (tensor<4x2xf32>, tensor<4x2xi32>) // FLATTEN-CHECK-LABEL: func private @test_topk // FLATTEN-CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>) -> (tensor<4x2xf32>, tensor<4x2xi32>) diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo index 7dcd16a1f9b615..4e9633014b332b 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async.hlo @@ -1,142 +1,162 @@ -// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s -// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION +// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations -split-input-file %s -o - | FileCheck %s -// NO_DEAD_FUNCTION-NOT: @test +// CHECK-LABEL: func.func private @recv_ +// CHECK: %0:2 = "mhlo.recv"(%arg0) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (!mhlo.token) -> (tensor, !mhlo.token) -// CHECK: module @foobar +// CHECK-LABEL: func.func private @send_ +// CHECK: %0 = "mhlo.send"(%arg0, %arg1) <{channel_handle = #mhlo.channel_handle, is_host_transfer = true}> : (tensor, !mhlo.token) -> !mhlo.token + +// CHECK-LABEL: func.func @main +// CHECK-LITERAL: %0 = "mhlo.async_start"(%arg0, %arg1) <{called_computation = @send_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (tensor, !mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> +// CHECK-NEXT-LITERAL: %1 = "mhlo.async_done"(%0) {called_computation = @send_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"}, mhlo.sharding = "{maximal device=0}", xla_shape = "token[]"} : (!mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor>) -> !mhlo.token +// CHECK-NEXT-LITERAL: %2 = "mhlo.async_start"(%1) <{called_computation = @recv_, execution_thread = "main"}> {mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}", xla_shape = "(s32[], u32[], token[])"} : (!mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, tensor> +// CHECK-NEXT-LITERAL: %3:2 = "mhlo.async_done"(%2) {called_computation = @recv_, execution_thread = "main", mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"}, mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle, !mhlo.token>, tensor>) -> (tensor, !mhlo.token) HloModule foobar -// Compiler-generated functions - -// CHECK: func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.recv"([[TOK]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = false} - -// CHECK: func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.recv"([[TOK]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} - -// CHECK: func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.send"([[INPUT]] - // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} - -// CHECK: func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: mhlo.copy [[INPUT]] - // CHECK-SAME: cross_program_prefetch_index - -// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]]) - // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> - -// CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK-SAME: use_global_device_ids - // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): - // CHECK: mhlo.add [[LHS]], [[RHS]] - -// CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { - // CHECK-NEXT: "mhlo.all_gather"([[INPUT]]) - // CHECK-SAME: all_gather_dim = 1 : i64 - // CHECK-SAME: channel_handle = #mhlo.channel_handle - // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> - // CHECK-SAME: use_global_device_ids - -// CHECK: func @main(%arg0: tensor) -> tensor { -ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { - ROOT %Arg_0.1 = f32[] parameter(0) -} +ENTRY %async_send_recv_test (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) { + %arg_0 = s32[] parameter(0) + %arg_1 = token[] parameter(1) + + %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"} + %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"} -// Tests - -// CHECK: func private @test_all_gather_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_all_gather_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main" - ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true - // CHECK-NEXT: "mhlo.async_done"([[AG_START]]) - ROOT ag-done = f32[128,128] all-gather-done(ag-start) + %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"} + %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"} + + %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0} + %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0} + ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5) } -add { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT add = f32[] add(lhs, rhs) +// ----- + +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,128]{1,0}} + +// CHECK-LABEL: func.func private @all_gather_ +// CHECK: mhlo.all_gather + +// CHECK-LABEL: func.func @main +// CHECK: mhlo.async_start{{.*}}called_computation = @all_gather_ +// CHECK: mhlo.async_done + +ENTRY %async_all_gather_test (Arg_0.1: f32[128,32]) -> f32[128,128] { + %Arg_0.1 = f32[128,32] parameter(0) + %all-gather-start.2 = f32[128,128] all-gather-start(f32[128,32] %Arg_0.1), channel_id=1, replica_groups={{0,2,4,6},{1,3,5,7}}, constrain_layout=true, dimensions={1}, use_global_device_ids=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=16} + ROOT %all-gather-done.3 = f32[128,128] all-gather-done(f32[128,128] %all-gather-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:1 offset " source_line=17} } -// CHECK: func private @test_all_reduce_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_all_reduce_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main" - ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true - // CHECK-NEXT: "mhlo.async_done"([[AR_START]]) - ROOT ar-done = f32[128,32] all-reduce-done(ar-start) +// ----- + +HloModule main, entry_computation_layout={(f32[10]{0})->f32[10]{0}} + +%region_1.2 (Arg_0.3: f32[], Arg_1.4: f32[]) -> f32[] { + %Arg_0.3 = f32[] parameter(0) + %Arg_1.4 = f32[] parameter(1) + ROOT %maximum.5 = f32[] maximum(f32[] %Arg_0.3, f32[] %Arg_1.4), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=7} } -// CHECK: func private @test_collective_permute -// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> -%test_collective_permute (input: f32[128,32]) -> f32[128,32] { - %input = f32[128,32]{1,0} parameter(0) - // CHECK-NEXT: [[CP_START:%.*]] = "mhlo.async_start"([[ARG]]) - // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main" - %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}} - // CHECK-NEXT: "mhlo.async_done"([[CP_START]]) - ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start) +// CHECK-LABEL: func.func private @all_reduce_ +// CHECK: mhlo.all_reduce + +// CHECK-LABEL: func.func @main +// CHECK: mhlo.async_start{{.*}}called_computation = @all_reduce_ +// CHECK: mhlo.async_done +ENTRY %async_all_reduce_test (Arg_0.1: f32[10]) -> f32[10] { + %Arg_0.1 = f32[10] parameter(0) + %all-reduce-start.6 = f32[10] all-reduce-start(f32[10] %Arg_0.1), channel_id=5, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=%region_1.2, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=22} + ROOT %all-reduce-done.7 = f32[10] all-reduce-done(f32[10] %all-reduce-start.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:30 offset " source_line=23} } -// CHECK: func private @test_copy_start -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) -%test_copy_start { - input = f32[128,32] parameter(0) - // CHECK-NEXT: [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]]) - // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main" - copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0 - // CHECK-NEXT: "mhlo.async_done"([[COPY_START]]) - ROOT copy-done = f32[128,32] copy-done(copy-start) +// ----- + +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} + +// CHECK-LABEL: func.func private @collective_permute_ +// CHECK: mhlo.collective_permute + +// CHECK-LABEL: func.func @main +// CHECK: mhlo.async_start{{.*}}called_computation = @collective_permute_ +// CHECK: mhlo.async_done +ENTRY %async_collective_permute_test (Arg_0.1: f32[128,32]) -> f32[128,32] { + %Arg_0.1 = f32[128,32] parameter(0) + %collective-permute-start.2 = f32[128,32] collective-permute-start(f32[128,32] %Arg_0.1), channel_id=1, source_target_pairs={{0,1},{1,2},{2,3}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=13} + ROOT %collective-permute-done.3 = f32[128,32] collective-permute-done(f32[128,32] %collective-permute-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:109 offset " source_line=14} } -// CHECK: func private @test_send -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_send_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]]) - // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> - send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true - // CHECK-NEXT: "mhlo.async_done"([[SEND_START]]) - ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true +// ----- + +HloModule main, entry_computation_layout={(f32[128,32]{1,0})->f32[128,32]{1,0}} + +ENTRY %async_copy_test (Arg_0.1: f32[128,32]) -> f32[128,32] { + %Arg_0.1 = f32[128,32] parameter(0) + %copy-start.2 = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %Arg_0.1), cross_program_prefetch_index=0, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=10} + ROOT %copy-done.3 = f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %copy-start.2), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:133 offset " source_line=11} } -// CHECK: func private @test_recv -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_recv_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) - // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> - recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true - // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) - recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true - ROOT gte = get-tuple-element(recv-done), index=0 +// ----- + +HloModule main, entry_computation_layout={(token[])->(s32[3,4]{1,0}, token[])} + +ENTRY %async_recv_test_tuple (Arg_0.1: token[]) -> (s32[3,4], token[]) { + %Arg_0.1 = token[] parameter(0) + %recv.2 = (s32[3,4], u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=16} + %recv-done.3 = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.4 = s32[3,4] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=0, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + %get-tuple-element.5 = token[] get-tuple-element((s32[3,4], token[]) %recv-done.3), index=1, sharding={maximal device=0}, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:179 offset " source_line=17} + ROOT %tuple.6 = (s32[3,4], token[]) tuple(s32[3,4] %get-tuple-element.4, token[] %get-tuple-element.5) } -// CHECK: func private @test_recv_dtd -// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) -%test_recv_dtd_start { - input = f32[128,32] parameter(0) - tok = token[] parameter(1) - // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) - // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main" - // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> - recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5 - // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) - recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5 - ROOT gte = get-tuple-element(recv-done), index=0 +// ----- + +HloModule main, entry_computation_layout={(s32[3,4]{1,0}, token[])->token[]} + +ENTRY %async_send_test (Arg_0.1: s32[3,4], Arg_1.2: token[]) -> token[] { + %Arg_0.1 = s32[3,4] parameter(0) + %Arg_1.2 = token[] parameter(1) + %send.3 = (s32[3,4], u32[], token[]) send(s32[3,4] %Arg_0.1, token[] %Arg_1.2), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=16} + ROOT %send-done.4 = token[] send-done((s32[3,4], u32[], token[]) %send.3), channel_id=5, is_host_transfer=true, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:213 offset " source_line=17} } + + +// BROKEN: b/TODO: Async custom calls? + +// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} + +// ENTRY %async_custom_call_test2 (Arg_0.1: f32[10]) -> (f32[20]) { +// %Arg_0.1 = f32[10] parameter(0) +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="bar", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=21} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=22} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:288 offset " source_line=23} +// } + +// HloModule main, entry_computation_layout={(f32[10]{0})->(f32[20]{0})} + +// ENTRY %async_custom_call_test (Arg_0.1: f32[10]) -> (f32[20]) { +// %Arg_0.1 = f32[10] parameter(0) +// %async-start.5 = ((f32[10]), f32[20], s32[]) custom-call-start(f32[10] %Arg_0.1), async_execution_thread="thread", custom_call_target="foo", metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=16} +// %async-update.6 = ((f32[10]), f32[20], s32[]) custom-call-update(((f32[10]), f32[20], s32[]) %async-start.5), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=18} +// ROOT %async-done.7 = (f32[20]) custom-call-done(((f32[10]), f32[20], s32[]) %async-update.6), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:265 offset " source_line=20} +// } + + +/////////// + +// BROKEN: b/TODO: Empty arg send/recv don't roundtrip + +// HloModule main, entry_computation_layout={(token[])->token[]} + +// ENTRY %async_send_test_empty (Arg_0.1: token[]) -> token[] { +// %tuple.2 = () tuple(), metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// %Arg_0.1 = token[] parameter(0) +// %send.3 = ((), u32[], token[]) send(() %tuple.2, token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=15} +// ROOT %send-done.4 = token[] send-done(((), u32[], token[]) %send.3), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:240 offset " source_line=16} +// } + +// HloModule main, entry_computation_layout={(token[])->((), token[])} + +// ENTRY %async_recv_test (Arg_0.1: token[]) -> ((), token[]) { +// %Arg_0.1 = token[] parameter(0) +// %recv.2 = ((), u32[], token[]) recv(token[] %Arg_0.1), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=17} +// ROOT %recv-done.3 = ((), token[]) recv-done(((), u32[], token[]) %recv.2), channel_id=5, metadata={source_file="within split at third_party/tensorflow/compiler/xla/translate/mhlo_to_hlo/tests/export_async.mlir:153 offset " source_line=18} +// } + diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo new file mode 100644 index 00000000000000..7493c958776950 --- /dev/null +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/import_async2.hlo @@ -0,0 +1,146 @@ +// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s +// RUN: xla-translate --print-sugar=false -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION + +// It would be great to consolidate this test with `import_async.hlo`, but +// this test is very fragile and doesn't run properly in a `-split-input-file` +// mode. + +// NO_DEAD_FUNCTION-NOT: @test + +// CHECK: module @foobar +HloModule foobar + +// Compiler-generated functions + +// CHECK: func private [[RECV_DTD_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.recv"([[TOK]] + // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = false} + +// CHECK: func private [[RECV_GENSYM:@.*recv.*]]([[TOK:%.*]]: !mhlo.token) -> (tensor<128x32xf32>, !mhlo.token) attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.recv"([[TOK]] + // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} + +// CHECK: func private [[SEND_GENSYM:@.*send.*]]([[INPUT:%.*]]: tensor<128x32xf32>, %arg1: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.send"([[INPUT]] + // CHECK-SAME{LITERAL}: {channel_handle = #mhlo.channel_handle, is_host_transfer = true} + +// CHECK: func private [[COPY_GENSYM:@.*copy.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { + // CHECK-NEXT: mhlo.copy [[INPUT]] + // CHECK-SAME: cross_program_prefetch_index + +// CHECK: func private [[CP_GENSYM:@.*collective_permute_.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.collective_permute"([[INPUT]]) + // CHECK-SAME{LITERAL}: <{source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>}> : (tensor<128x32xf32>) -> tensor<128x32xf32> + +// CHECK: func private [[AR_GENSYM:@.*all_reduce.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.all_reduce"([[INPUT]]) + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + // CHECK-SAME: use_global_device_ids + // CHECK: [[BLOCK:^.*]]([[LHS:%.*]]: tensor, [[RHS:%.*]]: tensor): + // CHECK: mhlo.add [[LHS]], [[RHS]] + +// CHECK: func private [[AG_GENSYM:@.*all_gather.*]]([[INPUT:%.*]]: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { + // CHECK-NEXT: "mhlo.all_gather"([[INPUT]]) + // CHECK-SAME: all_gather_dim = 1 : i64 + // CHECK-SAME: channel_handle = #mhlo.channel_handle + // CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> + // CHECK-SAME: use_global_device_ids + +// CHECK: func @main(%arg0: tensor) -> tensor { +ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] { + ROOT %Arg_0.1 = f32[] parameter(0) +} + +// Tests + +// CHECK: func private @test_all_gather_start +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) +%test_all_gather_start { + input = f32[128,32] parameter(0) + // CHECK-NEXT: [[AG_START:%.*]] = "mhlo.async_start"([[INPUT]]) + // CHECK-SAME: called_computation = [[AG_GENSYM]], execution_thread = "main" + ag-start = (f32[128,32], f32[128,128]) all-gather-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}, use_global_device_ids=true + // CHECK-NEXT: "mhlo.async_done"([[AG_START]]) + ROOT ag-done = f32[128,128] all-gather-done(ag-start) +} + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +// CHECK: func private @test_all_reduce_start +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) +%test_all_reduce_start { + input = f32[128,32] parameter(0) + // CHECK-NEXT: [[AR_START:%.*]] = "mhlo.async_start"([[INPUT]]) + // CHECK-SAME: called_computation = [[AR_GENSYM]], execution_thread = "main" + ar-start = (f32[128,32], f32[128,32]) all-reduce-start(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, to_apply=add, use_global_device_ids=true + // CHECK-NEXT: "mhlo.async_done"([[AR_START]]) + ROOT ar-done = f32[128,32] all-reduce-done(ar-start) +} + +// CHECK: func private @test_collective_permute +// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32> +%test_collective_permute (input: f32[128,32]) -> f32[128,32] { + %input = f32[128,32]{1,0} parameter(0) + // CHECK-NEXT: [[CP_START:%.*]] = "mhlo.async_start"([[ARG]]) + // CHECK-SAME: called_computation = [[CP_GENSYM]], execution_thread = "main" + %cp-start = (f32[128,32]{1,0}, f32[128,32]) collective-permute-start(%input), source_target_pairs={{0,1},{1,2},{2,3}} + // CHECK-NEXT: "mhlo.async_done"([[CP_START]]) + ROOT %cp-done = f32[128,32]{1,0} collective-permute-done(%cp-start) +} + +// CHECK: func private @test_copy_start +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>) +%test_copy_start { + input = f32[128,32] parameter(0) + // CHECK-NEXT: [[COPY_START:%.*]] = "mhlo.async_start"([[INPUT]]) + // CHECK-SAME: called_computation = [[COPY_GENSYM]], execution_thread = "main" + copy-start = (f32[128,32], f32[128,32], u32[]) copy-start(input), cross_program_prefetch_index=0 + // CHECK-NEXT: "mhlo.async_done"([[COPY_START]]) + ROOT copy-done = f32[128,32] copy-done(copy-start) +} + +// CHECK: func private @test_send +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) +%test_send_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + // CHECK-NEXT: [[SEND_START:%.*]] = "mhlo.async_start"([[INPUT]], [[TOK]]) + // CHECK-SAME: called_computation = [[SEND_GENSYM]], execution_thread = "main" + // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> + send-start = (f32[128,32], u32[], token[]) send(input, tok), channel_id=5, is_host_transfer=true + // CHECK-NEXT: "mhlo.async_done"([[SEND_START]]) + ROOT send-done = token[] send-done(send-start), channel_id=5, is_host_transfer=true +} + +// CHECK: func private @test_recv +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) +%test_recv_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) + // CHECK-SAME: called_computation = [[RECV_GENSYM]], execution_thread = "main" + // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> + recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5, is_host_transfer=true + // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) + recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5, is_host_transfer=true + ROOT gte = get-tuple-element(recv-done), index=0 +} + +// CHECK: func private @test_recv_dtd +// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>, [[TOK:%.*]]: !mhlo.token) +%test_recv_dtd_start { + input = f32[128,32] parameter(0) + tok = token[] parameter(1) + // CHECK-NEXT: [[RECV_START:%.*]] = "mhlo.async_start"([[TOK]]) + // CHECK-SAME: called_computation = [[RECV_DTD_GENSYM]], execution_thread = "main" + // CHECK-SAME{LITERAL}: -> !mhlo.async_bundle, !mhlo.token>, tensor> + recv-start = (f32[128,32], u32[], token[]) recv(tok), channel_id=5 + // CHECK-NEXT: "mhlo.async_done"([[RECV_START]]) + recv-done = (f32[128,21], token[]) recv-done(recv-start), channel_id=5 + ROOT gte = get-tuple-element(recv-done), index=0 +} diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo index 74eaaea5a0e8fe..d3433dce372cbf 100644 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo +++ b/third_party/xla/xla/translate/hlo_to_mhlo/tests/module_attributes.hlo @@ -5,6 +5,18 @@ # FLATTEN-CHECK-LABEL: module @main attributes { hlo_module { name: "main" + input_output_alias { + entries { + output_shape_index: 0 + parameter_number: 0 + kind: MAY_ALIAS + } + entries { + output_shape_index: 1 + parameter_number: 1 + kind: MAY_ALIAS + } + } entry_computation_name: "main.5" computations { name: "main.5" @@ -217,6 +229,7 @@ hlo_module { value: "attr_value" } } +# CHECK-SAME: mhlo.input_output_alias = [{alias = {kind = "may_alias", parameter_index = array, parameter_number = 0 : i64}, output_index = array}, {alias = {kind = "may_alias", parameter_index = array, parameter_number = 1 : i64}, output_index = array}] # CHECK-SAME: mhlo.is_dynamic = true is_dynamic: true # CHECK-SAME: mhlo.use_auto_spmd_partitioning = true diff --git a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo b/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo deleted file mode 100644 index ef40699925dc8a..00000000000000 --- a/third_party/xla/xla/translate/hlo_to_mhlo/tests/send_recv.hlo +++ /dev/null @@ -1,55 +0,0 @@ -// RUN: xla-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s - -HloModule foo - -// CHECK: func private @[[RECV_FUNC:[^(]*]] -// CHECK: mhlo.recv -// CHECK-SAME: channel_handle = #mhlo.channel_handle -// CHECK-NOT: mhlo.sharding - -// CHECK: func private @[[SEND_FUNC:[^(]*]] -// CHECK: mhlo.send -// CHECK-SAME: channel_handle = #mhlo.channel_handle - -// CHECK: func @main -// CHECK: mhlo.async_start -// CHECK-SAME: called_computation = @[[SEND_FUNC]] -// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"} -// CHECK-SAME: mhlo.sharding = "{ -// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0} -// CHECK-SAME: }" -// CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> -// CHECK: mhlo.async_done -// CHECK-SAME: called_computation = @[[SEND_FUNC]] -// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_dtoh_0"} -// CHECK-SAME: mhlo.sharding = "{maximal device=0}" -// CHECK-SAME: (!mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor>) -> !mhlo.token -// CHECK: mhlo.async_start -// CHECK-SAME: called_computation = @[[RECV_FUNC]] -// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"} -// CHECK-SAME: mhlo.sharding = "{ -// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0} -// CHECK-SAME: }" -// CHECK-SAME: (!mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, tensor> -// CHECK: mhlo.async_done -// CHECK-SAME: called_computation = @[[RECV_FUNC]] -// CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_handler_name = "tf_rendezvous", _xla_host_transfer_rendezvous = "_foo_htod_0"} -// CHECK-SAME: mhlo.sharding = "{ -// CHECK-SAME: {maximal device=0}, {maximal device=0} -// CHECK-SAME: }" -// CHECK-SAME: (!mhlo.async_bundle, !mhlo.token>, tensor>) -> (tensor, !mhlo.token) - -ENTRY %foo (arg_0: s32[], arg_1: token[]) -> (s32[], token[]) { - %arg_0 = s32[] parameter(0) - %arg_1 = token[] parameter(1) - - %send.0 = (s32[], u32[], token[]) send(s32[] %arg_0, token[] %arg_1), channel_id=3, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"} - %send-done.1 = token[] send-done((s32[], u32[], token[]) %send.0), channel_id=3, is_host_transfer=true, sharding={maximal device=0}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_dtoh_0"} - - %recv.2 = (s32[], u32[], token[]) recv(token[] %send-done.1), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"} - %recv-done.3 = (s32[], token[]) recv-done((s32[], u32[], token[]) %recv.2), channel_id=5, is_host_transfer=true, sharding={{maximal device=0}, {maximal device=0}}, frontend_attributes={_xla_host_transfer_handler_name="tf_rendezvous", _xla_host_transfer_rendezvous="_foo_htod_0"} - - %get-tuple-element.4 = s32[] get-tuple-element((s32[], token[]) %recv-done.3), index=0, sharding={maximal device=0} - %get-tuple-element.5 = token[] get-tuple-element((s32[], token[]) %recv-done.3), index=1, sharding={maximal device=0} - ROOT %tuple.6 = (s32[], token[]) tuple(s32[] %get-tuple-element.4, token[] %get-tuple-element.5) -} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD index c8b74d0d22a9e8..40e05c8873d229 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/BUILD @@ -23,6 +23,7 @@ cc_library( "//xla:types", "//xla:util", "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", "//xla/mlir_hlo", "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", @@ -98,6 +99,7 @@ cc_library( ":type_to_shape", "//xla:array", "//xla:comparison_util", + "//xla:debug_options_flags", "//xla:literal", "//xla:literal_util", "//xla:shape_util", @@ -105,6 +107,7 @@ cc_library( "//xla:types", "//xla:xla_data_proto_cc", "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/client/lib:approx_topk", "//xla/client/lib:approx_topk_shape", "//xla/client/lib:matrix", @@ -119,8 +122,11 @@ cc_library( "//xla/service:hlo_parser", "//xla/service:hlo_proto_cc", "//xla/service/gpu:backend_configs_cc", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", @@ -132,8 +138,11 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:ml_dtypes", "@local_tsl//tsl/platform:statusor", + "@local_tsl//tsl/platform:types", + "@stablehlo//:base", "@stablehlo//:stablehlo_ops", ], ) @@ -176,18 +185,26 @@ cc_library( deps = [ ":mlir_hlo_to_hlo", ":type_to_shape", + "//xla:debug_options_flags", + "//xla:shape_util", + "//xla/client:xla_builder", + "//xla/client:xla_computation", "//xla/hlo/ir:hlo", "//xla/mlir_hlo:hlo_dialect_registration", "//xla/service:hlo_module_config", "//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_util", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:errors", + "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc index a492861b28d831..73a5c8b994e57e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "mlir/Support/LLVM.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo_parser.h" #include "xla/shape_util.h" @@ -185,4 +187,99 @@ std::optional ConvertSharding(llvm::StringRef sharding) { return std::nullopt; } +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing) { + if (aliasing.empty()) return std::nullopt; + + xla::HloInputOutputAliasProto input_output_alias_proto; + for (auto attr : aliasing) { + auto entry_attr = mlir::cast(attr); + auto alias_attr = mlir::cast(entry_attr.get("alias")); + mlir::ArrayRef output_index = + mlir::cast(entry_attr.get("output_index")) + .asArrayRef(); + mlir::ArrayRef parameter_index = + mlir::cast(alias_attr.get("parameter_index")) + .asArrayRef(); + HloInputOutputAliasProto::AliasEntryProto entry; + entry.mutable_output_shape_index()->Add(output_index.begin(), + output_index.end()); + entry.set_parameter_number( + mlir::cast(alias_attr.get("parameter_number")) + .getInt()); + entry.mutable_parameter_shape_index()->Add(parameter_index.begin(), + parameter_index.end()); + mlir::StringRef kind = + mlir::cast(alias_attr.get("kind")).getValue(); + if (kind == "may_alias") + entry.set_kind(xla::Kind::MAY_ALIAS); + else if (kind == "must_alias") + entry.set_kind(xla::Kind::MUST_ALIAS); + else + entry.set_kind(xla::Kind::UNDEFINED_ALIAS); + input_output_alias_proto.add_entries()->Swap(&entry); + } + return input_output_alias_proto; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + mlir::mhlo::DotDimensionNumbersAttr input) { + DotDimensionNumbers output; + + for (auto v : input.getLhsBatchingDimensions()) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : input.getRhsBatchingDimensions()) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : input.getLhsContractingDimensions()) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : input.getRhsContractingDimensions()) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +DotDimensionNumbers ConvertDotDimensionNumbers( + absl::Span lhs_batch, absl::Span lhs_contract, + absl::Span rhs_batch, + absl::Span rhs_contract) { + DotDimensionNumbers output; + for (auto v : lhs_batch) { + output.add_lhs_batch_dimensions(v); + } + + for (auto v : rhs_batch) { + output.add_rhs_batch_dimensions(v); + } + + for (auto v : lhs_contract) { + output.add_lhs_contracting_dimensions(v); + } + + for (auto v : rhs_contract) { + output.add_rhs_contracting_dimensions(v); + } + + return output; +} + +absl::StatusOr> ConvertMlirArrayAttrToInt64Array( + const mlir::ArrayAttr& array) { + int rank = array.size(); + std::vector converted_array(rank); + for (int i = 0; i < rank; i++) { + mlir::IntegerAttr attr = mlir::dyn_cast(array[i]); + if (!attr) { + return Internal("Type Error: Expected layout integer attribute"); + } + converted_array[i] = attr.getInt(); + } + return converted_array; +} } // namespace xla diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h index e0e0dc9821d21e..49daefe6935650 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/attribute_exporter.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/statusor.h" #include "mlir/IR/Attributes.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/shape_util.h" @@ -59,5 +60,8 @@ ConvertOutputOperandAliasing(mlir::ArrayAttr aliasArrayAttr); // Will fail if both attempts at parsing failed. std::optional ConvertSharding(mlir::StringRef sharding); +std::optional ConvertInputOutputAlias( + llvm::ArrayRef aliasing); + } // namespace xla #endif // XLA_TRANSLATE_MHLO_TO_HLO_ATTRIBUTE_EXPORTER_H_ diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h index 2c85a82680345a..2ecd4e3ef3ba3d 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h +++ b/third_party/xla/xla/translate/mhlo_to_hlo/layout_util.h @@ -19,8 +19,9 @@ limitations under the License. #define XLA_TRANSLATE_MHLO_TO_HLO_LAYOUT_UTIL_H_ #include -#include +#include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "xla/client/xla_builder.h" #include "xla/hlo/ir/hlo_sharding.h" @@ -30,10 +31,10 @@ limitations under the License. namespace mlir { // XLA Layout preferences. Currently, when it comes to TPU, there are two -// primary layout choices for any XLA argumetns (parameter or resource): (1) +// primary layout choices for any XLA arguments (parameter or resource): (1) // CompactChunkPadded and (2) Linear. CompactChunkPadded is the native TPU // layout while Linear is native host (CPU) layout. -// This enum allows the caller of XLA to progogate layout preference to the XLA +// This enum allows the caller of XLA to propagate layout preference to the XLA // compiler. // kNoPreference: the generic layout where the XLA compiler has the freedom // to assign any layout. diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 623080e11fd60d..3e965fdaeb89a1 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -19,26 +19,25 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include +#include "absl/log/check.h" #include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/types/span.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SMLoc.h" -#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -46,9 +45,12 @@ limitations under the License. #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" @@ -56,25 +58,26 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/UseDefLists.h" #include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/RegionUtils.h" -#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/dialect/Base.h" #include "xla/array.h" #include "xla/client/lib/approx_topk.h" #include "xla/client/lib/approx_topk_shape.h" -#include "xla/client/lib/matrix.h" -#include "xla/client/lib/quantize.h" +#include "xla/client/lib/matrix.h" // IWYU pragma: keep #include "xla/client/lib/slicing.h" #include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" #include "xla/comparison_util.h" +#include "xla/debug_options_flags.h" #include "xla/hlo/ir/dynamic_parameter_binding.h" -#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_sharding.h" +#include "xla/layout.h" +#include "xla/layout_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/mlir/utils/error_util.h" @@ -88,16 +91,16 @@ limitations under the License. #include "xla/service/hlo_parser.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/translate/mhlo_to_hlo/attribute_exporter.h" +#include "xla/translate/mhlo_to_hlo/layout_util.h" #include "xla/translate/mhlo_to_hlo/location_exporter.h" #include "xla/translate/mhlo_to_hlo/module_config_exporter.h" #include "xla/translate/mhlo_to_hlo/stack_frame_index_builder.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/types.h" #include "xla/xla_data.pb.h" -#include "tsl/platform/ml_dtypes.h" +#include "tsl/platform/errors.h" #include "tsl/platform/statusor.h" +#include "tsl/platform/types.h" using ::int64_t; using ::tsl::int16; @@ -109,18 +112,54 @@ using ::tsl::uint32; using ::tsl::uint64; using ::tsl::uint8; -constexpr char kShapeIndicesAttr[] = "shape_indices"; -constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices"; -constexpr char kShardingAttr[] = "mhlo.sharding"; -constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes"; -constexpr char kReplicationAttr[] = "mhlo.is_same_data_across_replicas"; -constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication"; -constexpr char kLiteralAttr[] = "mhlo.literal"; - +// Boolean attribute. +constexpr char kJaxBufferDonor[] = "jax.buffer_donor"; + +// BitcastOp lowering strings. +constexpr char kResultLayout[] = "result_layout"; +constexpr char kSourceLayout[] = "source_layout"; + +// CustomCallOp lowering strings. +constexpr char kAggregateToTopk[] = "aggregate_to_topk"; +constexpr char kApiVersion[] = "api_version"; +constexpr char kApproxTopK[] = "ApproxTopK"; +constexpr char kBackendConfig[] = "backend_config"; +constexpr char kCallTargetName[] = "call_target_name"; +constexpr char kCalledComputations[] = "called_computations"; +constexpr char kHasSideEffect[] = "has_side_effect"; +constexpr char kIsFallback[] = "is_fallback"; +constexpr char kRecallTarget[] = "recall_target"; +constexpr char kReductionDim[] = "reduction_dim"; +constexpr char kReductionInputSizeOverride[] = "reduction_input_size_override"; +constexpr char kTopK[] = "top_k"; + +// MHLO attributes. Module level attributes require namespacing. +constexpr char kMhloCrossProgramPrefetches[] = "mhlo.cross_program_prefetches"; +constexpr char kMhloFrontendAttributes[] = "mhlo.frontend_attributes"; +constexpr char kMhloInputOutputAlias[] = "mhlo.input_output_alias"; +constexpr char kMhloIsDynamic[] = "mhlo.is_dynamic"; +constexpr char kMhloLiteral[] = "mhlo.literal"; +constexpr char kMhloParameterReplication[] = "mhlo.parameter_replication"; +constexpr char kMhloReplication[] = "mhlo.is_same_data_across_replicas"; +constexpr char kMhloSharding[] = "mhlo.sharding"; +constexpr char kMhloSpmdOutputSharding[] = "mhlo.spmd_output_sharding"; +constexpr char kMhloSpmdParametersShardings[] = + "mhlo.spmd_parameters_shardings"; +constexpr char kMhloUseAutoSpmdPartitioning[] = + "mhlo.use_auto_spmd_partitioning"; + +// Miscellaneous string literals. +constexpr char kArgEmptyTuple[] = "arg_empty_tuple"; +constexpr char kArgPrefix[] = "Arg_"; +constexpr char kArgTuple[] = "arg_tuple"; +constexpr char kDefaultLayoutAttrName[] = "xla_shape"; +constexpr char kExecutionThread[] = "execution_thread"; // Array attribute. Same shape as infeed result, but contains a // minor_to_major array for every tensor. -constexpr char kLayoutAttr[] = "layout"; -constexpr char kDefaultLayoutAttrName[] = "xla_shape"; +constexpr char kLayout[] = "layout"; +constexpr char kMain[] = "main"; +constexpr char kRegionPrefix[] = "region_"; +constexpr char kTfAliasingOutput[] = "tf.aliasing_output"; // Passes through everything except for unique_ptr, on which it calls get(). // This exists to allow the generated code to call XLA functions that take a raw @@ -585,7 +624,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( // returns std::nullopt. static std::optional CreateOpShardingFromAttribute( mlir::Operation* op) { - auto shardingAttr = op->getAttrOfType(kShardingAttr); + auto shardingAttr = op->getAttrOfType(kMhloSharding); if (!shardingAttr) return std::nullopt; return xla::ConvertSharding(shardingAttr.getValue()); } @@ -606,7 +645,7 @@ static xla::FrontendAttributes CreateXlaFrontendAttributesFromOp( mlir::Operation* op) { xla::FrontendAttributes frontend_attributes; auto frontend_attributes_dict = - op->getAttrOfType(kFrontendAttributesAttr); + op->getAttrOfType(kMhloFrontendAttributes); if (!frontend_attributes_dict) return frontend_attributes; ConstructFrontendAttributesFromAttribute(frontend_attributes_dict, frontend_attributes); @@ -619,7 +658,7 @@ static void ExtractFrontendAttributesFromFunction( fe_attrs->resize(function.getNumArguments(), std::nullopt); for (int i = 0, end = function.getNumArguments(); i < end; ++i) if (auto fe_attr = function.getArgAttrOfType( - i, kFrontendAttributesAttr)) { + i, kMhloFrontendAttributes)) { xla::FrontendAttributes frontend_attributes; ConstructFrontendAttributesFromAttribute(fe_attr, frontend_attributes); (*fe_attrs)[i] = frontend_attributes; @@ -643,14 +682,14 @@ static void ExtractShardingsFromFunction( std::optional()); for (int i = 0, end = function.getNumArguments(); i < end; ++i) if (auto sharding = - function.getArgAttrOfType(i, kShardingAttr)) + function.getArgAttrOfType(i, kMhloSharding)) (*arg_shardings)[i] = xla::ConvertSharding(sharding.getValue()); ret_shardings->resize(function.getNumResults(), std::optional()); for (int i = 0, end = function.getNumResults(); i < end; ++i) if (auto sharding = - function.getResultAttrOfType(i, kShardingAttr)) + function.getResultAttrOfType(i, kMhloSharding)) (*ret_shardings)[i] = xla::ConvertSharding(sharding.getValue()); } @@ -753,7 +792,7 @@ class ConvertToHloModule { // // TODO(hinsu): Check for dynamic shapes and exit instead of crashing. LogicalResult Run() { - auto main = module_.lookupSymbol("main"); + auto main = module_.lookupSymbol(kMain); if (!main) return module_.emitError( "conversion requires module with `main` function"); @@ -771,8 +810,8 @@ class ConvertToHloModule { // Lower a `mlir::Region` to a `XlaComputation` LogicalResult LowerRegionAsComputation( mlir::Region* region, xla::XlaComputation* func, - std::optional> implicit_operands = - std::nullopt, + llvm::ArrayRef implicit_operands = {}, + llvm::ArrayRef implicit_results = {}, bool ensure_single_arg = false, llvm::ArrayRef> arg_shardings = {}, llvm::ArrayRef> ret_shardings = {}); @@ -786,11 +825,11 @@ class ConvertToHloModule { llvm::ArrayRef> ret_shardings, llvm::ArrayRef> fe_attrs, xla::XlaComputation* result, - std::optional> implicit_operands = - std::nullopt); + llvm::ArrayRef implicit_operands = {}, + llvm::ArrayRef implicit_results = {}); ::xla::HloModuleProto ConsumeMainProto() { - auto main = module_.lookupSymbol("main"); + auto main = module_.lookupSymbol(kMain); // This is an invariant check as Run returns failure if there is no main // function and so the main proto shouldn't be consumed in that case. CHECK(main) << "requires module to have main function"; // Crash Ok. @@ -816,7 +855,7 @@ class ConvertToHloModule { LogicalResult Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, - xla::XlaBuilder* builder, + llvm::ArrayRef implicit_results, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaOp* return_value); @@ -916,12 +955,13 @@ bool SimplyReturnedOp(mlir::Operation* op) { } void BuildGetTupleElementsForTupleResults(mlir::Operation* op, xla::XlaOp tuple, - OpLoweringContext ctx) { + OpLoweringContext ctx, + unsigned num_implicit_results = 0) { const std::optional& sharding = ctx.builder->sharding(); if (sharding.has_value()) { bool is_tuple_sharding = sharding->type() == xla::OpSharding::TUPLE; - assert(!is_tuple_sharding || - op->getNumResults() == sharding->tuple_shardings_size()); + assert(!is_tuple_sharding || (op->getNumResults() + num_implicit_results == + sharding->tuple_shardings_size())); for (auto [index, result] : llvm::enumerate(op->getResults())) { // If `sharding` is not a tuple sharding, then every `get-tuple-element` // gets the same sharding. @@ -956,7 +996,8 @@ LogicalResult ExportXlaOp(CollectiveBroadcastOp op, OpLoweringContext ctx) { } LogicalResult ExportXlaOp(CompositeOp, OpLoweringContext) { - // TODO: b/328526226 - Implement MHLO export for CompositeOp. + // Failure on purpose because `mhlo::CompositeOp` will be handled by + // special purpose logic in `ConvertToHloModule::Lower`. return failure(); } @@ -1544,6 +1585,8 @@ LogicalResult ExportXlaOp(DotOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(DotGeneralOp op, OpLoweringContext ctx) { auto& value_map = *ctx.values; xla::XlaOp lhs, rhs; + // TODO: Support algorithm lowering in followup. + if (op.getAlgorithm().has_value()) return mlir::failure(); if (failed(GetXlaOp(op.getLhs(), value_map, &lhs, op))) return mlir::failure(); if (failed(GetXlaOp(op.getRhs(), value_map, &rhs, op))) @@ -1610,7 +1653,7 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { xla::XlaComputation false_branch; auto& value_map = *ctx.values; - // mhlo.IfOp does not have any operands or blocks-arguments. The computation + // mhlo.IfOp does not have any operands or blocks arguments. The computation // inside the region-blocks use implicit captures of values defined above. // In order to create the xla parameters for functions corresponding to // IfOp regions, we need to infer the a region-block's arguments, using all @@ -1628,10 +1671,10 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { getUsedValuesDefinedAbove(op.getFalseBranch(), op.getFalseBranch(), implicit_false_operand_set); - llvm::SmallVector implicit_true_operands( - implicit_true_operand_set.begin(), implicit_true_operand_set.end()); - llvm::SmallVector implicit_false_operands( - implicit_false_operand_set.begin(), implicit_false_operand_set.end()); + llvm::SmallVector implicit_true_operands = + implicit_true_operand_set.takeVector(); + llvm::SmallVector implicit_false_operands = + implicit_false_operand_set.takeVector(); llvm::SmallVector> ret_shardings = GetResultShardings(ctx.builder->sharding(), op->getNumResults()); @@ -1657,13 +1700,13 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { // implicit captures operands. Also export the instructions within those // regions. if (failed(ctx.converter->LowerRegionAsComputation( - &op.getTrueBranch(), &true_branch, - llvm::ArrayRef(implicit_true_operands), - /*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) || + &op.getTrueBranch(), &true_branch, implicit_true_operands, + /*implicit_results=*/{}, /*ensure_single_arg=*/true, + true_arg_shardings, ret_shardings)) || failed(ctx.converter->LowerRegionAsComputation( - &op.getFalseBranch(), &false_branch, - llvm::ArrayRef(implicit_false_operands), - /*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) { + &op.getFalseBranch(), &false_branch, implicit_false_operands, + /*implicit_results=*/{}, /*ensure_single_arg=*/true, + false_arg_shardings, ret_shardings))) { return failure(); } @@ -1701,7 +1744,7 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { std::vector computations(branches.size()); std::vector computations_p(branches.size()); - // mhlo.CaseOp does not have any operands or blocks-arguments. The computation + // mhlo.CaseOp does not have any operands or blocks arguments. The computation // inside the region-blocks use implicit captures of values defined above. // In order to create the xla parameters for functions corresponding to // CaseOp regions, we need to infer the a region-block's arguments, using all @@ -1715,8 +1758,8 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { for (unsigned i = 0; i < branches.size(); ++i) { llvm::SetVector implicit_operand_set; getUsedValuesDefinedAbove(branches[i], branches[i], implicit_operand_set); - llvm::SmallVector implicit_operands( - implicit_operand_set.begin(), implicit_operand_set.end()); + llvm::SmallVector implicit_operands = + implicit_operand_set.takeVector(); llvm::SmallVector> ret_shardings = GetResultShardings(ctx.builder->sharding(), op->getNumResults()); @@ -1740,8 +1783,9 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { // that region. computations_p[i] = &computations[i]; if (failed(ctx.converter->LowerRegionAsComputation( - &branches[i], computations_p[i], llvm::ArrayRef(implicit_operands), - /*ensure_single_arg*/ true, arg_shardings, ret_shardings))) + &branches[i], computations_p[i], implicit_operands, + /*implicit_results=*/{}, /*ensure_single_arg=*/true, arg_shardings, + ret_shardings))) return failure(); } @@ -1905,12 +1949,12 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { // This feature is at time of writing only used by JAX, and is tested in the // jax2tf backwards compatibility tests. - if (op.getCallTargetName() == "ApproxTopK") { + if (op.getCallTargetName() == kApproxTopK) { auto isSupportedAttrName = [](NamedAttribute attr) { auto name = attr.getName(); - return name == "call_target_name" || name == "backend_config" || - name == "api_version" || name == "called_computations" || - name == "has_side_effect"; + return name == kCallTargetName || name == kBackendConfig || + name == kApiVersion || name == kCalledComputations || + name == kHasSideEffect; }; for (const auto& attr : op->getAttrs()) { if (!isSupportedAttrName(attr)) @@ -1925,9 +1969,9 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { for (auto attr : backend_config) { auto name = attr.getName(); - if (!(name == "top_k" || name == "reduction_dim" || - name == "recall_target" || name == "aggregate_to_topk" || - name == "reduction_input_size_override" || name == "is_fallback")) + if (!(name == kTopK || name == kReductionDim || name == kRecallTarget || + name == kAggregateToTopk || name == kReductionInputSizeOverride || + name == kIsFallback)) return op.emitOpError() << name.getValue() << " is not a supported backend_config" << " attribute for ApproxTopK"; @@ -1969,29 +2013,28 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { << " attribute in backend_config must be of bool type"; return success(); }; - if (failed(checkI64Attr("top_k"))) return failure(); - if (failed(checkI64Attr("reduction_dim"))) return failure(); - if (failed(checkF32Attr("recall_target"))) return failure(); - if (failed(checkBoolAttr("aggregate_to_topk"))) return failure(); - if (failed(checkI64Attr("reduction_input_size_override"))) return failure(); - bool has_is_fallback = backend_config.contains("is_fallback"); - if (has_is_fallback && !backend_config.getAs("is_fallback")) + if (failed(checkI64Attr(kTopK))) return failure(); + if (failed(checkI64Attr(kReductionDim))) return failure(); + if (failed(checkF32Attr(kRecallTarget))) return failure(); + if (failed(checkBoolAttr(kAggregateToTopk))) return failure(); + if (failed(checkI64Attr(kReductionInputSizeOverride))) return failure(); + bool has_is_fallback = backend_config.contains(kIsFallback); + if (has_is_fallback && !backend_config.getAs(kIsFallback)) return op.emitOpError() << "is_fallback attribute in backend_config must be of bool type"; - int64_t top_k = backend_config.getAs("top_k").getInt(); + int64_t top_k = backend_config.getAs(kTopK).getInt(); int64_t reduction_dim = - backend_config.getAs("reduction_dim").getInt(); - float recall_target = backend_config.getAs("recall_target") + backend_config.getAs(kReductionDim).getInt(); + float recall_target = backend_config.getAs(kRecallTarget) .getValue() .convertToFloat(); bool aggregate_to_topk = - backend_config.getAs("aggregate_to_topk").getValue(); + backend_config.getAs(kAggregateToTopk).getValue(); int64_t reduction_input_size_override = - backend_config.getAs("reduction_input_size_override") - .getInt(); + backend_config.getAs(kReductionInputSizeOverride).getInt(); bool is_fallback = has_is_fallback && - backend_config.getAs("is_fallback").getValue(); + backend_config.getAs(kIsFallback).getValue(); // (C1) if (args.size() % 2 != 0) { @@ -2151,7 +2194,7 @@ LogicalResult ExportXlaOp(CustomCallOp op, OpLoweringContext ctx) { absl::StatusOr literal; const xla::Literal* literal_ptr = nullptr; - auto literal_attr = op->getAttrOfType(kLiteralAttr); + auto literal_attr = op->getAttrOfType(kMhloLiteral); if (literal_attr) { literal = CreateArrayLiteralFromAttr(literal_attr, {}); if (!literal.ok()) return failure(); @@ -2712,15 +2755,57 @@ LogicalResult ExportXlaOp(TraceOp op, OpLoweringContext ctx) { LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { xla::XlaComputation condition; xla::XlaComputation body; + // If the results of the while op have a sharding, we use those shardings for // the corresponding arguments and return shardings in the body and condition. llvm::SmallVector> res_shardings = GetResultShardings(ctx.builder->sharding(), op->getNumResults()); + + // mhlo.WhileOp has operands and corresponding blocks arguments, but the + // computation inside its region-blocks can also use implicit captures of + // values defined above. + // In order to create the xla parameters for functions corresponding to + // WhileOp regions, we need to infer the implicit region-block's arguments, + // using all the values used in the region but defined above. + // + // Note that the body and cond regions of WhileOp share the same block + // arguments, so we collect the implicit values for both in a single set. + llvm::SetVector implicit_operand_set; + getUsedValuesDefinedAbove(op->getRegions(), implicit_operand_set); + llvm::SmallVector implicit_operands = + implicit_operand_set.takeVector(); + + llvm::SmallVector implicit_args; + if (failed(GetXlaOps(op, implicit_operands, ctx, implicit_args))) + return failure(); + + // We need to append the shardings of the implicit values to the result + // shardings, since the HLO While will have those implcit values as additional + // operands and results. + llvm::SmallVector> implicit_shardings; + if (!implicit_args.empty() && !res_shardings.empty()) { + // We only add implicit arg shardings if there are result shardings, + // otherwise it means sharding propagation hasn't been done yet. + implicit_shardings = GetXlaOpShardings(implicit_args); + + res_shardings.append(implicit_shardings.begin(), implicit_shardings.end()); + if (std::optional new_sharding = + CreateTupleSharding(res_shardings)) { + ctx.builder->SetSharding(*new_sharding); + } + } + + // The body of the While needs to return the same number of values as its + // arguments, as they are carried over to the next iteration. Thus, we pass + // the `implicit_operands` as `implicit_results`, to carry them over as is. if (failed(ctx.converter->LowerRegionAsComputation( - &op.getBody(), &body, std::nullopt, /*ensure_single_arg=*/true, - /*arg_shardings=*/res_shardings, /*ret_shardings=*/res_shardings)) || + &op.getBody(), &body, implicit_operands, + /*implicit_results=*/implicit_operands, + /*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings, + /*ret_shardings=*/res_shardings)) || failed(ctx.converter->LowerRegionAsComputation( - &op.getCond(), &condition, std::nullopt, + &op.getCond(), &condition, implicit_operands, + /*implicit_results=*/{}, /*ensure_single_arg=*/true, /*arg_shardings=*/res_shardings))) { return failure(); } @@ -2729,11 +2814,12 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { // those operands, to be used as sole operand of xla::While. llvm::SmallVector operands; if (failed(GetTuple(op, op.getOperands(), ctx, operands))) return failure(); + operands.append(implicit_args.begin(), implicit_args.end()); xla::XlaOp operand = operands[0]; if (operands.size() > 1) operand = Tuple(ctx.builder, operands); - auto whileop = xla::While(condition, body, operand); + xla::XlaOp whileop = xla::While(condition, body, operand); auto& value_map = *ctx.values; auto shape_or = whileop.builder()->GetShape(whileop); @@ -2748,7 +2834,8 @@ LogicalResult ExportXlaOp(WhileOp op, OpLoweringContext ctx) { } // mhlo.WhileOp supports multiple returns, untuple all the results of XLA's. - BuildGetTupleElementsForTupleResults(op, whileop, ctx); + BuildGetTupleElementsForTupleResults( + op, whileop, ctx, /*num_implicit_results=*/implicit_args.size()); return success(); } @@ -2824,11 +2911,11 @@ LogicalResult ExportXlaOp(BitcastOp op, OpLoweringContext ctx) { xla::internal::XlaBuilderFriend::GetInstruction(operand); xla::LayoutProto result_layout = ExtractLayout(op, bitcast_proto->shape().dimensions_size(), - "result_layout") + kResultLayout) .ToProto(); xla::LayoutProto source_layout = ExtractLayout(op, operand_proto->shape().dimensions_size(), - "source_layout") + kSourceLayout) .ToProto(); xla::gpu::BitcastBackendConfig bitcast_config; *bitcast_config.mutable_source_layout() = source_layout; @@ -2879,7 +2966,7 @@ namespace { LogicalResult ConvertLayout(mlir::Operation* op, const mlir::ArrayAttr& layout, xla::ShapeProto* shape) { - // In the case of tuples, ShapeProtos can be nested, and so can the mlir + // In the case of tuples, Shape protos can be nested, and so can the mlir // attribute describing the layout. So recurse into the subshapes in both data // structures in parallel. if (shape->element_type() == xla::TUPLE) { @@ -3045,7 +3132,7 @@ LogicalResult ExportXlaOperatorWrapped(mlir::Operation* inst, LogicalResult ConvertToHloModule::Lower( mlir::Operation* inst, bool is_entry_function, llvm::ArrayRef> ret_shardings, - xla::XlaBuilder* builder, + llvm::ArrayRef implicit_results, xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaOp* return_value) { // Explicitly fail for ops that are not supported for export. @@ -3092,8 +3179,7 @@ LogicalResult ConvertToHloModule::Lower( // For infeed ops stemming back to InfeedDequeueTuple, respect the // layout attribute, and create the corresponding layout in hlo. if (isa(inst)) { - mlir::ArrayAttr layout = - inst->getAttrOfType(kLayoutAttr); + mlir::ArrayAttr layout = inst->getAttrOfType(kLayout); if (layout) { // We propagate layout to the following three ops: @@ -3222,40 +3308,51 @@ LogicalResult ConvertToHloModule::Lower( if (isa(inst)) { // Construct the return value for the function. If there is a single value // returned, then return it directly, else create a tuple and return. - unsigned num_return_values = inst->getNumOperands(); + unsigned num_return_values = + inst->getNumOperands() + implicit_results.size(); std::optional ret_tuple_sharding = CreateTupleSharding(ret_shardings); if ((options_.return_tuple && is_entry_function) || num_return_values != 1) { - std::vector returns(num_return_values); - for (OpOperand& ret : inst->getOpOperands()) { - unsigned index = ret.getOperandNumber(); - xla::XlaOp operand; - if (failed(GetXlaOp(ret.get(), value_map, &operand, inst))) - return failure(); - - returns[index] = operand; - if (!is_entry_function || !ret_tuple_sharding) continue; - - xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); - absl::StatusOr reshape = - ReshapeWithCorrectRepresentationAndSharding( - builder, returns[index], return_shape, - options_.layout_preference_fn, options_.shape_representation_fn, - ret_shardings[index], /*fast_mem=*/false); - if (!reshape.ok()) - return inst->emitError() << reshape.status().message(); - - returns[index] = reshape.value(); + std::vector returns; + returns.reserve(num_return_values); + // NOTE: we can't use operand_range in llvm::concat. + for (Value ret : inst->getOperands()) { + xla::XlaOp& operand = returns.emplace_back(); + if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure(); + } + for (Value ret : implicit_results) { + xla::XlaOp& operand = returns.emplace_back(); + if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure(); + } + if (is_entry_function && ret_tuple_sharding) { + assert(implicit_results.empty() && + "entry functions shouldn't have implicit results"); + for (OpOperand& ret : inst->getOpOperands()) { + unsigned index = ret.getOperandNumber(); + + xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); + absl::StatusOr reshape = + ReshapeWithCorrectRepresentationAndSharding( + builder, returns[index], return_shape, + options_.layout_preference_fn, + options_.shape_representation_fn, ret_shardings[index], + /*fast_mem=*/false); + if (!reshape.ok()) + return inst->emitError() << reshape.status().message(); + + returns[index] = reshape.value(); + } } xla::XlaScopedShardingAssignment scoped_sharding(builder, ret_tuple_sharding); *return_value = xla::Tuple(builder, returns); } else if (num_return_values == 1) { + Value ret = implicit_results.empty() ? inst->getOperand(0) + : implicit_results.front(); xla::XlaOp operand; - if (failed(GetXlaOp(inst->getOperand(0), value_map, &operand, inst))) - return failure(); + if (failed(GetXlaOp(ret, value_map, &operand, inst))) return failure(); if (ret_tuple_sharding) { auto tuple = Tuple(builder, {operand}); @@ -3270,6 +3367,59 @@ LogicalResult ConvertToHloModule::Lower( return success(); } + if (auto composite_op = dyn_cast(inst)) { + SmallVector operands; + for (const Value& val : inst->getOperands()) { + xla::XlaOp operand; + if (failed(GetXlaOp(val, value_map, &operand, inst))) { + return failure(); + } + operands.push_back(operand); + } + + xla::XlaComputation computation; + if (failed(LowerBasicBlockAsFunction( + /*block=*/&module_ + .lookupSymbol( + composite_op.getDecomposition()) + .getBody() + .front(), + /*builder=*/ + module_builder_ + .CreateSubBuilder(composite_op.getDecomposition().str()) + .get(), + /*is_entry_function=*/false, + /*ensure_single_arg=*/false, + /*entry_args_same_across_replicas=*/{}, + /*arg_shardings=*/{}, /*ret_shardings=*/{}, + /*fe_attrs=*/{}, /*result=*/&computation, + /*implicit_operands=*/{}))) { + return failure(); + } + + std::string composite_attributes; + llvm::raw_string_ostream(composite_attributes) + << composite_op.getCompositeAttributes(); + + xla::XlaOp composite_call = xla::CompositeCall( + builder, computation, operands, composite_op.getName().str(), + composite_attributes, composite_op.getVersion()); + + // Use GetTupleElement for multiple outputs + unsigned num_results = composite_op.getNumResults(); + if (num_results > 1) { + for (unsigned i = 0; i != num_results; ++i) { + value_map[composite_op.getResult(i)] = + xla::GetTupleElement(composite_call, i); + } + } else if (num_results == 1) { + value_map[composite_op.getResult(0)] = composite_call; + } + *return_value = composite_call; + + return success(); + } + inst->emitOpError() << "can't be translated to XLA HLO"; return failure(); } @@ -3318,7 +3468,7 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { // Create a sub-builder if this is not the main function. std::unique_ptr builder_up; - bool entry_function = f.getName() == "main"; + bool entry_function = f.getName() == kMain; if (!entry_function) builder_up = module_builder_.CreateSubBuilder(f.getName().str()); auto& builder = entry_function ? module_builder_ : *builder_up; @@ -3332,14 +3482,14 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { bool any_arg_replicated = false; entry_args_same_across_replicas.reserve(f.getNumArguments()); for (int64_t i = 0; i < f.getNumArguments(); ++i) { - auto attr = f.getArgAttrOfType(i, kReplicationAttr); + auto attr = f.getArgAttrOfType(i, kMhloReplication); entry_args_same_across_replicas.push_back(attr != nullptr && attr.getValue()); any_arg_replicated |= entry_args_same_across_replicas.back(); // Pass the alias info to the builder so that it will build the alias info // into the resulting HloModule. auto buffer_donor = - f.getArgAttrOfType(i, "jax.buffer_donor"); + f.getArgAttrOfType(i, kJaxBufferDonor); if (buffer_donor) { if (options_.use_tuple_args) { builder.AddBufferDonor(/*param_number=*/0, /*param_index=*/{i}); @@ -3348,7 +3498,7 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { } } auto aliasing_output = - f.getArgAttrOfType(i, "tf.aliasing_output"); + f.getArgAttrOfType(i, kTfAliasingOutput); if (!aliasing_output) continue; xla::ShapeIndex output_index; if ((options_.return_tuple && entry_function) || f.getNumResults() != 1) { @@ -3383,13 +3533,13 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::func::FuncOp f) { return failure(); } if (auto execution_thread = - f->getAttrOfType("execution_thread")) { + f->getAttrOfType(kExecutionThread)) { computation.mutable_proto()->mutable_computations(0)->set_execution_thread( execution_thread.str()); } for (int i = 0; i < f.getNumArguments(); ++i) { if (auto pr = - f.getArgAttrOfType(i, kParameterReplicationAttr)) { + f.getArgAttrOfType(i, kMhloParameterReplication)) { for (auto b : pr.getValue()) for (auto& instr : *computation.mutable_proto() ->mutable_computations(0) @@ -3494,8 +3644,8 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( llvm::ArrayRef> arg_shardings, llvm::ArrayRef> ret_shardings, llvm::ArrayRef> fe_attrs, - xla::XlaComputation* result, - std::optional> implicit_operands) { + xla::XlaComputation* result, llvm::ArrayRef implicit_operands, + llvm::ArrayRef implicit_results) { // Mapping from the Value to lowered XlaOp. ValueLoweringMap lowering; @@ -3519,7 +3669,7 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // fuse all the `mlir::Location`s or join the operation name strings with // ";" (which is essentially the same). auto tuple = - xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication); + xla::Parameter(builder, 0, input_shape, kArgTuple, leaf_replication); builder->ClearSharding(); for (BlockArgument& arg : block->getArguments()) { @@ -3533,17 +3683,16 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp. llvm::SmallVector arg_shapes; - auto args_size = block->getNumArguments(); - if (implicit_operands) args_size = implicit_operands->size(); + // Lowering supports mix of block args and implicit operands + // Block args must be added before implicit capture operands + + auto args_size = block->getNumArguments() + implicit_operands.size(); arg_shapes.reserve(args_size); - if (implicit_operands) { - for (auto implicit_operand : *implicit_operands) - arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType())); - } else { - for (BlockArgument& arg : block->getArguments()) - arg_shapes.push_back(xla::TypeToShape(arg.getType())); - } + for (BlockArgument& arg : block->getArguments()) + arg_shapes.push_back(xla::TypeToShape(arg.getType())); + for (Value implicit_operand : implicit_operands) + arg_shapes.push_back(xla::TypeToShape(implicit_operand.getType())); if (args_size > 1) { xla::XlaScopedShardingAssignment scoped_sharding( @@ -3554,26 +3703,23 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // but not tuple params. Do the same for tuple params. To do so, either // fuse all the `mlir::Location`s or join the operation name strings // with ";" (which is essentially the same). - auto tuple = xla::Parameter(builder, 0, - xla::ShapeUtil::MakeTupleShape(arg_shapes), - "arg_tuple"); - - if (implicit_operands) { - for (auto [arg_index, implicit_operand] : - llvm::enumerate(*implicit_operands)) { - xla::XlaScopedShardingAssignment scoped_sharding( - builder, arg_shardings.empty() ? std::nullopt - : arg_shardings[arg_index]); - lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index); - } - } else { - for (BlockArgument& arg : block->getArguments()) { - auto num = arg.getArgNumber(); - xla::XlaScopedShardingAssignment scoped_sharding( - builder, - arg_shardings.empty() ? std::nullopt : arg_shardings[num]); - lowering[arg] = xla::GetTupleElement(tuple, num); - } + auto tuple = xla::Parameter( + builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes), kArgTuple); + + for (BlockArgument& arg : block->getArguments()) { + auto num = arg.getArgNumber(); + xla::XlaScopedShardingAssignment scoped_sharding( + builder, + arg_shardings.empty() ? std::nullopt : arg_shardings[num]); + lowering[arg] = xla::GetTupleElement(tuple, num); + } + for (auto [implicit_index, implicit_operand] : + llvm::enumerate(implicit_operands)) { + int64_t arg_index = block->getNumArguments() + implicit_index; + xla::XlaScopedShardingAssignment scoped_sharding( + builder, + arg_shardings.empty() ? std::nullopt : arg_shardings[arg_index]); + lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index); } } else if (args_size == 1) { // Save the location information as a name. For example JAX will set the @@ -3581,23 +3727,17 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( xla::XlaScopedShardingAssignment scoped_sharding( builder, arg_shardings.empty() ? std::nullopt : arg_shardings.front()); - if (implicit_operands) { - mlir::Value arg = (*implicit_operands)[0]; - xla::XlaScopedOpMetadataAssignment op_metadata( - builder, GetOpNameMetadataFromLocation(arg)); - lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); - } else { - mlir::BlockArgument arg = block->getArgument(0); - xla::XlaScopedOpMetadataAssignment op_metadata( - builder, GetOpNameMetadataFromLocation(arg)); - lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); - } + mlir::Value arg = implicit_operands.empty() ? block->getArgument(0) + : implicit_operands.front(); + xla::XlaScopedOpMetadataAssignment op_metadata( + builder, GetOpNameMetadataFromLocation(arg)); + lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], kArgPrefix); } else { // Applicable only for IfOp or CaseOp. No implicit operands implies no // xla parameters. In this case, we create an empty tuple as the // block-parameter. xla::Parameter(builder, 0, xla::ShapeUtil::MakeTupleShape(arg_shapes), - "arg_empty_tuple"); + kArgEmptyTuple); } } else { for (BlockArgument& arg : block->getArguments()) { @@ -3616,11 +3756,11 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( xla::XlaScopedOpMetadataAssignment op_metadata( builder, GetOpNameMetadataFromLocation(arg)); if (entry_args_same_across_replicas.empty()) { - lowering[arg] = - xla::Parameter(builder, num, shape, absl::StrCat("Arg_", num)); + lowering[arg] = xla::Parameter(builder, num, shape, + absl::StrCat(kArgPrefix, num)); } else { lowering[arg] = xla::Parameter( - builder, num, shape, absl::StrCat("Arg_", num), + builder, num, shape, absl::StrCat(kArgPrefix, num), std::vector(entry_args_same_across_replicas[num], xla::ShapeUtil::GetLeafCount(shape))); } @@ -3631,8 +3771,8 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( xla::XlaOp return_value; for (auto& inst : *block) - if (failed(Lower(&inst, is_entry_function, ret_shardings, builder, - &lowering, &return_value))) + if (failed(Lower(&inst, is_entry_function, ret_shardings, implicit_results, + builder, &lowering, &return_value))) return failure(); // Build the XlaComputation and check for failures. @@ -3648,18 +3788,18 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( LogicalResult ConvertToHloModule::LowerRegionAsComputation( mlir::Region* region, xla::XlaComputation* func, - std::optional> implicit_operands, - bool ensure_single_arg, + llvm::ArrayRef implicit_operands, + llvm::ArrayRef implicit_results, bool ensure_single_arg, llvm::ArrayRef> arg_shardings, llvm::ArrayRef> ret_shardings) { - std::unique_ptr builder = - module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++)); - return LowerBasicBlockAsFunction(®ion->front(), builder.get(), - /*is_entry_function=*/false, - /*ensure_single_arg*/ ensure_single_arg, - /*entry_args_same_across_replicas=*/{}, - arg_shardings, ret_shardings, - /*fe_attrs=*/{}, func, implicit_operands); + std::unique_ptr builder = module_builder_.CreateSubBuilder( + absl::StrCat(kRegionPrefix, region_id_++)); + return LowerBasicBlockAsFunction( + ®ion->front(), builder.get(), + /*is_entry_function=*/false, + /*ensure_single_arg*/ ensure_single_arg, + /*entry_args_same_across_replicas=*/{}, arg_shardings, ret_shardings, + /*fe_attrs=*/{}, func, implicit_operands, implicit_results); } // Runs the PrepareForExport pass on the ModuleOp. @@ -3704,40 +3844,46 @@ absl::Status ConvertMlirHloToHlo(mlir::ModuleOp module, TF_RETURN_IF_ERROR(PrepareForExport(module)); mlir::BaseScopedDiagnosticHandler diag_handler(module.getContext()); - xla::XlaBuilder module_builder("main"); + xla::XlaBuilder module_builder(kMain); ConvertToHloModule converter(module, module_builder, options); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); - StringRef module_name = module.getName() ? *module.getName() : "main"; + StringRef module_name = module.getName() ? *module.getName() : kMain; hlo_module.set_name(module_name.str()); - if (auto cross_program_prefetches = module->getAttrOfType( - "mhlo.cross_program_prefetches")) { + if (auto cross_program_prefetches = + module->getAttrOfType(kMhloCrossProgramPrefetches)) { for (const auto& prefetch : Convert_cross_program_prefetches(cross_program_prefetches)) { *hlo_module.add_cross_program_prefetches() = std::move(prefetch); } } - if (auto is_dynamic = - module->getAttrOfType("mhlo.is_dynamic")) { + if (auto is_dynamic = module->getAttrOfType(kMhloIsDynamic)) { hlo_module.set_is_dynamic(is_dynamic.getValue()); } if (auto frontend_attributes = - module->getAttrOfType(kFrontendAttributesAttr)) { + module->getAttrOfType(kMhloFrontendAttributes)) { ConstructFrontendAttributesFromAttribute( frontend_attributes, *hlo_module.mutable_frontend_attributes()); } - if (auto use_auto_spmd_partitioning = module->getAttrOfType( - "mhlo.use_auto_spmd_partitioning")) { + if (auto use_auto_spmd_partitioning = + module->getAttrOfType(kMhloUseAutoSpmdPartitioning)) { hlo_module.set_use_auto_spmd_partitioning( use_auto_spmd_partitioning.getValue()); } - if (auto spmd_output_sharding = module->getAttrOfType( - "mhlo.spmd_output_sharding")) { + if (auto spmd_output_sharding = + module->getAttrOfType(kMhloSpmdOutputSharding)) { *hlo_module.mutable_spmd_output_sharding() = *xla::ConvertSharding(spmd_output_sharding.getValue()); } + if (auto input_output_alias = + module->getAttrOfType(kMhloInputOutputAlias)) { + if (std::optional input_output_alias_proto = + xla::ConvertInputOutputAlias(input_output_alias.getValue())) { + *hlo_module.mutable_input_output_alias() = *input_output_alias_proto; + } + } if (auto spmd_parameters_sharding = module->getAttrOfType( - "mhlo.spmd_parameters_shardings")) { + kMhloSpmdParametersShardings)) { for (const auto& sharding : spmd_parameters_sharding.getValue()) { *hlo_module.add_spmd_parameters_shardings() = *xla::ConvertSharding( mlir::cast(sharding).getValue()); @@ -3805,7 +3951,8 @@ absl::Status BuildHloFromMlirHlo(mlir::Block& block, xla::XlaBuilder& builder, } else { xla::XlaOp return_value; if (failed(converter.Lower(&inst, /*is_entry_function=*/true, - /*ret_shardings=*/{}, &builder, &lowering, + /*ret_shardings=*/{}, + /*implicit_results=*/{}, &builder, &lowering, &return_value))) return diag_handler.ConsumeStatus(); } diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc index 88f05238846b7a..7dad7c322e2d84 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/module_config_exporter.cc @@ -22,18 +22,20 @@ limitations under the License. namespace mlir { namespace mhlo { namespace { -constexpr char kConfigNumPartitions[] = "mhlo.num_partitions"; -constexpr char kConfigNumReplicas[] = "mhlo.num_replicas"; + +constexpr char kMhloNumPartitions[] = "mhlo.num_partitions"; +constexpr char kMhloNumReplicas[] = "mhlo.num_replicas"; + } // namespace void ExportHloModuleConfig(xla::HloModuleConfig& config, mlir::ModuleOp module) { if (auto num_partitions = - module->getAttrOfType(kConfigNumPartitions)) { + module->getAttrOfType(kMhloNumPartitions)) { config.set_num_partitions(num_partitions.getInt()); } if (auto num_replicas = - module->getAttrOfType(kConfigNumReplicas)) { + module->getAttrOfType(kMhloNumReplicas)) { config.set_replica_count(num_replicas.getInt()); } } diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD index 37354bae635287..f94730765b7d13 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/BUILD @@ -12,9 +12,11 @@ lit_test_suite( [ "add.mlir", "case.mlir", + "composite.mlir", "dynamic.mlir", "export-with-layouts.mlir", "export.mlir", + "export_async.mlir", "export_and_check_layouts.mlir", "export_large_constants.mlir", "export_replicas.mlir", @@ -36,6 +38,7 @@ lit_test_suite( "simple.mlir", "unsupported_type.mlir", "while.mlir", + "while_free_vars.mlir", ], include = [ "*.mlir", diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir new file mode 100644 index 00000000000000..60c5548587a677 --- /dev/null +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/composite.mlir @@ -0,0 +1,190 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s | FileCheck %s + +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { + // CHECK: %Arg_0.3 = f32[] parameter(0) + // CHECK: %constant.4 = f32[] constant(2) + // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + // CHECK: } + // CHECK: ENTRY %main.7 () -> f32[] { + // CHECK: %constant.1 = f32[] constant(42) + // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add, + version = 1 : i32 + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// zero-output composite +module @composite { + //CHECK: HloModule composite, entry_computation_layout={()->()} + //CHECK: %return.2 (Arg_0.3: f32[]) -> () { + //CHECK: %Arg_0.3 = f32[] parameter(0) + //CHECK: ROOT %tuple.4 = () tuple() + //CHECK: } + //CHECK: ENTRY %main.7 () -> () { + //CHECK: %constant.1 = f32[] constant(42) + //CHECK: %call.5 = () call(f32[] %constant.1), to_apply=%return.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: ROOT %tuple.6 = () tuple() + //CHECK: } + func.func @main() -> () { + %0 = mhlo.constant dense<4.200000e+01> : tensor + "mhlo.composite"(%0) { + name = "foo.bar", + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @return, + version = 1 : i32 + } : (tensor) -> () + return + } + func.func @return(%arg0: tensor) -> () { + return + } +} + +// ----- + +// multi-output composite +module @composite { + //CHECK: HloModule composite, entry_computation_layout={()->(f32[], f32[])} + //CHECK: %add.2 (Arg_0.3: f32[]) -> (f32[], f32[]) { + //CHECK: %Arg_0.3 = f32[] parameter(0) + //CHECK: %constant.4 = f32[] constant(2) + //CHECK: %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + //CHECK: ROOT %tuple.6 = (f32[], f32[]) tuple(f32[] %add.5, f32[] %add.5) + //CHECK: } + //CHECK: ENTRY %main.11 () -> (f32[], f32[]) { + //CHECK: %constant.1 = f32[] constant(42) + //CHECK: %call.7 = (f32[], f32[]) call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="1"} + //CHECK: %get-tuple-element.8 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=0 + //CHECK: %get-tuple-element.9 = f32[] get-tuple-element((f32[], f32[]) %call.7), index=1 + //CHECK: ROOT %tuple.10 = (f32[], f32[]) tuple(f32[] %get-tuple-element.8, f32[] %get-tuple-element.9) + //CHECK: } + func.func @main() -> (tensor, tensor) { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %result:2 = "mhlo.composite"(%0) { + name = "foo.bar", + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add, + version = 1 : i32 + } : (tensor) -> (tensor, tensor) + return %result#0, %result#1 : tensor, tensor + } + func.func @add(%arg0: tensor) -> (tensor, tensor) { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1, %1 : tensor, tensor + } +} + +// ----- + +// optional composite attributes +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { + // CHECK: %Arg_0.3 = f32[] parameter(0) + // CHECK: %constant.4 = f32[] constant(2) + // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + // CHECK: } + // CHECK: ENTRY %main.7 () -> f32[] { + // CHECK: %constant.1 = f32[] constant(42) + // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="1"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + decomposition = @add, + version = 1 : i32 + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// optional composite version +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { + // CHECK: %Arg_0.3 = f32[] parameter(0) + // CHECK: %constant.4 = f32[] constant(2) + // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + // CHECK: } + // CHECK: ENTRY %main.7 () -> f32[] { + // CHECK: %constant.1 = f32[] constant(42) + // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={n = 1 : i32, tensor = dense<1> : tensor},composite.name="foo.bar",composite.version="0"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + composite_attributes = { + n = 1 : i32, + tensor = dense<1> : tensor + }, + decomposition = @add + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} + +// ----- + +// optional composite attributes and version +module @composite { + // CHECK: HloModule composite, entry_computation_layout={()->f32[]} + // CHECK: %add.2 (Arg_0.3: f32[]) -> f32[] { + // CHECK: %Arg_0.3 = f32[] parameter(0) + // CHECK: %constant.4 = f32[] constant(2) + // CHECK: ROOT %add.5 = f32[] add(f32[] %Arg_0.3, f32[] %constant.4) + // CHECK: } + // CHECK: ENTRY %main.7 () -> f32[] { + // CHECK: %constant.1 = f32[] constant(42) + // CHECK: ROOT %call.6 = f32[] call(f32[] %constant.1), to_apply=%add.2, is_composite=true, frontend_attributes={composite.attributes={},composite.name="foo.bar",composite.version="0"} + // CHECK: } + func.func @main() -> tensor { + %0 = mhlo.constant dense<4.200000e+01> : tensor + %1 = mhlo.composite "foo.bar" %0 { + decomposition = @add + } : (tensor) -> tensor + return %1 : tensor + } + func.func @add(%arg0: tensor) -> tensor { + %0 = mhlo.constant dense<2.000000e+00> : tensor + %1 = mhlo.add %arg0, %0 : tensor + return %1 : tensor + } +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir index 680341f0d899ac..3d44aff99a7226 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export-with-layouts.mlir @@ -1,5 +1,5 @@ // RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts %s | FileCheck %s -// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text -with-layouts -print-layouts --via-builder=true %s | FileCheck %s #CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir index dec3e5dcae858f..6672e62daf04de 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export.mlir @@ -1,6 +1,16 @@ // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s // RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics --via-builder=true %s | FileCheck %s +// CHECK: HloModule foo +// CHECK: ENTRY %main +module @foo { + func.func @main(%arg: tensor) -> tensor { + func.return %arg : tensor + } +} + +// ----- + // CHECK: HloModule func.func @main(%arg0: tensor<2xi1>) -> tensor<2xi1> { %0 = "mhlo.add"(%arg0, %arg0) : (tensor<2xi1>, tensor<2xi1>) -> tensor<2xi1> @@ -109,114 +119,6 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { // ----- -// CHECK: HloModule -func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { - %0 = "mhlo.all_gather"(%arg1) { - all_gather_dim = 1 : i64, - channel_handle = #mhlo.channel_handle, - shard_count = 4, - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - use_global_device_ids - } : (tensor<128x32xf32>) -> tensor<128x128xf32> - return %0 : tensor<128x128xf32> -} - -func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { - %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x128xf32>> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x128xf32>>) -> tensor<128x128xf32> - return %1 : tensor<128x128xf32> -} - -// CHECK: ENTRY -// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]]) -// CHECK-SAME: channel_id=1 -// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} -// CHECK-SAME: dimensions={1} -// CHECK-SAME: use_global_device_ids=true -// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]] - -// ----- - -// CHECK: HloModule -func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} { - %0 = "mhlo.all_reduce"(%arg0) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.maximum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 2 - >, - use_global_device_ids - } : (tensor<10xf32>) -> tensor<10xf32> - func.return %0 : tensor<10xf32> -} - -func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { - %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle, tensor<10xf32>> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<10xf32>>) -> tensor<10xf32> - return %1 : tensor<10xf32> -} - -// CHECK: ENTRY -// CHECK: %[[INPUT:.*]] = f32[10] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]]) -// CHECK-SAME: channel_id=5 -// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} -// CHECK-SAME: use_global_device_ids=true -// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]] - -// ----- - -// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} -func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} { - %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({ - ^bb0(%lhs: tensor, %rhs: tensor): - %max = mhlo.maximum %lhs, %rhs : tensor - "mhlo.return"(%max) : (tensor) -> () - }) - { - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 2 - >, - use_global_device_ids - } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) - func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32> -} - -func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) { - %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle,tensor<1xf32>>, tuple,tensor<1xf32>>> - %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle,tensor<1xf32>>, tuple,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>) - return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32> -} - -// ----- - -// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} -func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} { - %0:2 = "mhlo.all_gather"(%arg0, %arg1) { - all_gather_dim = 1 : i64, - channel_handle = #mhlo.channel_handle, - replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, - use_global_device_ids - } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) - func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32> -} - -func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) { - %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>> - %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>) - return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32> -} - -// ----- - func.func private @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> tuple, tensor<8x16xf32>> { // CHECK: %[[ARG0:.*]] = f32[8,2] parameter(0) // CHECK-NEXT: %[[ARG1:.*]] = f32[8,4] parameter(1) @@ -624,30 +526,6 @@ func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { // ----- -// CHECK: HloModule -func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - %0 = "mhlo.collective_permute"(%arg0) { - source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, - channel_handle = #mhlo.channel_handle - } : (tensor<128x32xf32>) -> tensor<128x32xf32> - func.return %0 : tensor<128x32xf32> -} - -func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { - %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x32xf32>> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x32xf32>>) -> tensor<128x32xf32> - return %1 : tensor<128x32xf32> -} - -// CHECK: ENTRY -// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]]) -// CHECK-SAME: channel_id=1 -// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}} -// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]] - -// ----- - // CHECK: HloModule func.func @main(%arg0 : tensor<5x2xf32>, %arg1 : tensor<5x5xf32>, @@ -890,27 +768,6 @@ func.func @main(%arg0: tensor<2xi32>) -> tensor<2xi32> { // CHECK: [[ARG:%.*]] = s32[2] parameter(0) // CHECK: ROOT %[[RESULT:.*]] = s32[2] copy(s32[2] [[ARG]]) -// ----- - -// CHECK: HloModule -func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { - %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32> - func.return %0 : tensor<128x32xf32> -} - -func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { - %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x32xf32>> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x32xf32>>) -> tensor<128x32xf32> - return %1 : tensor<128x32xf32> -} - -// CHECK: ENTRY -// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) -// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]]) -// CHECK-SAME: cross_program_prefetch_index=0 -// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]] - - // ----- // CHECK: HloModule @@ -2130,67 +1987,6 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // ----- -// CHECK: HloModule - -func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} { - %0 = "mhlo.recv"(%token) { - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 1 // Device to device channel - >, - is_host_transfer = false - } : (!mhlo.token) -> (!mhlo.token) - func.return %0 : !mhlo.token -} - -func.func @main(%token: !mhlo.token) -> (!mhlo.token) { - %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle> - %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle>) -> (!mhlo.token) - return %2 : !mhlo.token -} - -// CHECK: ENTRY -// CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5 -// CHECK: ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5 - -// ----- - -// CHECK: HloModule -func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} { - %0:2 = "mhlo.recv"(%token) { - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 3 // Host to device channel - >, - is_host_transfer = true - } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) - func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token -} - -func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) { - %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, tensor> - %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle, !mhlo.token>, tensor>) -> (tensor<3x4xi32>, !mhlo.token) - return %1, %2 : tensor<3x4xi32>, !mhlo.token -} - -// CHECK: ENTRY -// CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer -// CHECK-SAME: sharding={ -// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0} -// CHECK-SAME: } -// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer -// CHECK-SAME: sharding={ -// CHECK-SAME: {maximal device=0}, {maximal device=0} -// CHECK-SAME: } -// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0} -// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0} -// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) - -// ----- - - // CHECK: HloModule func.func @main(%arg0 : tensor<1x10xf32>, %arg1 : tensor<1x10xi32>, %arg2 : tensor, %arg3 : tensor) -> (tensor<1xf32>, tensor<1xi32>) { %result0, %result1 = "mhlo.reduce"(%arg0, %arg1, %arg2, %arg3) ({ @@ -2465,58 +2261,6 @@ func.func @main(%token: !mhlo.token) -> !mhlo.token { // ----- -// CHECK: HloModule -func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { - %0 = "mhlo.send"(%arg, %token) { - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 2 // Device to host channel - >, - is_host_transfer = true - } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token - func.return %0 : !mhlo.token -} - -func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { - %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor>) -> !mhlo.token - return %1 : !mhlo.token -} - -// CHECK: ENTRY -// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) -// CHECK: [[TOKEN:%.*]] = token[] parameter(1) -// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer -// CHECK: ROOT -// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5 - -// ----- - -// CHECK: HloModule -func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { - %0 = "mhlo.send"(%token) { - channel_handle = #mhlo.channel_handle< - handle = 5, - type = 1 // Device to device channel - > - } : (!mhlo.token) -> !mhlo.token - func.return %0 : !mhlo.token -} - -func.func @main(%token: !mhlo.token) -> !mhlo.token { - %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle> - %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle>) -> !mhlo.token - return %1 : !mhlo.token -} - -// CHECK: ENTRY -// CHECK: [[TOKEN:%.*]] = token[] parameter(0) -// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5 -// CHECK: ROOT -// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5 - -// ----- - // CHECK: HloModule func.func @main(%arg: tensor<4x4xf32>, %size: tensor) -> tensor<4x4xf32> { %0 = "mhlo.set_dimension_size"(%arg, %size) {dimension = 1 : i64} : (tensor<4x4xf32>, tensor) -> tensor<4x4xf32> @@ -2935,55 +2679,6 @@ func.func @main(%arg0: tensor<4x2xf32>, %arg1: tensor<4x2xi32>, %init0: tensor : tensor<2xi64> } : (tensor<4x2xf32>, tensor<4x2xi32>, tensor, tensor) -> (tensor<2x2xf32>, tensor<2x2xi32>) func.return %0#0, %0#1 : tensor<2x2xf32>, tensor<2x2xi32> } -// ----- - -// CHECK: HloModule -// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] { -func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> - attributes {execution_thread = "thread"} { - %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32> - return %0 : tensor<20xf32> -} - -// CHECK: ENTRY -func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { - // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]) - // CHECK-SAME: calls=[[CALLED_COMPUTATION]] - %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) - %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) - %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> - return %2 : tensor<20xf32> -} - -// ----- - -// CHECK: HloModule -// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] { -func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> - attributes {execution_thread = "thread"} { - %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32> - // CHECK: custom-call - // CHECK-SAME: custom_call_target="bar" - return %1 : tensor<20xf32> -} - -// CHECK: ENTRY -func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { - // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) - // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], - // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) - // CHECK: ROOT - // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) - - %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> - %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> - return %2 : tensor<20xf32> -} - // ----- diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir new file mode 100644 index 00000000000000..70bf10c8d045c8 --- /dev/null +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/export_async.mlir @@ -0,0 +1,312 @@ +// RUN: xla-translate --print-sugar=false -split-input-file -mlir-hlo-to-hlo-text -verify-diagnostics %s | FileCheck %s + +// CHECK: HloModule +func.func @all_gather_0(%arg1: tensor<128x32xf32>) -> tensor<128x128xf32> attributes {execution_thread = "main"} { + %0 = "mhlo.all_gather"(%arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + shard_count = 4, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<128x32xf32>) -> tensor<128x128xf32> + return %0 : tensor<128x128xf32> +} + +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x128xf32> { + %0 = "mhlo.async_start"(%arg0) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x128xf32>> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x128xf32>>) -> tensor<128x128xf32> + return %1 : tensor<128x128xf32> +} + +// CHECK: ENTRY +// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) +// CHECK: %[[OUTPUT:.*]] = f32[128,128] all-gather-start(f32[128,32] %[[INPUT]]) +// CHECK-SAME: channel_id=1 +// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} +// CHECK-SAME: dimensions={1} +// CHECK-SAME: use_global_device_ids=true +// CHECK: ROOT {{.*}} f32[128,128] all-gather-done(f32[128,128] %[[OUTPUT]] + +// ----- + +// CHECK: HloModule +func.func @all_reduce_0(%arg0: tensor<10xf32>) -> tensor<10xf32> attributes {execution_thread = "main"} { + %0 = "mhlo.all_reduce"(%arg0) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 2 + >, + use_global_device_ids + } : (tensor<10xf32>) -> tensor<10xf32> + func.return %0 : tensor<10xf32> +} + +func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %0 = "mhlo.async_start"(%arg0) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>) -> !mhlo.async_bundle, tensor<10xf32>> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<10xf32>>) -> tensor<10xf32> + return %1 : tensor<10xf32> +} + +// CHECK: ENTRY +// CHECK: %[[INPUT:.*]] = f32[10] parameter(0) +// CHECK: %[[OUTPUT:.*]] = f32[10] all-reduce-start(f32[10] %[[INPUT]]) +// CHECK-SAME: channel_id=5 +// CHECK-SAME{LITERAL}: replica_groups={{0,2,4,6},{1,3,5,7}} +// CHECK-SAME: use_global_device_ids=true +// CHECK: ROOT {{.*}} f32[10] all-reduce-done(f32[10] %[[OUTPUT]] + +// ----- + +// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} +func.func @all_reduce_0(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) attributes {execution_thread = "main"} { + %0:2 = "mhlo.all_reduce"(%arg0, %arg1) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %max = mhlo.maximum %lhs, %rhs : tensor + "mhlo.return"(%max) : (tensor) -> () + }) + { + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 2 + >, + use_global_device_ids + } : (tensor<10xf32>, tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) + func.return %0#0, %0#1 : tensor<10xf32>, tensor<1xf32> +} + +func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<1xf32>) -> (tensor<10xf32>, tensor<1xf32>) { + %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_reduce_0, execution_thread = "main"} : (tensor<10xf32>, tensor<1xf32>) -> !mhlo.async_bundle,tensor<1xf32>>, tuple,tensor<1xf32>>> + %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle,tensor<1xf32>>, tuple,tensor<1xf32>>>) -> (tensor<10xf32>, tensor<1xf32>) + return %1#0, %1#1 : tensor<10xf32>, tensor<1xf32> +} + +// ----- + +// expected-error@-3 {{'mhlo.async_start' op can't be translated to XLA HLO}} +func.func @all_gather_0(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) attributes {execution_thread = "main"} { + %0:2 = "mhlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + channel_handle = #mhlo.channel_handle, + replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, + use_global_device_ids + } : (tensor<8x2xf32>, tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + func.return %0#0, %0#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +func.func @main(%arg0: tensor<8x2xf32>, %arg1: tensor<8x4xf32>) -> (tensor<8x2xf32>, tensor<8x4xf32>) { + %0 = "mhlo.async_start"(%arg0, %arg1) {called_computation = @all_gather_0, execution_thread = "main"} : (tensor<8x2xf32>, tensor<8x4xf32>) -> !mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>> + %1:2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle,tensor<8x4xf32>>, tuple,tensor<8x4xf32>>>) -> (tensor<8x2xf32>, tensor<8x4xf32>) + return %1#0, %1#1 : tensor<8x2xf32>, tensor<8x4xf32> +} + +// ----- + +// CHECK: HloModule +func.func @collective_permute_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { + %0 = "mhlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #mhlo.channel_handle + } : (tensor<128x32xf32>) -> tensor<128x32xf32> + func.return %0 : tensor<128x32xf32> +} + +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %0 = "mhlo.async_start"(%arg0) {called_computation = @collective_permute_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x32xf32>> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x32xf32>>) -> tensor<128x32xf32> + return %1 : tensor<128x32xf32> +} + +// CHECK: ENTRY +// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) +// CHECK: %[[OUTPUT:.*]] = f32[128,32] collective-permute-start(f32[128,32] %[[INPUT]]) +// CHECK-SAME: channel_id=1 +// CHECK-SAME{LITERAL}: source_target_pairs={{0,1},{1,2},{2,3}} +// CHECK: ROOT {{.*}} f32[128,32] collective-permute-done(f32[128,32] %[[OUTPUT]] + +// ----- + +// CHECK: HloModule +func.func @copy_0(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> attributes {execution_thread = "main"} { + %0 = "mhlo.copy"(%arg0) {cross_program_prefetch_index = 0 : i32} : (tensor<128x32xf32>) -> tensor<128x32xf32> + func.return %0 : tensor<128x32xf32> +} + +func.func @main(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %0 = "mhlo.async_start"(%arg0) {called_computation = @copy_0, execution_thread = "main"} : (tensor<128x32xf32>) -> !mhlo.async_bundle, tensor<128x32xf32>> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, tensor<128x32xf32>>) -> tensor<128x32xf32> + return %1 : tensor<128x32xf32> +} + +// CHECK: ENTRY +// CHECK: %[[INPUT:.*]] = f32[128,32] parameter(0) +// CHECK: %[[OUTPUT:.*]] = (f32[128,32], f32[128,32], u32[]) copy-start(f32[128,32] %[[INPUT]]) +// CHECK-SAME: cross_program_prefetch_index=0 +// CHECK: ROOT {{.*}} f32[128,32] copy-done((f32[128,32], f32[128,32], u32[]) %[[OUTPUT]] + +// ----- + +// CHECK: HloModule + +func.func @recv_0(%token: !mhlo.token) -> (!mhlo.token) attributes {execution_thread = "main"} { + %0 = "mhlo.recv"(%token) { + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 1 // Device to device channel + >, + is_host_transfer = false + } : (!mhlo.token) -> (!mhlo.token) + func.return %0 : !mhlo.token +} + +func.func @main(%token: !mhlo.token) -> (!mhlo.token) { + %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle> + %2 = "mhlo.async_done"(%0) : (!mhlo.async_bundle>) -> (!mhlo.token) + return %2 : !mhlo.token +} + +// CHECK: ENTRY +// CHECK: [[TOKEN:%.*]] = token[] parameter(0) +// CHECK: [[RECV:%.*]] = ((), u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5 +// CHECK: ((), token[]) recv-done(((), u32[], token[]) [[RECV]]), channel_id=5 + +// ----- + +// CHECK: HloModule +func.func @recv_0(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) attributes {execution_thread = "main"} { + %0:2 = "mhlo.recv"(%token) { + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 3 // Host to device channel + >, + is_host_transfer = true + } : (!mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) + func.return %0#0, %0#1 : tensor<3x4xi32>, !mhlo.token +} + +func.func @main(%token: !mhlo.token) -> (tensor<3x4xi32>, !mhlo.token) { + %0 = "mhlo.async_start"(%token) {called_computation = @recv_0, execution_thread = "main", mhlo.sharding = "{{maximal device=0}, {maximal device=0}, {maximal device=0}}"} : (!mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, tensor> + %1, %2 = "mhlo.async_done"(%0) {mhlo.sharding = "{{maximal device=0}, {maximal device=0}}"} : (!mhlo.async_bundle, !mhlo.token>, tensor>) -> (tensor<3x4xi32>, !mhlo.token) + return %1, %2 : tensor<3x4xi32>, !mhlo.token +} + +// CHECK: ENTRY +// CHECK: [[TOKEN:%.*]] = token[] parameter(0) +// CHECK: [[RECV:%.*]] = (s32[3,4], u32[], token[]) recv(token[] [[TOKEN]]), channel_id=5, is_host_transfer +// CHECK-SAME: sharding={ +// CHECK-SAME: {maximal device=0}, {maximal device=0}, {maximal device=0} +// CHECK-SAME: } +// CHECK: [[RECV_DONE:%.*]] = (s32[3,4], token[]) recv-done((s32[3,4], u32[], token[]) [[RECV]]), channel_id=5, is_host_transfer +// CHECK-SAME: sharding={ +// CHECK-SAME: {maximal device=0}, {maximal device=0} +// CHECK-SAME: } +// CHECK: [[TUPLE0:%.*]] = s32[3,4] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=0, sharding={maximal device=0} +// CHECK: [[TUPLE1:%.*]] = token[] get-tuple-element((s32[3,4], token[]) [[RECV_DONE]]), index=1, sharding={maximal device=0} +// CHECK: ROOT {{%.*}} = (s32[3,4], token[]) tuple(s32[3,4] [[TUPLE0]], token[] [[TUPLE1]]) + +// ----- + +// CHECK: HloModule +func.func @send_0(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { + %0 = "mhlo.send"(%arg, %token) { + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 2 // Device to host channel + >, + is_host_transfer = true + } : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.token + func.return %0 : !mhlo.token +} + +func.func @main(%arg: tensor<3x4xi32>, %token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.async_start"(%arg, %token) {called_computation = @send_0, execution_thread = "main"} : (tensor<3x4xi32>, !mhlo.token) -> !mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle, !mhlo.token>, !mhlo.token, tensor>) -> !mhlo.token + return %1 : !mhlo.token +} + +// CHECK: ENTRY +// CHECK: [[ARG:%.*]] = s32[3,4] parameter(0) +// CHECK: [[TOKEN:%.*]] = token[] parameter(1) +// CHECK: [[SEND:%.*]] = (s32[3,4], u32[], token[]) send(s32[3,4] [[ARG]], token[] [[TOKEN]]), channel_id=5, is_host_transfer +// CHECK: ROOT +// CHECK-SAME: token[] send-done((s32[3,4], u32[], token[]) [[SEND]]), channel_id=5 + +// ----- + +// CHECK: HloModule +func.func @send_0(%token: !mhlo.token) -> !mhlo.token attributes {execution_thread = "main"} { + %0 = "mhlo.send"(%token) { + channel_handle = #mhlo.channel_handle< + handle = 5, + type = 1 // Device to device channel + > + } : (!mhlo.token) -> !mhlo.token + func.return %0 : !mhlo.token +} + +func.func @main(%token: !mhlo.token) -> !mhlo.token { + %0 = "mhlo.async_start"(%token) {called_computation = @send_0, execution_thread = "main"} : (!mhlo.token) -> !mhlo.async_bundle> + %1 = "mhlo.async_done"(%0) : (!mhlo.async_bundle>) -> !mhlo.token + return %1 : !mhlo.token +} + +// CHECK: ENTRY +// CHECK: [[TOKEN:%.*]] = token[] parameter(0) +// CHECK: [[SEND:%.*]] = ((), u32[], token[]) send(() [[UNIT:%.*]], token[] [[TOKEN]]), channel_id=5 +// CHECK: ROOT +// CHECK-SAME: token[] send-done(((), u32[], token[]) [[SEND]]), channel_id=5 + +// ----- + +// CHECK: HloModule +// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] { +func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> + attributes {execution_thread = "thread"} { + %0 = "mhlo.custom_call"(%arg0) {call_target_name = "foo"} : (tensor<10xf32>) -> tensor<20xf32> + return %0 : tensor<20xf32> +} + +// CHECK: ENTRY +func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { + // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]) + // CHECK-SAME: calls=[[CALLED_COMPUTATION]] + %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread = "thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) + %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + // CHECK: ROOT %{{.*}} = (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) + %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> + return %2 : tensor<20xf32> +} + +// ----- + +// CHECK: HloModule +// CHECK: [[CALLED_COMPUTATION:%AsyncOp.*]] ([[ARG:.*]]: f32[10]) -> f32[20] { +func.func @AsyncOp(%arg0: tensor<10xf32>) -> tensor<20xf32> + attributes {execution_thread = "thread"} { + %1 = "mhlo.custom_call"(%arg0) {call_target_name = "bar"} : (tensor<10xf32>) -> tensor<20xf32> + // CHECK: custom-call + // CHECK-SAME: custom_call_target="bar" + return %1 : tensor<20xf32> +} + +// CHECK: ENTRY +func.func @main(%arg0: tensor<10xf32>) -> tensor<20xf32> { + // CHECK: %[[ARG0:.*]] = f32[10] parameter(0) + // CHECK: %[[START:.*]] = ((f32[10]), f32[20], s32[]) async-start(f32[10] %[[ARG0]]), async_execution_thread="thread", calls=[[CALLED_COMPUTATION]], + // CHECK: %[[UPDATE:.*]] = ((f32[10]), f32[20], s32[]) async-update(((f32[10]), f32[20], s32[]) %[[START]]) + // CHECK: ROOT + // CHECK-SAME: (f32[20]) async-done(((f32[10]), f32[20], s32[]) %[[UPDATE]]) + + %0 = "mhlo.async_start"(%arg0) {called_computation = @AsyncOp, execution_thread="thread"} : (tensor<10xf32>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + %1 = "mhlo.async_update"(%0) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> !mhlo.async_bundle>, tensor<20xf32>, tensor> + %2 = "mhlo.async_done"(%1) : (!mhlo.async_bundle>, tensor<20xf32>, tensor>) -> tensor<20xf32> + return %2 : tensor<20xf32> +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir index 049456bb09e6f7..6ad08374e5d2e6 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/module_attributes.mlir @@ -100,3 +100,45 @@ module @ModuleWithFrontendAttributes attributes { func.return %arg0 : tensor<1xf32> } } + + + +// ----- + +module attributes { +// CHECK: input_output_alias { +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 0 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: entries { +// CHECK-NEXT: output_shape_index: 1 +// CHECK-NEXT: parameter_number: 1 +// CHECK-NEXT: kind: MAY_ALIAS +// CHECK-NEXT: } +// CHECK-NEXT: } + mhlo.input_output_alias = [ + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 0 : i64 + }, + output_index = array + }, + { + alias = + { + kind = "may_alias", + parameter_index = array, + parameter_number = 1 : i64 + }, + output_index = array + } +] +} { + func.func @main(%arg0: tensor<1xf32>, %arg1: tensor<1xf32> ) -> (tensor<1xf32>, tensor<1xf32>) { + func.return %arg0, %arg1: tensor<1xf32>, tensor<1xf32> + } +} \ No newline at end of file diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir new file mode 100644 index 00000000000000..3663f927ae6876 --- /dev/null +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/while_free_vars.mlir @@ -0,0 +1,89 @@ +// RUN: xla-translate -split-input-file -mlir-hlo-to-hlo-text %s -o - | FileCheck %s + +// This test verifies that the correct shardings are added when a while loop +// has free variables. + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.7 (arg_tuple.8: (s32[], f32[4], s32[], s32[], f32[4])) -> (s32[], f32[4], s32[], s32[], f32[4]) { +// CHECK-NEXT: %arg_tuple.8 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-DAG: %get-tuple-element.12 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=3 +// CHECK-DAG: %get-tuple-element.13 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.8), index=4, sharding={devices=[4]<=[4]} +// CHECK-DAG: %add.14 = s32[] add(s32[] %get-tuple-element.9, s32[] %get-tuple-element.12) +// CHECK-DAG: %add.15 = f32[4] add(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.13) +// CHECK: ROOT %tuple.16 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %add.14, f32[4] %add.15, s32[] %get-tuple-element.11, s32[] %get-tuple-element.12, f32[4] %get-tuple-element.13) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} + +// CHECK: %region_1.17 (arg_tuple.18: (s32[], f32[4], s32[], s32[], f32[4])) -> pred[] { +// CHECK-NEXT: %arg_tuple.18 = (s32[], f32[4], s32[], s32[], f32[4]) parameter(0) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK: %get-tuple-element.21 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %arg_tuple.18), index=2 +// CHECK-NEXT: ROOT %compare.24 = pred[] compare(s32[] %get-tuple-element.19, s32[] %get-tuple-element.21), direction=LT + +// CHECK: ENTRY %main.28 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) +// CHECK-NEXT: %constant.4 = s32[] constant(0) +// CHECK-NEXT: %constant.5 = s32[] constant(1) +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: %tuple.6 = (s32[], f32[4], s32[], s32[], f32[4]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %constant.4, s32[] %constant.5, f32[4] %Arg_2.3) +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %while.25 = (s32[], f32[4], s32[], s32[], f32[4]) while((s32[], f32[4], s32[], s32[], f32[4]) %tuple.6), condition=%region_1.17, body=%region_0.7 +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.26 = s32[] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=0, sharding={replicated} +// CHECK-NEXT: ROOT %get-tuple-element.27 = f32[4] get-tuple-element((s32[], f32[4], s32[], s32[], f32[4]) %while.25), index=1, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> tensor<4xf32> { + %0 = mhlo.constant dense<0> : tensor + %1 = mhlo.constant dense<1> : tensor + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> + attributes {mhlo.sharding = "{{replicated},{devices=[2,2]<=[4] last_tile_dim_replicate}}"} + cond { + %3 = mhlo.compare LT, %iterArg, %0 : (tensor, tensor) -> tensor + mhlo.return %3 : tensor + } do { + %3 = mhlo.add %iterArg, %1 : tensor + %4 = mhlo.add %iterArg_0, %arg2 : tensor<4xf32> + mhlo.return %3, %4: tensor, tensor<4xf32> + } + func.return %2#1 : tensor<4xf32> +} + +// ----- + +// This test verifies that a value captured multiple times is only lifted once +// and all its uses are replaced. Also verifies that no sharding is added to +// region parameters or root when the while doesn't have a sharding. + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.5 (arg_tuple.6: (s32[], f32[4], s32[])) -> (s32[], f32[4], s32[]) { +// CHECK-NEXT: %arg_tuple.6 = (s32[], f32[4], s32[]) parameter(0) +// CHECK: %get-tuple-element.9 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.6), index=2 +// CHECK: %add.10 = s32[] add(s32[] %get-tuple-element.7, s32[] %get-tuple-element.9) +// CHECK: ROOT %tuple.11 = (s32[], f32[4], s32[]) tuple(s32[] %add.10, f32[4] %get-tuple-element.8, s32[] %get-tuple-element.9) + +// CHECK: %region_1.12 (arg_tuple.13: (s32[], f32[4], s32[])) -> pred[] { +// CHECK-NEXT: %arg_tuple.13 = (s32[], f32[4], s32[]) parameter(0) +// CHECK: %get-tuple-element.16 = s32[] get-tuple-element((s32[], f32[4], s32[]) %arg_tuple.13), index=2 +// CHECK: ROOT %compare.17 = pred[] compare(s32[] %get-tuple-element.14, s32[] %get-tuple-element.16), direction=LT + +// CHECK: ENTRY %main.21 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: s32[]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) +// CHECK-NEXT: %Arg_2.3 = s32[] parameter(2) +// CHECK-NEXT: %tuple.4 = (s32[], f32[4], s32[]) tuple(s32[] %Arg_0.1, f32[4] %Arg_1.2, s32[] %Arg_2.3) +// CHECK-NEXT: %while.18 = (s32[], f32[4], s32[]) while((s32[], f32[4], s32[]) %tuple.4), condition=%region_1.12, body=%region_0.5 + +func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<4xf32> { + %2:2 = mhlo.while(%iterArg = %arg0, %iterArg_0 = %arg1) : tensor, tensor<4xf32> + cond { + %3 = mhlo.compare LT, %iterArg, %arg2 : (tensor, tensor) -> tensor + mhlo.return %3 : tensor + } do { + %3 = mhlo.add %iterArg, %arg2 : tensor + mhlo.return %3, %iterArg_0: tensor, tensor<4xf32> + } + func.return %2#1 : tensor<4xf32> +} diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc index 8cff1c99592b1c..7c07582a46c794 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/translate.cc @@ -16,26 +16,42 @@ limitations under the License. #include #include +#include #include "absl/log/log.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/Value.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "xla/client/xla_builder.h" +#include "xla/client/xla_computation.h" +#include "xla/debug_options_flags.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/service/hlo.pb.h" #include "xla/service/hlo_module_config.h" #include "xla/service/hlo_proto_util.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "tsl/platform/errors.h" +#include "tsl/platform/statusor.h" constexpr char kParameterReplicationAttr[] = "mhlo.parameter_replication"; @@ -123,6 +139,8 @@ absl::Status ConvertMlirHloToHloViaBuilder( mlir::cast(b).getValue()); auto hlo_module = computation.proto(); + mlir::StringRef module_name = module.getName() ? *module.getName() : "main"; + hlo_module.set_name(module_name.str()); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); return absl::OkStatus(); diff --git a/third_party/xla/xla/tsl/c/tsl_status.cc b/third_party/xla/xla/tsl/c/tsl_status.cc index fea89436d4bafe..75b948129f2533 100644 --- a/third_party/xla/xla/tsl/c/tsl_status.cc +++ b/third_party/xla/xla/tsl/c/tsl_status.cc @@ -35,7 +35,7 @@ void TSL_SetStatus(TSL_Status* s, TSL_Code code, const char* msg) { return; } s->status = - Status(static_cast(code), tsl::StringPiece(msg)); + Status(static_cast(code), absl::string_view(msg)); } void TSL_SetPayload(TSL_Status* s, const char* key, const char* value) { diff --git a/third_party/xla/xla/tsl/concurrency/BUILD b/third_party/xla/xla/tsl/concurrency/BUILD index 0363d152edac7c..578b6ce20b512e 100644 --- a/third_party/xla/xla/tsl/concurrency/BUILD +++ b/third_party/xla/xla/tsl/concurrency/BUILD @@ -31,7 +31,6 @@ cc_library( "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:platform_port", ], ) @@ -72,6 +71,7 @@ tsl_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_benchmark", "@local_tsl//tsl/platform:test_main", ], ) diff --git a/third_party/xla/xla/tsl/concurrency/async_value.cc b/third_party/xla/xla/tsl/concurrency/async_value.cc index dd26e04438cfb0..fa3f0582e779ef 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value.cc @@ -63,12 +63,6 @@ AsyncValue::TypeInfoTable* AsyncValue::GetTypeInfoTableSingleton() { std::atomic AsyncValue::total_allocated_async_values_; -const AsyncValue::TypeInfo& AsyncValue::GetTypeInfo() const { - TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton(); - DCHECK_NE(type_id_, 0); - return (*type_info_table)[type_id_ - 1]; -} - // This is called when the value is set into the ConcreteAsyncValue buffer, or // when the IndirectAsyncValue is forwarded to an available AsyncValue, and we // need to change our state and clear out the notifications. The current state diff --git a/third_party/xla/xla/tsl/concurrency/async_value.h b/third_party/xla/xla/tsl/concurrency/async_value.h index 372db4f20e8539..30e0d8ee11ac90 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value.h +++ b/third_party/xla/xla/tsl/concurrency/async_value.h @@ -21,17 +21,18 @@ limitations under the License. #include #include #include -#include #include +#include #include #include +#include "absl/base/optimization.h" #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/types/span.h" #include "xla/tsl/concurrency/concurrent_vector.h" #include "xla/tsl/concurrency/ref_count.h" -#include "tsl/platform/mem.h" +#include "tsl/platform/logging.h" namespace tsl { @@ -164,8 +165,8 @@ class AsyncValue { // process. This is intended for debugging/assertions only, and shouldn't be // used for mainline logic in the runtime. static size_t GetNumAsyncValueInstances() { - assert(AsyncValueAllocationTrackingEnabled() && - "AsyncValue instance tracking disabled!"); + DCHECK(AsyncValueAllocationTrackingEnabled()) + << "AsyncValue instance tracking disabled!"; return total_allocated_async_values_.load(std::memory_order_relaxed); } @@ -418,8 +419,9 @@ class AsyncValue { private: // Information about a ConcreteAsyncValue subclass. struct TypeInfo { - // Destructor returns the size of the derived AsyncValue to be deallocated. - using DestructorFn = size_t (*)(AsyncValue*); + // Destructor returns the size and alignment of the derived AsyncValue to + // be deallocated. + using DestructorFn = std::pair (*)(AsyncValue*); using GetErrorFn = const absl::Status& (*)(const AsyncValue*); using SetErrorFn = void (*)(AsyncValue*, absl::Status); using HasDataFn = bool (*)(const AsyncValue*); @@ -433,9 +435,9 @@ class AsyncValue { template static TypeInfo MakeTypeInfo() { return TypeInfo{ - [](AsyncValue* v) { + [](AsyncValue* v) -> std::pair { static_cast(v)->~Derived(); - return sizeof(Derived); + return {sizeof(Derived), std::align_val_t{alignof(Derived)}}; }, [](const AsyncValue* v) -> const absl::Status& { return static_cast(v)->GetError(); @@ -454,14 +456,17 @@ class AsyncValue { template const T& GetConcreteValue() const; - // Get the TypeInfo instance for this AsyncValue. - const TypeInfo& GetTypeInfo() const; - - using TypeInfoTable = internal::ConcurrentVector; - // Returns the TypeInfoTable instance (there is one per process). + using TypeInfoTable = internal::ConcurrentVector; static TypeInfoTable* GetTypeInfoTableSingleton(); + // Get the TypeInfo instance for this AsyncValue. + const TypeInfo& GetTypeInfo() const { + TypeInfoTable* type_info_table = AsyncValue::GetTypeInfoTableSingleton(); + DCHECK_NE(type_id_, 0) << "TypeId must be set"; + return (*type_info_table)[type_id_ - 1]; + } + void EnqueueWaiter(absl::AnyInvocable waiter, WaitersAndState old_value); @@ -569,7 +574,7 @@ class ConcreteAsyncValue : public AsyncValue { // Return the underlying error. IsError() must return true. const absl::Status& GetError() const { - assert(IsError()); + DCHECK(IsError()); return data_store_.error(); } @@ -579,12 +584,12 @@ class ConcreteAsyncValue : public AsyncValue { } const T& get() const { - assert(HasData()); + DCHECK(HasData()); return data_store_.data(); } T& get() { - assert(HasData()); + DCHECK(HasData()); return data_store_.data(); } @@ -629,7 +634,7 @@ class ConcreteAsyncValue : public AsyncValue { } void SetError(State s, absl::Status status) { - assert(s == State::kUnconstructed || s == State::kConstructed); + DCHECK(s == State::kUnconstructed || s == State::kConstructed); if (s == State::kConstructed) { data_.~T(); } @@ -677,13 +682,13 @@ class ConcreteAsyncValue : public AsyncValue { } void SetError(State s, absl::Status status) { - assert(!error_); + DCHECK(!error_); error_ = std::make_unique(std::move(status)); } template void EmplaceData(Args&&... args) { - assert(!HasData()); + DCHECK(!HasData()); new (&data_) T(std::forward(args)...); has_data_ = true; } @@ -807,8 +812,8 @@ class TypedIndirectAsyncValue : public IndirectAsyncValue { }; inline AsyncValue::~AsyncValue() { - assert(waiters_and_state_.load().waiter() == nullptr && - "An async value with waiters should never have refcount of zero"); + DCHECK_EQ(waiters_and_state_.load().waiter(), nullptr) + << "An async value with waiters should never have refcount of zero"; if (AsyncValueAllocationTrackingEnabled() && is_refcounted_) total_allocated_async_values_.fetch_sub(1, std::memory_order_relaxed); @@ -853,7 +858,7 @@ inline AsyncValue* AsyncValue::AddRef(uint32_t count) { #endif if (count > 0) { - assert(refcount_.load(std::memory_order_relaxed) > 0); + DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0); // Increasing the reference counter can always be done with // memory_order_relaxed: New references to an object can only be formed from // an existing reference, and passing an existing reference from one thread @@ -871,7 +876,7 @@ inline void AsyncValue::DropRef(uint32_t count) { if (!is_refcounted_) return; #endif - assert(refcount_.load(std::memory_order_relaxed) > 0); + DCHECK_GT(refcount_.load(std::memory_order_relaxed), 0); // We expect that `count` argument will often equal the actual reference count // here; optimize for that. If `count` == reference count, only an acquire // barrier is needed to prevent the effects of the deletion from leaking @@ -894,8 +899,8 @@ template const T& AsyncValue::GetConcreteValue() const { // Make sure both T (the stored type) and BaseT have vtable_ptr or // neither have the vtable_ptr. - assert(std::is_polymorphic::value == has_vtable_); - assert(IsTypeIdCompatible() && "Incorrect accessor"); + DCHECK_EQ(std::is_polymorphic::value, has_vtable_); + DCHECK(IsTypeIdCompatible()) << "Incorrect accessor"; const char* this_ptr = reinterpret_cast(this); return *reinterpret_cast(this_ptr + AsyncValue::kDataOffset); @@ -909,32 +914,27 @@ const T& AsyncValue::get() const { switch (kind()) { case Kind::kConcrete: #ifndef NDEBUG - // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available. if (!GetTypeInfo().has_data(this)) { - std::cerr << "Cannot call get() when ConcreteAsyncValue" // Crash OK - << " isn't constructed; state: " << s.DebugString() << "," - << " error message: " - << (IsError() ? GetError().message() : "None"); - std::abort(); + LOG(FATAL) << "Cannot call get() when ConcreteAsyncValue" + << " isn't constructed; state: " << s.DebugString() << "," + << " error message: " + << (IsError() ? GetError().message() : "None"); } #endif // NDEBUG return GetConcreteValue(); case Kind::kIndirect: #ifndef NDEBUG - // TODO(ezhulenev): Use `DLOG_IF` when absl logging is available. if (s != State::kConcrete) { - std::cerr << "Cannot call get() when IndirectAsyncValue" // Crash OK - << " isn't concrete; state: " << s.DebugString() << "," - << " error message: " - << (IsError() ? GetError().message() : "None"); - std::abort(); + LOG(FATAL) << "Cannot call get() when IndirectAsyncValue" + << " isn't concrete; state: " << s.DebugString() << "," + << " error message: " + << (IsError() ? GetError().message() : "None"); } #endif // NDEBUG auto* iv_value = static_cast(this)->value_; - assert(iv_value && "Indirect value not resolved"); + DCHECK(iv_value) << "Indirect value not resolved"; return iv_value->get(); } - assert(false && "unexpected AsyncValue kind"); } template @@ -943,14 +943,14 @@ T& AsyncValue::get() { } inline void AsyncValue::SetStateConcrete() { - assert(IsConstructed() && kind() == Kind::kConcrete); + DCHECK(IsConstructed() && kind() == Kind::kConcrete); NotifyAvailable(State::kConcrete); } template void AsyncValue::emplace(Args&&... args) { - assert(GetTypeId() == type_id_ && "Incorrect accessor"); - assert(IsUnconstructed() && kind() == Kind::kConcrete); + DCHECK_EQ(GetTypeId(), type_id_) << "Incorrect accessor"; + DCHECK(IsUnconstructed() && kind() == Kind::kConcrete); static_cast*>(this)->emplace( std::forward(args)...); @@ -968,7 +968,7 @@ inline const absl::Status* AsyncValue::GetErrorIfPresent() const { // Unresolved IndirectAsyncValues are not errors. if (!iv_value) return nullptr; - assert(iv_value->kind() != Kind::kIndirect); + DCHECK(iv_value->kind() != Kind::kIndirect); return iv_value->GetErrorIfPresent(); } } @@ -976,7 +976,7 @@ inline const absl::Status* AsyncValue::GetErrorIfPresent() const { inline const absl::Status& AsyncValue::GetError() const { auto* result = GetErrorIfPresent(); - assert(result && "Cannot call GetError() when error isn't available."); + DCHECK(result) << "Cannot call GetError() when error isn't available."; return *result; } @@ -988,7 +988,7 @@ void AsyncValue::AndThen(Waiter&& waiter) { auto old_value = waiters_and_state_.load(std::memory_order_acquire); if (old_value.state() == State::kConcrete || old_value.state() == State::kError) { - assert(old_value.waiter() == nullptr); + DCHECK_EQ(old_value.waiter(), nullptr); waiter(); return; } @@ -1003,7 +1003,7 @@ void AsyncValue::AndThen(Executor& executor, Waiter&& waiter) { auto old_value = waiters_and_state_.load(std::memory_order_acquire); if (old_value.state() == State::kConcrete || old_value.state() == State::kError) { - assert(old_value.waiter() == nullptr); + DCHECK_EQ(old_value.waiter(), nullptr); executor.Execute(std::forward(waiter)); return; } @@ -1018,17 +1018,30 @@ inline void AsyncValue::Destroy() { // Copy `is_refcounted` flag before destroying the async value object. bool was_ref_counted = is_refcounted_; - if (kind() == Kind::kIndirect) { + if (ABSL_PREDICT_FALSE(kind() == Kind::kIndirect)) { // Depending on what the benchmarks say, it might make sense to remove this // explicit check and instead make ~IndirectAsyncValue go through the // GetTypeInfo().destructor case below. static_cast(this)->~IndirectAsyncValue(); - if (was_ref_counted) port::AlignedFree(this); + if (was_ref_counted) { +#if defined(__cpp_sized_deallocation) + ::operator delete(this, sizeof(IndirectAsyncValue), + std::align_val_t{alignof(IndirectAsyncValue)}); +#else // defined(__cpp_sized_deallocation) + ::operator delete(this, std::align_val_t{alignof(IndirectAsyncValue)}); +#endif // defined(__cpp_sized_deallocation) + } return; } - GetTypeInfo().destructor(this); - if (was_ref_counted) port::AlignedFree(this); + auto [size, alignment] = GetTypeInfo().destructor(this); + if (was_ref_counted) { +#if defined(__cpp_sized_deallocation) + ::operator delete(this, size, alignment); +#else // defined(__cpp_sized_deallocation) + ::operator delete(this, alignment); +#endif // defined(__cpp_sized_deallocation) + } } inline bool AsyncValue::IsUnique() const { diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref.h b/third_party/xla/xla/tsl/concurrency/async_value_ref.h index ca1f4133dad564..1065a7b5fcc3dc 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref.h +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -32,7 +33,6 @@ limitations under the License. #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/logging.h" -#include "tsl/platform/mem.h" namespace tsl { @@ -63,6 +63,56 @@ AsyncValueRef MakeConstructedAsyncValueRef(Args&&... args); template AsyncValueRef MakeAvailableAsyncValueRef(Args&&... args); +// A collection of type traits used by AsyncValueRef and AsyncValuePtr. +namespace internal { + +// Detects if a type is a specialization of an AsyncValueRef template. +template +struct IsAsyncValueRef : std::false_type {}; +template +struct IsAsyncValueRef> : std::true_type {}; + +template +inline constexpr bool is_async_value_ref_v = IsAsyncValueRef::value; + +// Detects types that are `absl::StatusOr` container. +template +struct IsStatusOr : std::false_type {}; +template +struct IsStatusOr> : std::true_type {}; + +// Type predicates for detecting absl::Status-like types. +template +static constexpr bool is_status_v = std::is_same_v; +template +static constexpr bool is_status_or_v = IsStatusOr::value; +template +static constexpr bool is_status_like_v = is_status_v || is_status_or_v; + +// Deduces the result type of invoking `F` with a first compatible `Arg`. +template +struct FirstInvokeResult { + template > + struct is_invocable : std::false_type { + using type = void; + }; + + template + struct is_invocable : std::true_type { + using type = std::invoke_result_t; + }; + + using type = typename std::disjunction...>::type; +}; + +// In contrast to `std::invoke_result_t` `Args` are not passed to `F` all +// together, but instead they are passed one-by-one, and the first valid one +// determines the result type. +template +using first_invoke_result_t = typename FirstInvokeResult::type; + +} // namespace internal + // AsyncValueRef is an asynchronous container for a payload of type `T` or an // error of type `absl::Status`. It is similar to an `absl::StatusOr`, but // does not require immediate value or error to be constructed. It is a promise @@ -88,8 +138,8 @@ class AsyncValueRef { AsyncValueRef(const AsyncValueRef&) = default; AsyncValueRef& operator=(const AsyncValueRef&) = default; - AsyncValueRef(AsyncValueRef&&) = default; - AsyncValueRef& operator=(AsyncValueRef&&) = default; + AsyncValueRef(AsyncValueRef&&) noexcept = default; + AsyncValueRef& operator=(AsyncValueRef&&) noexcept = default; explicit AsyncValueRef(RCReference value) : value_(std::move(value)) {} @@ -135,7 +185,7 @@ class AsyncValueRef { // Return true if the AsyncValue contains a concrete value. bool IsConcrete() const { return value_->IsConcrete(); } - // Return true if state is kUnconstructed. + // Return true if state is `kUnconstructed`. bool IsUnconstructed() const { return value_->IsUnconstructed(); } // Return the stored value. The AsyncValueRef must be available. @@ -295,6 +345,7 @@ class AsyncValueRef { return value_->SetError(std::move(status)); } + ABSL_DEPRECATED("Use SetError with absl::Status argument") void SetError(std::string_view message) const { SetError(absl::InternalError(message)); } @@ -335,35 +386,12 @@ class AsyncValueRef { RCReference value_; }; -// Detects if a type is a specialization of an AsyncValueRef template. -template -struct IsAsyncValueRef : std::false_type {}; -template -struct IsAsyncValueRef> : std::true_type {}; - -template -inline constexpr bool is_async_value_ref_v = IsAsyncValueRef::value; - // Non owning typed pointer for the AsyncValue. Can be cheaply passed around // when the lifetime of the underlying async value is clear from the context. // It is the user responsibility to construct an owning AsyncValueRef to extend // the lifetime of the underlying value if needed. template class AsyncValuePtr { - // Detect result types that are `absl::StatusOr` container. - template - struct IsStatusOr : std::false_type {}; - template - struct IsStatusOr> : std::true_type {}; - - // Type predicates for detecting absl::Status-like types. - template - static constexpr bool is_status_v = std::is_same_v; - template - static constexpr bool is_status_or_v = IsStatusOr::value; - template - static constexpr bool is_status_like_v = is_status_v || is_status_or_v; - // Wait for async value availability: AndThen([] {}) template using SimpleWaiter = std::enable_if_t>; @@ -383,26 +411,25 @@ class AsyncValuePtr { using StatusWaiter = std::enable_if_t<(std::is_invocable_v && !std::is_invocable_v> && - !is_status_v)>; + !internal::is_status_v)>; - // Because AsyncValue itself is a discriminated union of absl::Status and - // typed payload (error or value) the use of AsyncValueRef is - // discouraged (work in progress to disable with static assert) and `Map` - // automatically folds returned status-like object into the returned async - // value error. - - // Async value map functor: Map([](T& value) -> U); - // - R must be constructible from U - template + // Map async value of type `T` to an async value of type `R`. + template > using MapFunctor = std::enable_if_t>; - // Async value try map functor: TryMap([](T& value) -> absl::StatusOr); - // - R must be constructible from U - template + // Try map async value of type `T` to an async value of type `R`. + template > using TryMapFunctor = - std::enable_if_t && is_status_or_v && - std::is_constructible_v && - !std::is_constructible_v>; + std::enable_if_t && + std::is_constructible_v>; + + // Flat map async value of type `T` to an async value `R` (`R` itself is an + // async value ref). Returns `R` value type (async payload type). + template >> + using FlatMapFunctor = std::enable_if_t, + typename R::value_type>; public: // AsyncValuePtr::value_type @@ -423,7 +450,9 @@ class AsyncValuePtr { T& operator*() const { return get(); } explicit operator bool() const { return value_ != nullptr; } - bool operator!=(std::nullptr_t) const { return value_ != nullptr; } + bool operator==(const AsyncValuePtr& p) const { return value_ == p.value_; } + bool operator!=(const AsyncValuePtr& p) const { return value_ != p.value_; } + AsyncValuePtr& operator=(std::nullptr_t) { value_ = nullptr; return *this; @@ -591,8 +620,7 @@ class AsyncValuePtr { // return U(value); // R must be constructible from U // }) // - template , - MapFunctor* = nullptr> + template * = nullptr> AsyncValueRef Map(F&& f) { auto result = MakeUnconstructedAsyncValueRef(); AndThen([f = std::forward(f), result, ptr = *this]() mutable { @@ -606,8 +634,7 @@ class AsyncValuePtr { } // An overload that executes `f` on a user-provided executor. - template , - MapFunctor* = nullptr> + template * = nullptr> AsyncValueRef Map(AsyncValue::Executor& executor, F&& f) { auto result = MakeUnconstructedAsyncValueRef(); // We don't know when the executor will run the callback, so we need to @@ -637,8 +664,7 @@ class AsyncValuePtr { // // If returned status container will have an error status, it will be // automatically converted to async value error. - template , - TryMapFunctor* = nullptr> + template * = nullptr> AsyncValueRef TryMap(F&& f) { auto result = MakeUnconstructedAsyncValueRef(); AndThen([f = std::forward(f), result, ptr = *this]() mutable { @@ -657,8 +683,7 @@ class AsyncValuePtr { } // An overload that executes `f` on a user-provided executor. - template , - TryMapFunctor* = nullptr> + template * = nullptr> AsyncValueRef TryMap(AsyncValue::Executor& executor, F&& f) { auto result = MakeUnconstructedAsyncValueRef(); // We don't know when the executor will run the callback, so we need to @@ -694,7 +719,7 @@ class AsyncValuePtr { // A `TryMap` overload that automatically infers the type of result from `f`. template , - std::enable_if_t>* = nullptr> + std::enable_if_t>* = nullptr> auto TryMap(F&& f) { return TryMap(std::forward(f)); } @@ -702,12 +727,12 @@ class AsyncValuePtr { // A `TryMap` overload that automatically infers the type of result from `f` // and executes `f` on user-provided executor. template , - std::enable_if_t>* = nullptr> + std::enable_if_t>* = nullptr> auto TryMap(AsyncValue::Executor& executor, F&& f) { return TryMap(executor, std::forward(f)); } - // Returns and AsyncValueRef that will be forwarded to the AsyncValueRef + // Returns an AsyncValueRef that will be forwarded to the AsyncValueRef // returned from a functor. // // Sample usage: @@ -716,14 +741,25 @@ class AsyncValuePtr { // return LaunchAsyncTask(value); // }) // - template , - std::enable_if_t>* = nullptr> - AsyncValueRef FlatMap(F&& f) { + // Functor argument can be a `T&` or an `AsyncValueRef`, where async value + // pointer is guaranteed to be in concrete state. Async value pointer allows + // the functor to extend the lifetime of underlying async value if needed. + // + // async_value_ptr.FlatMap([](AsyncValuePtr ptr) -> AsyncValueRef { + // return LaunchAsyncTask([ref = ptr.CopyRef()] { ... }); + // }) + // + template > + AsyncValueRef FlatMap(F&& f) { // If async value is in concrete state, we can immediately call the functor. // We don't handle errors here and prefer a generic code path below because // error handling is never on a performance critical path. if (ABSL_PREDICT_TRUE(IsConcrete())) { - return f(get()); + if constexpr (std::is_invocable_v) { + return f(get()); + } else { + return f(*this); + } } auto promise = MakePromise(); @@ -731,17 +767,19 @@ class AsyncValuePtr { if (ABSL_PREDICT_FALSE(ptr.IsError())) { promise->SetError(ptr.GetError()); } else { - promise->ForwardTo(f(*ptr)); + if constexpr (std::is_invocable_v) { + promise->ForwardTo(f(*ptr)); + } else { + promise->ForwardTo(f(ptr)); + } } }); - return AsyncValueRef(promise); + return AsyncValueRef(promise); } // An overload that executes `f` on a user-provided executor. - template , - std::enable_if_t>* = nullptr> - AsyncValueRef FlatMap(AsyncValue::Executor& executor, - F&& f) { + template > + AsyncValueRef FlatMap(AsyncValue::Executor& executor, F&& f) { // We don't have a special handling for concrete values here because // we must execute user functor on a separate executor and can't call it in // the caller thread. @@ -753,10 +791,14 @@ class AsyncValuePtr { if (ABSL_PREDICT_FALSE(ref.IsError())) { promise->SetError(ref.GetError()); } else { - promise->ForwardTo(f(*ref)); + if constexpr (std::is_invocable_v) { + promise->ForwardTo(f(*ref)); + } else { + promise->ForwardTo(f(ref.AsPtr())); + } } }); - return AsyncValueRef(promise); + return AsyncValueRef(promise); } private: @@ -765,8 +807,8 @@ class AsyncValuePtr { // types and this will be a run time error. template RCReference MakePromise() { - if constexpr (std::is_final_v) { - return MakeIndirectAsyncValue(); + if constexpr (std::is_final_v) { + return MakeIndirectAsyncValue(); } else { return MakeIndirectAsyncValue(); }; @@ -874,7 +916,7 @@ T* PlacementConstruct(void* buf, Args&&... args) { template T* AllocateAndConstruct(Args&&... args) { - void* buf = port::AlignedMalloc(sizeof(T), alignof(T)); + void* buf = ::operator new(sizeof(T), std::align_val_t{alignof(T)}); return PlacementConstruct(buf, std::forward(args)...); } @@ -916,6 +958,47 @@ AsyncValueRef MakeAvailableAsyncValueRef(Args&&... args) { std::forward(args)...))); } +// Allocates an AsyncValueRef that is constructed from the result of calling an +// `f` on a user-provided `executor`. +// +// Sample usage: +// +// MakeAsyncValueRef(executor, []() -> int32_t { ... }); +// +template , + std::enable_if_t>* = nullptr> +AsyncValueRef MakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) { + auto result = MakeUnconstructedAsyncValueRef(); + executor.Execute([result, f = std::forward(f)] { result.emplace(f()); }); + return result; +} + +// Allocates an AsyncValueRef that is constructed from the result of calling an +// `f` on a user-provided `executor`. `F` must return an absl::StatusOr, and +// result of type `T` must be constructible from `U`. +// +// Sample usage: +// +// TryMakeAsyncValueRef(executor, +// []() -> absl::StatusOr { ... }); +// +template , + std::enable_if_t< + internal::is_status_or_v && + std::is_constructible_v>* = nullptr> +AsyncValueRef TryMakeAsyncValueRef(AsyncValue::Executor& executor, F&& f) { + auto result = MakeUnconstructedAsyncValueRef(); + executor.Execute([result, f = std::forward(f)] { + absl::StatusOr status_or = f(); + if (ABSL_PREDICT_TRUE(status_or.ok())) { + result.emplace(std::move(status_or).value()); + } else { + result.SetError(std::move(status_or).status()); + } + }); + return result; +} + //===----------------------------------------------------------------------===// // Constructing non-reference-counted values in user provided storage. //===----------------------------------------------------------------------===// @@ -951,13 +1034,13 @@ class AsyncValueOwningRef { AsyncValueOwningRef(const AsyncValueOwningRef&) = delete; AsyncValueOwningRef& operator=(const AsyncValueOwningRef&) = delete; - AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) { + AsyncValueOwningRef& operator=(AsyncValueOwningRef&& other) noexcept { Destroy(); std::swap(value_, other.value_); return *this; } - AsyncValueOwningRef(AsyncValueOwningRef&& other) { + AsyncValueOwningRef(AsyncValueOwningRef&& other) noexcept { Destroy(); std::swap(value_, other.value_); } diff --git a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc index 2c4ce86933dbf9..0cb4aad9b3b588 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_ref_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "xla/tsl/concurrency/async_value_ref.h" #include +#include #include #include #include @@ -30,6 +31,7 @@ limitations under the License. #include "xla/tsl/concurrency/async_value.h" #include "xla/tsl/concurrency/ref_count.h" #include "tsl/platform/test.h" +#include "tsl/platform/test_benchmark.h" namespace tsl { @@ -418,6 +420,44 @@ struct DeferredExecutor : public AsyncValue::Executor { std::vector tasks; }; +TEST(AsyncValueRefTest, MakeAsyncValueRef) { + DeferredExecutor executor; + + { // Make AsyncValueRef from a function that returns a value. + AsyncValueRef ref = + MakeAsyncValueRef(executor, []() -> float { return 42.0f; }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsAvailable()); + EXPECT_EQ(ref.get(), 42.0f); + } + + { // Make AsyncValueRef from a function that returns a StatusOr value. + AsyncValueRef ref = TryMakeAsyncValueRef( + executor, []() -> absl::StatusOr { return 42.0f; }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsAvailable()); + EXPECT_EQ(ref.get(), 42.0f); + } + + { // Make AsyncValueRef from a function that returns a StatusOr error. + AsyncValueRef ref = TryMakeAsyncValueRef( + executor, + []() -> absl::StatusOr { return absl::InternalError("test"); }); + + EXPECT_FALSE(ref.IsAvailable()); + EXPECT_EQ(executor.Quiesce(), 1); + + EXPECT_TRUE(ref.IsError()); + EXPECT_EQ(ref.GetError(), absl::InternalError("test")); + } +} + TEST(AsyncValueRefTest, MapAvailableOnExecutor) { AsyncValueRef ref = MakeAvailableAsyncValueRef(42); @@ -519,6 +559,52 @@ TEST(AsyncValueRefTest, FlatMapAvailableOnExecutor) { EXPECT_EQ(fmapped_to_float.get(), 42.0f); } +TEST(AsyncValueRefTest, FlatMapDeferredAsyncValueOnExecutor) { + DeferredExecutor executor0; + DeferredExecutor executor1; + + // Use non-copyable std::unique_ptr to make sure that we don't + // accidentally copy the value into the FlatMap functor. + + { // Use a regular FlatMap. + AsyncValueRef fmapped_to_float = + MakeAsyncValueRef>(executor0, [] { + return std::make_unique(42); + }).FlatMap([&](AsyncValuePtr> ptr) { + return MakeAsyncValueRef( + executor1, [ref = ptr.CopyRef()] { return **ref; }); + }); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor0.Quiesce(), 1); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor1.Quiesce(), 1); + + EXPECT_TRUE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(fmapped_to_float.get(), 42.0f); + } + + { // Use a FlatMap that itself executed on given executor. + AsyncValueRef fmapped_to_float = + MakeAsyncValueRef>(executor0, [] { + return std::make_unique(42); + }).FlatMap(executor1, [&](AsyncValuePtr> ptr) { + return MakeAsyncValueRef( + executor1, [ref = ptr.CopyRef()] { return **ref; }); + }); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor0.Quiesce(), 1); + + EXPECT_FALSE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(executor1.Quiesce(), 2); + + EXPECT_TRUE(fmapped_to_float.IsAvailable()); + EXPECT_EQ(fmapped_to_float.get(), 42.0f); + } +} + TEST(AsyncValueRefTest, BlockUntilReady) { AsyncValueRef ref = MakeAvailableAsyncValueRef(42); BlockUntilReady(ref); @@ -787,4 +873,25 @@ TEST(AsyncValueRefTest, RecursiveOwnership) { EXPECT_EQ(counter, 1 + 2 + 3); } +//===----------------------------------------------------------------------===// +// Performance benchmarks below +//===----------------------------------------------------------------------===// + +template +static void BM_MakeConstructed(benchmark::State& state) { + for (auto _ : state) { + auto ref = MakeConstructedAsyncValueRef>(); + benchmark::DoNotOptimize(ref); + } +} + +BENCHMARK(BM_MakeConstructed<1>); +BENCHMARK(BM_MakeConstructed<4>); +BENCHMARK(BM_MakeConstructed<8>); +BENCHMARK(BM_MakeConstructed<16>); +BENCHMARK(BM_MakeConstructed<32>); +BENCHMARK(BM_MakeConstructed<64>); +BENCHMARK(BM_MakeConstructed<128>); +BENCHMARK(BM_MakeConstructed<256>); + } // namespace tsl diff --git a/third_party/xla/xla/tsl/concurrency/async_value_test.cc b/third_party/xla/xla/tsl/concurrency/async_value_test.cc index f03034d5c67517..eb14685f37903f 100644 --- a/third_party/xla/xla/tsl/concurrency/async_value_test.cc +++ b/third_party/xla/xla/tsl/concurrency/async_value_test.cc @@ -132,7 +132,7 @@ TEST(AsyncValueTest, KeepPayloadOnError) { EXPECT_TRUE(!value.IsError()); - value.SetError("error"); + value.SetError(absl::InternalError("error")); EXPECT_EQ(1, *value->value); EXPECT_TRUE(value.IsError()); diff --git a/third_party/xla/xla/tsl/concurrency/ref_count.h b/third_party/xla/xla/tsl/concurrency/ref_count.h index 664dd95c4c486c..4ea65eeaff7917 100644 --- a/third_party/xla/xla/tsl/concurrency/ref_count.h +++ b/third_party/xla/xla/tsl/concurrency/ref_count.h @@ -124,7 +124,7 @@ class RCReference { public: RCReference() : pointer_(nullptr) {} - RCReference(RCReference&& other) : pointer_(other.pointer_) { + RCReference(RCReference&& other) noexcept : pointer_(other.pointer_) { other.pointer_ = nullptr; } @@ -132,7 +132,7 @@ class RCReference { if (pointer_) pointer_->AddRef(); } - RCReference& operator=(RCReference&& other) { + RCReference& operator=(RCReference&& other) noexcept { reset(other.pointer_); other.pointer_ = nullptr; return *this; @@ -187,7 +187,7 @@ class RCReference { explicit operator bool() const { return pointer_ != nullptr; } - void swap(RCReference& other) { + void swap(RCReference& other) noexcept { using std::swap; swap(pointer_, other.pointer_); } @@ -256,7 +256,7 @@ RCReference MakeRef(Args&&... args) { } // For ADL style swap. template -void swap(RCReference& a, RCReference& b) { +void swap(RCReference& a, RCReference& b) noexcept { a.swap(b); } diff --git a/third_party/xla/xla/tsl/cuda/BUILD.bazel b/third_party/xla/xla/tsl/cuda/BUILD.bazel index dabb8f5f4b11df..704e0b9c50e5d4 100644 --- a/third_party/xla/xla/tsl/cuda/BUILD.bazel +++ b/third_party/xla/xla/tsl/cuda/BUILD.bazel @@ -10,6 +10,10 @@ load( "cuda_rpath_flags", "if_cuda_is_configured", ) +load( + "//xla/tsl:tsl.bzl", + "if_hermetic_cuda_libs", +) load("//xla/tsl/cuda:stub.bzl", "cuda_stub") package( @@ -22,7 +26,7 @@ cuda_stub( ) cc_library( - name = "cublas", # buildifier: disable=duplicated-name + name = "cublas_stub", srcs = if_cuda_is_configured([ "cublas_stub.cc", "cublas.tramp.S", @@ -44,13 +48,19 @@ cc_library( ]), ) +alias( + name = "cublas", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublas", ":cublas_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cublasLt", srcs = ["cublasLt.symbols"], ) cc_library( - name = "cublas_lt", + name = "cublas_lt_stub", srcs = if_cuda_is_configured([ "cublasLt_stub.cc", "cublasLt.tramp.S", @@ -68,6 +78,12 @@ cc_library( ]), ) +alias( + name = "cublas_lt", + actual = if_hermetic_cuda_libs("@cuda_cublas//:cublasLt", ":cublas_lt_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cuda", srcs = ["cuda.symbols"], @@ -98,7 +114,7 @@ cuda_stub( ) cc_library( - name = "cudart", # buildifier: disable=duplicated-name + name = "cudart_stub", srcs = select({ # include dynamic loading implementation only when if_cuda_is_configured and build dynamically "@local_xla//xla/tsl:is_cuda_enabled_and_oss": [ @@ -129,13 +145,19 @@ cc_library( }), ) +alias( + name = "cudart", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudart//:cudart", ":cudart_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cudnn", srcs = ["cudnn.symbols"], ) cc_library( - name = "cudnn", # buildifier: disable=duplicated-name + name = "cudnn_stub", srcs = if_cuda_is_configured([ "cudnn_stub.cc", "cudnn.tramp.S", @@ -155,12 +177,24 @@ cc_library( ]), ) +alias( + name = "cudnn", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cudnn//:cudnn", ":cudnn_stub"), + visibility = ["//visibility:public"], +) + cc_library( - name = "nccl_rpath", + name = "nccl_rpath_flags", linkopts = if_cuda_is_configured(cuda_rpath_flags("nvidia/nccl/lib")), visibility = ["//visibility:public"], ) +alias( + name = "nccl_rpath", + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl_rpath_flags"), + visibility = ["//visibility:public"], +) + cc_library( name = "tensorrt_rpath", linkopts = if_cuda_is_configured(cuda_rpath_flags("tensorrt")), @@ -173,7 +207,7 @@ cuda_stub( ) cc_library( - name = "cufft", # buildifier: disable=duplicated-name + name = "cufft_stub", srcs = if_cuda_is_configured([ "cufft_stub.cc", "cufft.tramp.S", @@ -192,13 +226,19 @@ cc_library( ]), ) +alias( + name = "cufft", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cufft//:cufft", ":cufft_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cupti", srcs = ["cupti.symbols"], ) cc_library( - name = "cupti", # buildifier: disable=duplicated-name + name = "cupti_stub", srcs = if_cuda_is_configured([ "cupti_stub.cc", "cupti.tramp.S", @@ -219,13 +259,19 @@ cc_library( ]), ) +alias( + name = "cupti", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cupti//:cupti", ":cupti_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusolver", srcs = ["cusolver.symbols"], ) cc_library( - name = "cusolver", # buildifier: disable=duplicated-name + name = "cusolver_stub", srcs = if_cuda_is_configured([ "cusolver_stub.cc", "cusolver.tramp.S", @@ -244,13 +290,19 @@ cc_library( ]), ) +alias( + name = "cusolver", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusolver//:cusolver", ":cusolver_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "cusparse", srcs = ["cusparse.symbols"], ) cc_library( - name = "cusparse", # buildifier: disable=duplicated-name + name = "cusparse_stub", srcs = if_cuda_is_configured([ "cusparse_stub.cc", "cusparse.tramp.S", @@ -270,13 +322,19 @@ cc_library( ]), ) +alias( + name = "cusparse", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_cusparse//:cusparse", ":cusparse_stub"), + visibility = ["//visibility:public"], +) + cuda_stub( name = "nccl", srcs = ["nccl.symbols"], ) cc_library( - name = "nccl_stub", + name = "nccl", # buildifier: disable=duplicated-name srcs = if_cuda_is_configured([ "nccl_stub.cc", "nccl.tramp.S", @@ -296,3 +354,9 @@ cc_library( "@local_tsl//tsl/platform:load_library", ]), ) + +alias( + name = "nccl_stub", # buildifier: disable=duplicated-name + actual = if_hermetic_cuda_libs("@cuda_nccl//:nccl", ":nccl"), + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD index 4ebdd3e15904c9..30e3c5c32df348 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/BUILD @@ -13,11 +13,13 @@ package( cc_library( name = "coordination_service_error_util", + srcs = ["coordination_service_error_util.cc"], hdrs = ["coordination_service_error_util.h"], deps = [ "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:cord", + "@local_tsl//tsl/platform:regexp", "@local_tsl//tsl/protobuf:coordination_service_proto_cc", ], ) @@ -28,6 +30,7 @@ tsl_cc_test( deps = [ ":coordination_service_error_util", "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", @@ -75,6 +78,7 @@ tsl_gpu_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/functional:bind_front", "@com_google_absl//absl/hash", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -113,6 +117,7 @@ tsl_cc_test( ":coordination_service_impl", ":test_device_proto_cc", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -120,7 +125,6 @@ tsl_cc_test( "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:random", "@local_tsl//tsl/platform:status", @@ -167,12 +171,12 @@ tsl_cc_test( ":coordination_client", ":coordination_service_agent", "//xla/tsl/distributed_runtime:call_options", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/time", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:status", @@ -217,13 +221,13 @@ tsl_cc_test( "//xla/tsl/distributed_runtime/rpc:async_service_interface", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_client", "//xla/tsl/distributed_runtime/rpc/coordination:grpc_coordination_service_impl", + "//xla/tsl/lib/core:status_test_util", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/synchronization", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:status", diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc index c70dde5e12d3b3..e73985c668ed16 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc @@ -30,6 +30,7 @@ limitations under the License. #include "absl/base/thread_annotations.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/functional/bind_front.h" #include "absl/hash/hash.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -150,7 +151,15 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { const DeviceInfo& ListClusterDevices() override ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); uint64_t GetServiceIncarnation() override; - void StartCheckStaleness(); // Checks both heartbeat and barrier timeouts. + // Checks if any task has stopped sending heartbeats. + void CheckHeartbeatTimeout(); + // Checks if any barrier has timed out. + void CheckBarrierTimeout(); + // Checks both heartbeat and barrier timeouts. Use a single function so they + // can be run in the same thread as threads are a constrained resource. + void CheckStaleness(); + // Starts a thread to check staleness. + void StartCheckStaleness(); void Stop(bool shut_staleness_thread = true); bool ServiceHasStopped() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // Report service error to a specified task. @@ -179,6 +188,9 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { CoordinatedTaskEqual> tasks_at_barrier; std::vector done_callbacks; + // Specifies the task that initiated the barrier (the first task to call the + // barrier). + CoordinatedTask initiating_task; }; void PassBarrier(std::string_view barrier_id, absl::Status result, BarrierState* barrier) @@ -243,10 +255,6 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { void Disconnect(uint64_t grace_period_duration_us); absl::Status RecordHeartbeat(uint64_t task_incarnation); int64_t TimeSinceLastHeartbeatMs(); - // This denotes the deadline after which we stop accepting heartbeats from a - // disconnected task. This grace period accounts for the lag time between - // the service recording the state change and the agent stopping heartbeats. - uint64_t GetDisconnectedGracePeriodMicros(); void SetError(absl::Status status); DeviceInfo GetDeviceInfo() { return devices_; } void CollectDeviceInfo(const DeviceInfo& devices) { devices_ = devices; } @@ -257,6 +265,11 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::flat_hash_set GetOngoingBarriers(); void JoinBarrier(std::string_view barrier_id); void ExitBarrier(std::string_view barrier_id); + // Returns true if the task has been disconnected beyond the grace period + // and no further agent requests are expected. Note that the grace period + // accounts for the lag time between the service recording the state change + // and the agent stopping heartbeats/error polling. + bool IsDisconnectedBeyondGracePeriod(); private: // Incarnation ID for CPU:0 on remote task. @@ -266,9 +279,10 @@ class CoordinationServiceStandaloneImpl : public CoordinationServiceInterface { absl::Status status_; absl::Mutex last_heartbeat_mu_; uint64_t last_heartbeat_us_ ABSL_GUARDED_BY(last_heartbeat_mu_); - // This denotes the deadline after which we stop accepting heartbeats from a - // disconnected task. This grace period accounts for the lag time between - // the service recording the state change and the agent stopping heartbeats. + // This denotes the deadline after which we stop accepting heartbeats or + // error polling requests from a disconnected task. This grace period + // accounts for the lag time between the service recording the state change + // and the agent stopping heartbeats/error polling. uint64_t disconnect_grace_period_us_ = 0; DeviceInfo devices_; // For now, we assume there won't be many simultaneous barriers so we simply @@ -389,11 +403,6 @@ CoordinationServiceStandaloneImpl::TaskState::TimeSinceLastHeartbeatMs() { return (Env::Default()->NowMicros() - last_heartbeat_us_) / 1000; } -uint64_t CoordinationServiceStandaloneImpl::TaskState:: - GetDisconnectedGracePeriodMicros() { - return disconnect_grace_period_us_; -} - absl::flat_hash_set CoordinationServiceStandaloneImpl::TaskState::GetOngoingBarriers() { return ongoing_barriers_for_task_; @@ -409,6 +418,12 @@ void CoordinationServiceStandaloneImpl::TaskState::ExitBarrier( ongoing_barriers_for_task_.erase(barrier_id); } +bool CoordinationServiceStandaloneImpl::TaskState:: + IsDisconnectedBeyondGracePeriod() { + return GetState() == CoordinatedTaskState::TASKSTATE_DISCONNECTED && + Env::Default()->NowMicros() > disconnect_grace_period_us_; +} + void CoordinationServiceStandaloneImpl::SetDeviceAggregationFunction( std::function post_aggregate_device_fn) { @@ -441,119 +456,134 @@ CoordinationServiceStandaloneImpl::CoordinationServiceStandaloneImpl( StartCheckStaleness(); } -// Checks both heartbeat and barrier timeouts in the same thread, since threads -// are a constrained resource. -void CoordinationServiceStandaloneImpl::StartCheckStaleness() { - check_staleness_thread_.reset( - env_.StartThread({}, kHealthCheckThread, [this]() { - const bool has_service_to_client_connection = client_cache_ != nullptr; - // Used to store stale tasks and barriers. - std::vector stale_task_names; - absl::flat_hash_map expired_barriers; - while (true) { - { - absl::MutexLock l(&state_mu_); - check_staleness_thread_cv_.WaitWithTimeout(&state_mu_, - absl::Seconds(1)); - if (shutting_down_) { - return; - } - } - // Heartbeat check. - absl::Status status = absl::OkStatus(); - { - absl::MutexLock l(&state_mu_); - for (const auto& [task_name, task_state] : cluster_state_) { - // Skip tasks that are not registered or in error state - if (task_state->GetState() != - CoordinatedTaskState::TASKSTATE_CONNECTED) { - continue; - } - const bool is_stale = task_state->TimeSinceLastHeartbeatMs() > - heartbeat_timeout_ms_; - VLOG(10) << "Checking staleness for " << task_name - << " stale?=" << is_stale; - if (is_stale) { - stale_task_names.push_back(task_name); - status = MakeCoordinationError(absl::UnavailableError( - absl::StrCat("Task ", task_name, - " heartbeat timeout. This indicates that the " - "remote task has failed, got preempted, or " - "crashed unexpectedly. Check the task logs " - "for an earlier error to debug further."))); - SetTaskError(task_name, status); - } - } - } - // Propagate heartbeat timeout errors to other connected tasks. - if (!stale_task_names.empty()) { - if (!has_service_to_client_connection) { - absl::Status heartbeat_timeout_error = - MakeCoordinationError(absl::UnavailableError(absl::StrCat( - "The following tasks are unhealthy (stopped sending " - "heartbeats):\n", - absl::StrJoin(stale_task_names, "\n"), - "\nCheck the task logs for an earlier error to debug " - "further."))); - if (SendErrorPollingResponseOrStopService( - heartbeat_timeout_error)) { - return; - } - } else { - for (const auto& stale_task_name : stale_task_names) { - PropagateError(GetTaskFromName(stale_task_name)); - } - stale_task_names.clear(); - } - } - - // Barrier timeout check. - uint64_t current_time_micros = Env::Default()->NowMicros(); - { - absl::MutexLock l(&state_mu_); - // Gather barriers which have timed out. - for (std::string_view barrier_id : ongoing_barriers_) { - auto* barrier = &barriers_[barrier_id]; - if (current_time_micros > barrier->deadline_in_micros) { - expired_barriers[barrier_id] = barrier; - } - } - // Pass these barriers with the time out error. - for (const auto& [barrier_id, barrier] : expired_barriers) { - std::string pending_tasks; - int pending_task_count = 0; - for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { - if (!at_barrier) { - ++pending_task_count; - if (pending_task_count <= kPendingTaskLogLimit) { - absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); - } else { - break; - } - } - } - const absl::Status error = MakeCoordinationError( - absl::DeadlineExceededError(absl::StrCat( - "Barrier timed out. Barrier_id: ", barrier_id, - ". Timed out task names:\n", pending_tasks))); - PassBarrier(barrier_id, error, barrier); - } - } - if (!has_service_to_client_connection && - expired_barriers.contains(shutdown_barrier_id_)) { - // Error cannot be propagated through service-to-client connection. - // Note: we cannot destroy the thread within its own function. - // However, this thread will be destroyed once the function returns. - SendErrorPollingResponseOrStopService( - MakeCoordinationError(absl::DeadlineExceededError( - "Shutdown barrier timed out. Check the task logs for an " - "earlier error."))); - } +void CoordinationServiceStandaloneImpl::CheckHeartbeatTimeout() { + absl::Status status = absl::OkStatus(); + std::vector stale_task_names; + const bool has_service_to_client_connection = client_cache_ != nullptr; + { + absl::MutexLock l(&state_mu_); + for (const auto& [task_name, task_state] : cluster_state_) { + // Skip tasks that are not registered or in error state + if (task_state->GetState() != CoordinatedTaskState::TASKSTATE_CONNECTED) { + continue; + } + const bool is_stale = + task_state->TimeSinceLastHeartbeatMs() > heartbeat_timeout_ms_; + VLOG(10) << "Checking staleness for " << task_name + << " stale?=" << is_stale; + if (is_stale) { + stale_task_names.push_back(task_name); + status = MakeCoordinationError(absl::UnavailableError( + absl::StrCat("Task ", task_name, + " heartbeat timeout. This indicates that the " + "remote task has failed, got preempted, or " + "crashed unexpectedly. Check the task logs " + "for an earlier error to debug further."))); + SetTaskError(task_name, status); + } + } + } + // Propagate heartbeat timeout errors to other connected tasks. + if (!stale_task_names.empty()) { + if (!has_service_to_client_connection) { + absl::Status heartbeat_timeout_error = + MakeCoordinationError(absl::UnavailableError(absl::StrCat( + "The following tasks are unhealthy (stopped sending " + "heartbeats):\n", + absl::StrJoin(stale_task_names, "\n"), + "\nCheck the task logs for an earlier error to debug " + "further."))); + if (SendErrorPollingResponseOrStopService(heartbeat_timeout_error)) { + return; + } + } else { + for (const auto& stale_task_name : stale_task_names) { + PropagateError(GetTaskFromName(stale_task_name)); + } + } + } +} - // Reset this for the next barrier check. - expired_barriers.clear(); +void CoordinationServiceStandaloneImpl::CheckBarrierTimeout() { + const bool has_service_to_client_connection = client_cache_ != nullptr; + absl::flat_hash_map expired_barriers; + uint64_t current_time_micros = Env::Default()->NowMicros(); + { + absl::MutexLock l(&state_mu_); + // Gather barriers which have timed out. + for (std::string_view barrier_id : ongoing_barriers_) { + auto* barrier = &barriers_[barrier_id]; + if (current_time_micros > barrier->deadline_in_micros) { + expired_barriers[barrier_id] = barrier; + } + } + // Pass these barriers with the time out error. + for (const auto& [barrier_id, barrier] : expired_barriers) { + std::string pending_tasks; + int pending_task_count = 0; + for (const auto& [task, at_barrier] : barrier->tasks_at_barrier) { + if (at_barrier) { + continue; + } + ++pending_task_count; + if (pending_task_count > kPendingTaskLogLimit) { + break; } - })); + absl::StrAppend(&pending_tasks, GetTaskName(task), "\n"); + } + std::string error_message = absl::StrFormat( + "Barrier timed out. This usually happens because a task " + "triggered the barrier unexpectedly early, or some tasks are " + "too slow. Please look at the other task logs to debug " + "further. Barrier_id: %s. The first task at the barrier: " + "%s. ", + barrier_id, GetTaskName(barrier->initiating_task)); + if (pending_task_count > kPendingTaskLogLimit) { + absl::StrAppend( + &error_message, "Too many tasks have timed out. The first ", + kPendingTaskLogLimit, " timed out task names:\n", pending_tasks); + } else { + absl::StrAppend(&error_message, + "Total Number of tasks already at the barrier: ", + barrier->tasks_at_barrier.size() - pending_task_count, + "/", barrier->tasks_at_barrier.size(), + ". Timed out task names:\n%s", pending_tasks); + } + const absl::Status error = + MakeCoordinationError(absl::DeadlineExceededError(error_message)); + PassBarrier(barrier_id, error, barrier); + } + } + if (!has_service_to_client_connection && + expired_barriers.contains(shutdown_barrier_id_)) { + // Error cannot be propagated through service-to-client connection. + SendErrorPollingResponseOrStopService( + MakeCoordinationError(absl::DeadlineExceededError( + "Shutdown barrier timed out. Check the task logs for an " + "earlier error."))); + } +} + +void CoordinationServiceStandaloneImpl::CheckStaleness() { + // Used to store stale tasks and barriers. + while (true) { + { + absl::MutexLock l(&state_mu_); + check_staleness_thread_cv_.WaitWithTimeout(&state_mu_, absl::Seconds(1)); + if (shutting_down_) { + return; + } + } + CheckHeartbeatTimeout(); + CheckBarrierTimeout(); + } +} + +void CoordinationServiceStandaloneImpl::StartCheckStaleness() { + check_staleness_thread_.reset(env_.StartThread( + {}, kHealthCheckThread, + absl::bind_front(&CoordinationServiceStandaloneImpl::CheckStaleness, + this))); } void CoordinationServiceStandaloneImpl::Stop(bool shut_staleness_thread) { @@ -870,14 +900,10 @@ absl::Status CoordinationServiceStandaloneImpl::RecordHeartbeat( } if (!cluster_state_[task_name]->GetStatus().ok()) { return cluster_state_[task_name]->GetStatus(); - } else if (cluster_state_[task_name]->GetState() == - CoordinatedTaskState::TASKSTATE_DISCONNECTED && - // We accept heartbeats for a short grace period to account for - // the lag time between the service recording the state change - // and the agent stopping heartbeats. - Env::Default()->NowMicros() > - cluster_state_[task_name] - ->GetDisconnectedGracePeriodMicros()) { + } else if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) { + // We accept heartbeats for a short grace period to account for the lag + // time between the service recording the state change and the agent + // stopping heartbeats. return MakeCoordinationError(absl::InvalidArgumentError(absl::StrCat( "Task with task_name=", task_name, " must be registered before sending heartbeat messages"))); @@ -1173,11 +1199,26 @@ void CoordinationServiceStandaloneImpl::PollForErrorAsync( return; } - if (cluster_state_[task_name]->GetState() != - CoordinatedTaskState::TASKSTATE_CONNECTED) { - done(MakeCoordinationError(absl::InvalidArgumentError( + // On the agent side, the error polling thread will only be started when the + // task is connected, but by the time the request is processed by the service, + // the task state may have changed due to actions by the service or the main + // thread on the agent. As a way to handle this, we accept error polling for a + // short grace period. After the grace period, the service will return an + // error to the task. + if (cluster_state_[task_name]->IsDisconnectedBeyondGracePeriod()) { + done(MakeCoordinationError(absl::FailedPreconditionError( absl::StrCat("Task (", task_name, - ") that has not been registered polling for errors.")))); + ") that has not been registered or has disconnected " + "polling for errors.")))); + return; + } + + if (cluster_state_[task_name]->GetState() == + CoordinatedTaskState::TASKSTATE_ERROR) { + done(MakeCoordinationError(absl::FailedPreconditionError(absl::StrCat( + "Task (", task_name, + ") that is already in error state polling for errors. Current error: ", + cluster_state_[task_name]->GetStatus().ToString())))); return; } @@ -1248,6 +1289,7 @@ void CoordinationServiceStandaloneImpl::BarrierAsync( if (inserted) { // Initialize barrier state. barrier->passed = false; + barrier->initiating_task = task; // Assume barrier is for entire cluster if no tasks are specified. if (participating_tasks.empty()) { for (const auto& task_state : cluster_state_) { @@ -1450,9 +1492,11 @@ void CoordinationServiceStandaloneImpl::SendErrorPollingResponse( return; } } - LOG(ERROR) << "An error is encountered. Sending the error as a response to " - "all error polling requests: " - << error; + if (!absl::IsCancelled(error)) { + VLOG(2) << "An error is encountered. Sending the error as a response to " + "all error polling requests: " + << error; + } std::vector missing_tasks; { absl::MutexLock l(&state_mu_); diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc index 8bcf451987dc81..617da59e7c61e6 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc @@ -412,21 +412,19 @@ void CoordinationServiceAgentImpl::StartSendingHeartbeats() { } void CoordinationServiceAgentImpl::StartPollingForError() { - LOG(INFO) << "Polling error from coordination service. This thread " - "will run until an error is encountered or the agent is " - "shutdown."; + LOG(INFO) << "Polling for error from coordination service. This thread will " + "run until an error is encountered or the agent is shutdown."; absl::Status status = PollForError(); CHECK(!status.ok()) << "PollForError returned OK status. Should " "always return an error."; if (absl::IsCancelled(status)) { - LOG(INFO) << "Stop polling error from coordination service because " - "the service or the agent is shutting down." - << status; + LOG(INFO) << "Cancelling error polling because the service or the agent is " + "shutting down."; + // Return early and there is no need to set error. return; } - LOG(INFO) << "Error returned from coordination service after polling: " - << status; - + LOG(ERROR) << "An error is returned from coordination service (this can be " + "an error from this or another task)."; SetError(status); } @@ -440,10 +438,6 @@ absl::Status CoordinationServiceAgentImpl::PollForError() { n.WaitForNotification(); CHECK(!status.ok()) << "PollForError returned OK status. Should always return an error."; - LOG(ERROR) - << "PollForError returned with status (this can be an error from this or " - "another task): " - << status; return status; } @@ -628,7 +622,7 @@ absl::Status CoordinationServiceAgentImpl::ShutdownInternal() { } else { LOG(ERROR) << "Failed to disconnect from coordination service with status: " - << status + << TrimCoordinationErrorMessage(status) << "\nProceeding with agent shutdown anyway. This is usually caused " "by an earlier error during execution. Check the logs (this task " "or the leader) for an earlier error to debug further."; @@ -893,11 +887,12 @@ void CoordinationServiceAgentImpl::SetError(const absl::Status& error) { assert(!error.ok()); absl::MutexLock l(&state_mu_); if (state_ == CoordinatedTaskState::TASKSTATE_ERROR) return; + absl::Status trimmed_error = TrimCoordinationErrorMessage(error); - LOG(ERROR) << "Coordination agent is set to ERROR: " << error; + LOG(ERROR) << "Coordination agent is set to ERROR: " << trimmed_error; state_ = CoordinatedTaskState::TASKSTATE_ERROR; - status_ = error; - error_fn_(error); + status_ = trimmed_error; + error_fn_(trimmed_error); } absl::Status CoordinationServiceAgentImpl::ActivateWatch( diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc index 6348054527fdb8..1281ea8f78988f 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent_test.cc @@ -29,7 +29,7 @@ limitations under the License. #include "absl/time/time.h" #include "xla/tsl/distributed_runtime/call_options.h" #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" @@ -454,6 +454,58 @@ TEST_F(CoordinationServiceAgentTest, ConnectAfterReset_WithErrorPolling) { EXPECT_TRUE(agent_->IsError()); } +TEST_F(CoordinationServiceAgentTest, CancelledPollForErrorRequest) { + // Connect coordination agent. + PollForErrorResponse mocked_response; + EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _)) + .WillOnce(DoAll(SetArgPointee<2>(mocked_response), + InvokeArgument<3>(absl::CancelledError("Test Error.")))); + + CoordinationServiceConfig config; + config.set_poll_for_error_from_service_at_startup(true); + InitializeAgent(config); + TF_ASSERT_OK(agent_->Connect()); + // Wait a bit for the error polling thread to start. + absl::SleepFor(absl::Seconds(2)); + // Cancelled error polling request will not set agent to error. + ASSERT_FALSE(agent_->IsError()); +} + +TEST_F(CoordinationServiceAgentTest, InvalidPollForErrorRequest) { + // Connect coordination agent. + PollForErrorResponse mocked_response; + EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _)) + .WillOnce( + DoAll(SetArgPointee<2>(mocked_response), + InvokeArgument<3>(absl::InvalidArgumentError("Test Error.")))); + + CoordinationServiceConfig config; + config.set_poll_for_error_from_service_at_startup(true); + InitializeAgent(config); + TF_ASSERT_OK(agent_->Connect()); + // Wait a bit for the error polling thread to start. + absl::SleepFor(absl::Seconds(2)); + ASSERT_TRUE(agent_->IsError()); +} + +TEST_F(CoordinationServiceAgentTest, + PollForErrorRequestWithFailedPrecondition) { + // Connect coordination agent. + PollForErrorResponse mocked_response; + EXPECT_CALL(*GetClient(), PollForErrorAsync(_, _, _, _)) + .WillOnce(DoAll( + SetArgPointee<2>(mocked_response), + InvokeArgument<3>(absl::FailedPreconditionError("Test Error.")))); + + CoordinationServiceConfig config; + config.set_poll_for_error_from_service_at_startup(true); + InitializeAgent(config); + TF_ASSERT_OK(agent_->Connect()); + // Wait a bit for the error polling thread to start. + absl::SleepFor(absl::Seconds(2)); + ASSERT_TRUE(agent_->IsError()); +} + TEST_F(CoordinationServiceAgentTest, ResetCanBeRetried) { // Mock reset error failing for the first time. EXPECT_CALL(*GetClient(), ResetTaskAsync(_, _, _)) diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc new file mode 100644 index 00000000000000..8fc7631b458197 --- /dev/null +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.cc @@ -0,0 +1,75 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" + +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "tsl/platform/regexp.h" + +namespace tsl { +absl::Status TrimCoordinationErrorMessage(const absl::Status& s) { + if (s.ok()) { + return s; + } + auto status_message = std::string(s.message()); + auto additional_info_index = status_message.find("Additional GRPC"); + // This error didn't come from gRPC, so we don't need to trim it. + if (additional_info_index == std::string::npos) { + return s; + } + + std::optional payload = + s.GetPayload(CoordinationErrorPayloadKey()); + if (!payload.has_value() && absl::IsUnavailable(s)) { + // This error is not provided by us, so it's probably an RPC layer error. + auto prefix_message = + "Failed to send RPC to coordination service. Either the leader task " + "died/restarted unexpectedly or this task is experiencing network " + "issues. Check earlier logs from this task and the " + "leader (usually slice 0 process/task/worker 0) to debug further.\n"; + status_message = absl::StrCat( + prefix_message, + // Replace the duplicated error message at the start with the prefix. + status_message.substr(additional_info_index)); + } else { + // Extract RPC called. + std::string rpc_name; + // Note: it is unfortunate that we have to keep the tensorflow prefix + // because that's the RPC service proto namespace. + RE2::PartialMatch(status_message, + "(/tensorflow.CoordinationService/(\\w+))", &rpc_name); + // Erase duplicated error message. + status_message = status_message.substr(0, additional_info_index); + absl::StrAppend(&status_message, "\nRPC: ", rpc_name); + } + auto trimmed_status = absl::Status(s.code(), status_message); + // Reattach payload. + if (payload.has_value()) { + trimmed_status.SetPayload(CoordinationErrorPayloadKey(), *payload); + } +#if defined(PLATFORM_GOOGLE) + // Reattach source locations. + for (const auto& source_location : s.GetSourceLocations()) { + trimmed_status.AddSourceLocation(source_location); + } +#endif + return trimmed_status; +} +} // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h index 4555a4e90e3a97..e1a3cdc06eefe9 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h @@ -55,6 +55,14 @@ inline absl::Status MakeCoordinationError( absl::Cord(payload.SerializeAsString())); return s; } + +// Trims the error message by replacing the `Additional GRPC error` part. +// Note: The duplicated error message is a quirk of the underlying gRPC code +// that we are using. Changing the shared code may hide important messages for +// other libraries, so we trim the error message for coordination service +// instead. See tsl/distributed_runtime/rpc/grpc_state.h for more details. +absl::Status TrimCoordinationErrorMessage(const absl::Status& s); + } // namespace tsl #endif // XLA_TSL_DISTRIBUTED_RUNTIME_COORDINATION_COORDINATION_SERVICE_ERROR_UTIL_H_ diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc index 3c19fa5759a207..535f471f0a3fc1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_error_util_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/status/status.h" +#include "absl/strings/match.h" #include "tsl/platform/test.h" #include "tsl/protobuf/coordination_service.pb.h" namespace tsl { @@ -99,5 +100,54 @@ TEST(CoordinationServiceErrorUtil, MakeCoordinationErrorWithPayload) { EXPECT_EQ(actual_payload.is_reported_error(), payload.is_reported_error()); } +TEST(CoordinationServiceErrorUtil, + TrimCoordinationErrorMessage_CoordinationError) { + absl::Status error = MakeCoordinationError(absl::InternalError( + "Coordination service has stopped. RecordHeartbeat() from task: " + "/job:jax_worker/replica:0/task:2 failed. Additional GRPC error " + "information from remote target coordination_service while calling " + "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from " + "peer " + "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', " + "file_line:464, created_time:'2024-08-05T13:57:51.331198242-07:00', " + "grpc_status:13, grpc_message:'Coordination service has stopped. " + "RecordHeartbeat() from task: /job:jax_worker/replica:0/task:2 failed. " + "'} ")); + + absl::Status trimmed_error = TrimCoordinationErrorMessage(error); + EXPECT_EQ(trimmed_error.code(), error.code()); + EXPECT_EQ(trimmed_error.message(), + "Coordination service has stopped. RecordHeartbeat() from task: " + "/job:jax_worker/replica:0/task:2 failed. \nRPC: " + "/tensorflow.CoordinationService/Heartbeat"); + // Payload exists but has no value. + EXPECT_EQ(trimmed_error.GetPayload(CoordinationErrorPayloadKey()).value(), + ""); +} + +TEST(CoordinationServiceErrorUtil, TrimCoordinationErrorMessage_NetworkError) { + absl::Status error = absl::UnavailableError( + "failed to connect to all addresses; last error: UNKNOWN: " + "ipv4:127.0.0.1:10001: Failed to connect to remote host: Connection " + "refused. Additional GRPC error information from remote target " + "coordination_service while calling " + "/tensorflow.CoordinationService/Heartbeat::UNKNOWN:Error received from " + "peer " + "{file:'third_party/grpc/src/core/lib/surface/filter_stack_call.cc', " + "file_line:464, created_time:'2024-08-05T13:57:53.123562608-07:00', " + "grpc_status:14, grpc_message:'failed to connect to all addresses; last " + "error: UNKNOWN: ipv4:127.0.0.1:10001: Failed to connect to remote host: " + "Connection refused'} "); + + absl::Status trimmed_error = TrimCoordinationErrorMessage(error); + auto message = trimmed_error.message(); + EXPECT_EQ(trimmed_error.code(), error.code()); + EXPECT_TRUE(absl::StrContains(message, "Check earlier logs")); + // Message is not duplicated. + EXPECT_EQ(message.find("failed to connect"), + message.rfind("failed to connect")) + << trimmed_error; +} + } // namespace } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc index da40248891f372..3ec3290c9507e1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_recoverable_job_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/rpc/async_service_interface.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.h" #include "xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_service_impl.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/status.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc index 6133d19ef72380..2fa500109acc17 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/coordination/coordination_service_test.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/tsl/distributed_runtime/coordination/coordination_client.h" #include "xla/tsl/distributed_runtime/coordination/coordination_service_error_util.h" #include "xla/tsl/distributed_runtime/coordination/test_device.pb.h" -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/env.h" #include "tsl/platform/random.h" #include "tsl/platform/status.h" @@ -1169,9 +1169,15 @@ TEST_F(CoordinationBarrierTest, BarrierByNonClusterTask) { TEST_F(CoordinationBarrierTest, BarrierTimeout) { const std::string barrier_id = "barrier_id"; absl::Duration timeout = absl::Seconds(1); - absl::Status barrier_status_0; - absl::Notification n_0; + absl::Status barrier_status_0, barrier_status_1; + absl::Notification n_0, n_1; + GetCoordinationService()->BarrierAsync( + barrier_id, timeout, GetTask(1), + /*participating_tasks=*/{}, [&barrier_status_1, &n_1](absl::Status s) { + barrier_status_1 = s; + n_1.Notify(); + }); GetCoordinationService()->BarrierAsync( barrier_id, timeout, GetTask(0), /*participating_tasks=*/{}, [&barrier_status_0, &n_0](absl::Status s) { @@ -1181,13 +1187,21 @@ TEST_F(CoordinationBarrierTest, BarrierTimeout) { // Block until user-specified timeout. n_0.WaitForNotification(); + n_1.WaitForNotification(); + + // All barrier calls should fail with the same error. + EXPECT_EQ(barrier_status_0, barrier_status_1); EXPECT_TRUE(absl::IsDeadlineExceeded(barrier_status_0)); EXPECT_FALSE( absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(0)))); EXPECT_TRUE( - absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(1)))); - EXPECT_TRUE( - absl::StrContains(barrier_status_0.message(), GetTaskName(GetTask(2)))); + absl::StrContains(barrier_status_0.message(), + GetTaskName(GetTask(1)))); // First task at barrier. + EXPECT_TRUE(absl::StrContains(barrier_status_0.message(), + GetTaskName(GetTask(2)))); // Timed-out task. + EXPECT_TRUE(absl::StrContains( + barrier_status_0.message(), + "2/3")); // Number of tasks at barrier / total number of tasks. } TEST_F(CoordinationBarrierTest, BarrierReturnsPreviousError) { @@ -1809,10 +1823,77 @@ TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfTaskNotRegistered) { coord_service_->PollForErrorAsync( task_0_, [&](const absl::Status& status) { s = status; }); - EXPECT_THAT(s, StatusIs(absl::StatusCode::kInvalidArgument, + EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition, HasSubstr("has not been registered"))); } +TEST_F(CoordinateTwoTasksTest, + AllowPollForErrorWithinGracePeriodIfTaskHasShutDown) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + absl::Status s; + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + coord_service_->ShutdownTaskAsync(task_0_, + [&](const absl::Status& status) {}); + coord_service_->ShutdownTaskAsync(task_1_, + [&](const absl::Status& status) {}); + + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + // Stop the service. + coord_service_.reset(); + // The error polling request will still proceed because of grace period. It + // will be cancelled. + EXPECT_THAT(s, StatusIs(absl::StatusCode::kCancelled)); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfTaskHasShutDown) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + absl::Status s; + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->RegisterTask(task_1_, incarnation_1_)); + coord_service_->ShutdownTaskAsync(task_0_, + [&](const absl::Status& status) {}); + coord_service_->ShutdownTaskAsync(task_1_, + [&](const absl::Status& status) {}); + + // Sleep past the grace period. + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("has disconnected"))); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorAfterReset) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + absl::Status s; + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->ResetTask(task_0_)); + + // Sleep past the grace period. + Env::Default()->SleepForMicroseconds( + absl::ToInt64Microseconds(2 * kHeartbeatTimeout)); + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("has disconnected"))); +} + +TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorWhenInErrorState) { + EnableCoordinationService(/*has_service_to_client_connection=*/false); + absl::Status s; + ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); + ASSERT_OK(coord_service_->ReportTaskError(task_0_, + absl::InternalError("test_error"))); + + coord_service_->PollForErrorAsync( + task_0_, [&](const absl::Status& status) { s = status; }); + EXPECT_THAT(s, StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("test_error"))); +} + TEST_F(CoordinateTwoTasksTest, DoNotAllowPollForErrorIfServiceHasStopped) { EnableCoordinationService(/*has_service_to_client_connection=*/false); ASSERT_OK(coord_service_->RegisterTask(task_0_, incarnation_0_)); diff --git a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc index 45a7ddb1e8ff20..00845e5001b7ff 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc @@ -170,11 +170,7 @@ absl::Status PreemptionSyncManagerImpl::Initialize( call_opts_ = agent_->GetKeyValueAsync( kPreemptionNoticeKey, [this, agent = agent_](absl::StatusOr status_or_death_time) { - if (absl::IsCancelled(status_or_death_time.status()) || - // TODO(b/349613356): Investigate if we can always ensure that - // the RPC is cancelled before the server goes away, so we can - // differentiate between network failure and shutdown behaviour. - absl::IsUnavailable(status_or_death_time.status())) { + if (absl::IsCancelled(status_or_death_time.status())) { // The agent cancels pending GetKeyValue RPCs because of shutdown, // so simply log and return. LOG(INFO) << "Cancelled call to retrieve preemption notice. This is " diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD index dd3dcd24af4d68..6fe7b4064235f8 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/BUILD @@ -108,8 +108,8 @@ tsl_cc_test( ], deps = [ ":grpc_channel", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:device_name_utils", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:env_impl", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc index 4afe13f2c7960d..6bd7885d2cafb7 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/coordination/grpc_coordination_client.cc @@ -360,9 +360,7 @@ CoordinationClientCache* NewGrpcCoordinationClientCache( CoordinationClient* NewGrpcCoordinationClient( std::shared_ptr<::grpc::Channel> channel) { - // TODO(hanyangtay): Pass in the logical task name for better logging. - return new GrpcCoordinationClient( - channel, /*target=*/"unknown_target_for_coordination_leader"); + return new GrpcCoordinationClient(channel, /*target=*/"coordination_service"); } } // namespace tsl diff --git a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc index 806ea5494d90c8..80c976640fa6f1 100644 --- a/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc +++ b/third_party/xla/xla/tsl/distributed_runtime/rpc/grpc_channel_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" #include "tsl/protobuf/rpc_options.pb.h" diff --git a/third_party/xla/xla/tsl/framework/BUILD b/third_party/xla/xla/tsl/framework/BUILD index 4eeef3b0ad9900..52faa0be9359cf 100644 --- a/third_party/xla/xla/tsl/framework/BUILD +++ b/third_party/xla/xla/tsl/framework/BUILD @@ -16,7 +16,7 @@ load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.bzl", "if_windows", "internal_visibility") load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( @@ -194,6 +194,7 @@ cc_library( ":allocator", ":metrics", ":shared_counter", + "//xla/tsl/protobuf:bfc_memory_map_proto_cc", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -210,7 +211,6 @@ cc_library( "@local_tsl//tsl/platform:types", "@local_tsl//tsl/profiler/lib:scoped_memory_debug_annotation", "@local_tsl//tsl/profiler/lib:traceme", - "@local_tsl//tsl/protobuf:bfc_memory_map_proto_cc", ], ) @@ -358,7 +358,10 @@ cc_library( hdrs = [ "cancellation.h", ], - copts = ["-Wno-thread-safety-precise"], + copts = if_windows( + [], + ["-Wno-thread-safety-precise"], + ), visibility = ["//visibility:public"], deps = [ "@com_google_absl//absl/memory", @@ -462,8 +465,8 @@ tsl_cc_test( deps = [ ":device_id_impl", ":device_id_utils", + "//xla/tsl/lib/core:status_test_util", "//xla/tsl/util:device_name_utils", - "@local_tsl//tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test_main", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", diff --git a/third_party/xla/xla/tsl/framework/allocator.h b/third_party/xla/xla/tsl/framework/allocator.h index 29db454ec871b7..c289532c78a75e 100644 --- a/third_party/xla/xla/tsl/framework/allocator.h +++ b/third_party/xla/xla/tsl/framework/allocator.h @@ -146,6 +146,13 @@ class Allocator { // REQUIRES: "ptr" was previously returned by a call to AllocateRaw virtual void DeallocateRaw(void* ptr) = 0; + virtual void DeallocateRaw(void* ptr, size_t alignment, size_t num_bytes) { + (void)alignment; + (void)num_bytes; + + DeallocateRaw(ptr); + } + // Returns true if this allocator tracks the sizes of allocations. // RequestedSize and AllocatedSize must be overridden if // TracksAllocationSizes is overridden to return true. diff --git a/third_party/xla/xla/tsl/framework/bfc_allocator.cc b/third_party/xla/xla/tsl/framework/bfc_allocator.cc index c2d8f8b121f64f..a5f3401bbe86d4 100644 --- a/third_party/xla/xla/tsl/framework/bfc_allocator.cc +++ b/third_party/xla/xla/tsl/framework/bfc_allocator.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/tsl/framework/allocator_retry.h" +#include "xla/tsl/protobuf/bfc_memory_map.pb.h" #include "tsl/lib/core/bits.h" #include "tsl/platform/file_system.h" #include "tsl/platform/logging.h" @@ -33,7 +34,6 @@ limitations under the License. #include "tsl/platform/types.h" #include "tsl/profiler/lib/scoped_memory_debug_annotation.h" #include "tsl/profiler/lib/traceme.h" -#include "tsl/protobuf/bfc_memory_map.pb.h" namespace tsl { diff --git a/third_party/xla/xla/tsl/framework/cancellation.cc b/third_party/xla/xla/tsl/framework/cancellation.cc index d0a841f6ed59c5..7802eb926de59d 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.cc +++ b/third_party/xla/xla/tsl/framework/cancellation.cc @@ -103,7 +103,7 @@ bool CancellationManager::RegisterCallback(CancellationToken token, bool CancellationManager::RegisterCallbackWithErrorLogging( CancellationToken token, CancelCallback callback, - tsl::StringPiece callback_name) { + absl::string_view callback_name) { return RegisterCallbackConfig( token, CallbackConfiguration{callback, std::string(callback_name), true}); } diff --git a/third_party/xla/xla/tsl/framework/cancellation.h b/third_party/xla/xla/tsl/framework/cancellation.h index 56076c82270a51..38f7ebf60a63b2 100644 --- a/third_party/xla/xla/tsl/framework/cancellation.h +++ b/third_party/xla/xla/tsl/framework/cancellation.h @@ -135,7 +135,7 @@ class CancellationManager { // callback, which will be displayed on the log. bool RegisterCallbackWithErrorLogging(CancellationToken token, CancelCallback callback, - tsl::StringPiece callback_name); + absl::string_view callback_name); // Deregister the callback that, when registered, was associated // with the given cancellation token. Returns true iff the callback diff --git a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc index a9cbf0c4650ac6..9c9de966cfb67d 100644 --- a/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc +++ b/third_party/xla/xla/tsl/framework/cpu_allocator_impl.cc @@ -121,6 +121,17 @@ class CPUAllocator : public Allocator { port::AlignedFree(ptr); } + void DeallocateRaw(void* ptr, size_t alignment, size_t num_bytes) override { + if (cpu_allocator_collect_stats) { + const std::size_t alloc_size = + port::MallocExtension_GetAllocatedSize(ptr); + mutex_lock l(mu_); + stats_.bytes_in_use -= alloc_size; + AddTraceMe("MemoryDeallocation", ptr, 0, alloc_size); + } + port::AlignedSizedFree(ptr, alignment, num_bytes); + } + void AddTraceMe(absl::string_view traceme_name, const void* chunk_ptr, std::size_t req_bytes, std::size_t alloc_bytes) { tsl::profiler::TraceMe::InstantActivity( diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.cc b/third_party/xla/xla/tsl/framework/device_id_utils.cc index a751a3a9bad6c1..812b119c0201da 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils.cc @@ -29,12 +29,6 @@ limitations under the License. #include "tsl/platform/str_util.h" namespace tsl { -namespace { -int GetTfDeviceIdFromDeviceParsedName( - const DeviceNameUtils::ParsedName& device_name) { - return device_name.id; -} -} // namespace void CheckValidTfDeviceId(const DeviceType& type, const int visible_device_count, @@ -62,7 +56,7 @@ absl::Status ParseVisibleDeviceList( std::iota(visible_device_order->begin(), visible_device_order->end(), 0); } else { const std::vector order_str = - tsl::str_util::Split(visible_device_list, ','); + tsl::str_util::Split(visible_device_list, ','); // non-absl ok for (const std::string& platform_device_id_str : order_str) { int32_t platform_device_id; if (!absl::SimpleAtoi(platform_device_id_str, &platform_device_id)) { @@ -126,7 +120,7 @@ absl::StatusOr GetNumberTfDevicesAndConfigurePlatformDeviceId( absl::StatusOr GetPlatformDeviceIdFromDeviceParsedName( const DeviceNameUtils::ParsedName& device_name, const DeviceType& device_type) { - const TfDeviceId tf_device_id(GetTfDeviceIdFromDeviceParsedName(device_name)); + const TfDeviceId tf_device_id(GetDeviceIdFromDeviceParsedName(device_name)); PlatformDeviceId platform_device_id; absl::Status platform_id_status = DeviceIdManager::TfToPlatformDeviceId( device_type, tf_device_id, &platform_device_id); @@ -136,15 +130,10 @@ absl::StatusOr GetPlatformDeviceIdFromDeviceParsedName( return platform_id_status; } -absl::StatusOr GetDeviceIdFromDeviceParsedName( - const DeviceNameUtils::ParsedName& device_name, - const DeviceType& device_type) { - auto platform_id = - GetPlatformDeviceIdFromDeviceParsedName(device_name, device_type); - if (platform_id.ok()) { - return *platform_id; - } - return GetTfDeviceIdFromDeviceParsedName(device_name); +int GetDeviceIdFromDeviceParsedName( + const DeviceNameUtils::ParsedName& device_name) { + // This assumes that TF device ID is the same as PJRT local device ID. + return device_name.id; } } // namespace tsl diff --git a/third_party/xla/xla/tsl/framework/device_id_utils.h b/third_party/xla/xla/tsl/framework/device_id_utils.h index d25ae1cb7f3ef3..0da5969a189531 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils.h +++ b/third_party/xla/xla/tsl/framework/device_id_utils.h @@ -60,12 +60,9 @@ absl::StatusOr GetPlatformDeviceIdFromDeviceParsedName( const DeviceNameUtils::ParsedName& device_name, const DeviceType& device_type); -// TODO(b/293324740): support virtual devices. -// Returns the corresponding PlatformDeviceId if it is found. Otherwise returns -// the id in device_name. -absl::StatusOr GetDeviceIdFromDeviceParsedName( - const DeviceNameUtils::ParsedName& device_name, - const DeviceType& device_type); +// Returns the id in device_name. +int GetDeviceIdFromDeviceParsedName( + const DeviceNameUtils::ParsedName& device_name); } // namespace tsl diff --git a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc index 245097e01f80d4..e230d85a61cf51 100644 --- a/third_party/xla/xla/tsl/framework/device_id_utils_test.cc +++ b/third_party/xla/xla/tsl/framework/device_id_utils_test.cc @@ -18,8 +18,8 @@ limitations under the License. #include #include "xla/tsl/framework/device_id_manager.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/util/device_name_utils.h" -#include "tsl/lib/core/status_test_util.h" #include "tsl/platform/status_matchers.h" namespace tsl { @@ -182,11 +182,7 @@ TEST(DeviceIdUtilsTest, GetDeviceIdWithPlatformDeviceId) { DeviceNameUtils::ParsedName device_name; device_name.id = 0; - TF_ASSERT_OK_AND_ASSIGN(int device_id, - GetDeviceIdFromDeviceParsedName( - device_name, DeviceType(kTestDeviceType))); - - EXPECT_EQ(device_id, 1); + EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0); DeviceIdManager::TestOnlyReset(); } @@ -194,11 +190,7 @@ TEST(DeviceIdUtilsTest, GetDeviceIdWithoutPlatformDeviceId) { DeviceNameUtils::ParsedName device_name; device_name.id = 0; - TF_ASSERT_OK_AND_ASSIGN(int device_id, - GetDeviceIdFromDeviceParsedName( - device_name, DeviceType(kTestDeviceType))); - - EXPECT_EQ(device_id, 0); + EXPECT_EQ(GetDeviceIdFromDeviceParsedName(device_name), 0); } } // namespace diff --git a/third_party/xla/xla/tsl/lib/core/BUILD b/third_party/xla/xla/tsl/lib/core/BUILD new file mode 100644 index 00000000000000..49685b9b252f2b --- /dev/null +++ b/third_party/xla/xla/tsl/lib/core/BUILD @@ -0,0 +1,41 @@ +# Description: +# Tensor Standard Libraries. +# +# The libraries in this package are not allowed to have ANY dependencies +# to other TF components outside of TSL. + +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable") + +# TODO(rdzhabarov): Tighten visibility after migration is complete. +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +filegroup( + name = "legacy_lib_core_status_test_util_header", + srcs = [ + "status_test_util.h", + ], + compatible_with = get_compatible_with_portable(), + visibility = internal_visibility([ + "//tensorflow/core:__pkg__", + "@local_tsl//tsl/lib/core:__pkg__", + "//tensorflow/core/lib/core:__pkg__", + ]), +) + +cc_library( + name = "status_test_util", + testonly = 1, + hdrs = ["status_test_util.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@local_tsl//tsl/platform:status_matchers", + "@local_tsl//tsl/platform:test", + ], +) diff --git a/third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h b/third_party/xla/xla/tsl/lib/core/status_test_util.h similarity index 89% rename from third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h rename to third_party/xla/xla/tsl/lib/core/status_test_util.h index 56644ba71773c4..0c8f5d9d50e4ea 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/core/status_test_util.h +++ b/third_party/xla/xla/tsl/lib/core/status_test_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ -#define TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#ifndef XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#define XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ #include "tsl/platform/status_matchers.h" #include "tsl/platform/test.h" @@ -30,4 +30,4 @@ limitations under the License. // status matchers: // EXPECT_THAT(s, tensorflow::testing::StatusIs(status.code(), "message")); -#endif // TENSORFLOW_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ +#endif // XLA_TSL_LIB_CORE_STATUS_TEST_UTIL_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD b/third_party/xla/xla/tsl/lib/histogram/BUILD similarity index 62% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD rename to third_party/xla/xla/tsl/lib/histogram/BUILD index 4de34f8e390755..cbd206f6bd8083 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/BUILD +++ b/third_party/xla/xla/tsl/lib/histogram/BUILD @@ -1,13 +1,13 @@ load( - "@local_tsl//tsl/platform:rules_cc.bzl", - "cc_library", + "@local_tsl//tsl/platform:build_config.bzl", + "tsl_cc_test", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") load( - "//tsl/platform:build_config.bzl", - "tsl_cc_test", + "@local_tsl//tsl/platform:rules_cc.bzl", + "cc_library", ) +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -20,12 +20,12 @@ cc_library( hdrs = ["histogram.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:mutex", - "//tsl/platform:thread_annotations", - "//tsl/platform:types", - "//tsl/protobuf:histogram_proto_cc", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:mutex", + "@local_tsl//tsl/platform:thread_annotations", + "@local_tsl//tsl/platform:types", + "@local_tsl//tsl/protobuf:histogram_proto_cc", ], alwayslink = True, ) @@ -55,9 +55,9 @@ tsl_cc_test( ], deps = [ ":histogram", - "//tsl/platform:logging", - "//tsl/platform:test", - "//tsl/platform:test_main", - "//tsl/protobuf:histogram_proto_cc", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:test", + "@local_tsl//tsl/platform:test_main", + "@local_tsl//tsl/protobuf:histogram_proto_cc", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc b/third_party/xla/xla/tsl/lib/histogram/histogram.cc similarity index 99% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc rename to third_party/xla/xla/tsl/lib/histogram/histogram.cc index d6dc8aa4a5ab20..e8203549272547 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h b/third_party/xla/xla/tsl/lib/histogram/histogram.h similarity index 95% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h rename to third_party/xla/xla/tsl/lib/histogram/histogram.h index a024e2275b4d29..64b0cd188e7222 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram.h +++ b/third_party/xla/xla/tsl/lib/histogram/histogram.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ -#define TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#ifndef XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#define XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ #include #include @@ -121,7 +121,7 @@ class ThreadSafeHistogram { void Clear(); - // TODO(touts): It might be a good idea to provide a AddN() + // TODO(mdevin): It might be a good idea to provide a AddN() // method to avoid grabbing/releasing the lock when adding many values. void Add(double value); @@ -140,4 +140,4 @@ class ThreadSafeHistogram { } // namespace histogram } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ +#endif // XLA_TSL_LIB_HISTOGRAM_HISTOGRAM_H_ diff --git a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc rename to third_party/xla/xla/tsl/lib/histogram/histogram_test.cc index cda166f943d208..4051d98f49ab97 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/histogram/histogram_test.cc +++ b/third_party/xla/xla/tsl/lib/histogram/histogram_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/histogram/histogram.h" +#include "xla/tsl/lib/histogram/histogram.h" #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD b/third_party/xla/xla/tsl/lib/strings/BUILD similarity index 79% rename from third_party/xla/third_party/tsl/tsl/lib/strings/BUILD rename to third_party/xla/xla/tsl/lib/strings/BUILD index 699965e401c526..03f82a366f78c6 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/BUILD +++ b/third_party/xla/xla/tsl/lib/strings/BUILD @@ -2,8 +2,8 @@ load( "@local_tsl//tsl/platform:rules_cc.bzl", "cc_library", ) -load("@local_xla//xla/tsl:tsl.bzl", "internal_visibility") -load("@local_xla//xla/tsl:tsl.default.bzl", "filegroup") +load("//xla/tsl:tsl.bzl", "internal_visibility") +load("//xla/tsl:tsl.default.bzl", "filegroup") # copybara:uncomment package(default_applicable_licenses = ["//tensorflow:license"]) @@ -13,13 +13,13 @@ cc_library( hdrs = ["proto_serialization.h"], visibility = ["//visibility:public"], deps = [ - "//tsl/lib/gtl:inlined_vector", - "//tsl/platform:hash", - "//tsl/platform:logging", - "//tsl/platform:macros", - "//tsl/platform:protobuf", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@local_tsl//tsl/lib/gtl:inlined_vector", + "@local_tsl//tsl/platform:hash", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:macros", + "@local_tsl//tsl/platform:protobuf", ], ) diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc similarity index 98% rename from third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc rename to third_party/xla/xla/tsl/lib/strings/proto_serialization.cc index 139849e306a8b7..06ef0747ee553d 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.cc +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tsl/lib/strings/proto_serialization.h" +#include "xla/tsl/lib/strings/proto_serialization.h" #include #include diff --git a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h similarity index 92% rename from third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h rename to third_party/xla/xla/tsl/lib/strings/proto_serialization.h index 96a5c55f647694..b79e9aff6c21df 100644 --- a/third_party/xla/third_party/tsl/tsl/lib/strings/proto_serialization.h +++ b/third_party/xla/xla/tsl/lib/strings/proto_serialization.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ -#define TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#ifndef XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#define XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ #include "tsl/platform/protobuf.h" @@ -45,4 +45,4 @@ uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto, } // namespace tsl -#endif // TENSORFLOW_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ +#endif // XLA_TSL_LIB_STRINGS_PROTO_SERIALIZATION_H_ diff --git a/third_party/xla/xla/tsl/protobuf/BUILD b/third_party/xla/xla/tsl/protobuf/BUILD new file mode 100644 index 00000000000000..1a6ce0e4277571 --- /dev/null +++ b/third_party/xla/xla/tsl/protobuf/BUILD @@ -0,0 +1,27 @@ +load( + "@local_tsl//tsl/platform:build_config.bzl", + "tf_proto_library", +) +load( + "//xla/tsl:tsl.bzl", + "if_google", + "internal_visibility", +) + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = internal_visibility([ + "//tensorflow/core:__subpackages__", + "//xla/tsl:internal", + "//tensorflow_models:__subpackages__", + ]), + features = if_google(["-parse_headers"]), + licenses = ["notice"], +) + +tf_proto_library( + name = "bfc_memory_map_proto", + srcs = ["bfc_memory_map.proto"], + make_default_target_header_only = True, + visibility = ["//visibility:public"], +) diff --git a/third_party/xla/third_party/tsl/tsl/protobuf/bfc_memory_map.proto b/third_party/xla/xla/tsl/protobuf/bfc_memory_map.proto similarity index 100% rename from third_party/xla/third_party/tsl/tsl/protobuf/bfc_memory_map.proto rename to third_party/xla/xla/tsl/protobuf/bfc_memory_map.proto diff --git a/third_party/xla/xla/tsl/tsl.bzl b/third_party/xla/xla/tsl/tsl.bzl index 33571902eb052a..2882b9b96861f1 100644 --- a/third_party/xla/xla/tsl/tsl.bzl +++ b/third_party/xla/xla/tsl/tsl.bzl @@ -221,6 +221,17 @@ def if_with_tpu_support(if_true, if_false = []): "//conditions:default": if_false, }) +# These configs are used to determine whether we should use the hermetic CUDA +# tools in cc_libraries. +# They are intended for the OSS builds only. +def if_hermetic_cuda_tools(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we're building with hermetic CUDA tools.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + +def if_hermetic_cuda_libs(if_true, if_false = []): # buildifier: disable=unused-variable + """Shorthand for select()'ing on whether we need to include hermetic CUDA libraries.""" + return select({"@local_config_cuda//cuda:hermetic_cuda_tools_and_libs": if_true, "//conditions:default": if_false}) # copybara:comment_replace return if_false + def get_win_copts(is_external = False): WINDOWS_COPTS = [ # copybara:uncomment_begin(no MSVC flags in google) diff --git a/third_party/xla/xla/tsl/util/BUILD b/third_party/xla/xla/tsl/util/BUILD index fd0d1dc9412c4e..8ed6f8a05d6d80 100644 --- a/third_party/xla/xla/tsl/util/BUILD +++ b/third_party/xla/xla/tsl/util/BUILD @@ -273,7 +273,7 @@ tsl_cc_test( srcs = ["device_name_utils_test.cc"], deps = [ ":device_name_utils", - "@local_tsl//tsl/lib/core:status_test_util", + "//xla/tsl/lib/core:status_test_util", "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:strcat", "@local_tsl//tsl/platform:test", diff --git a/third_party/xla/xla/tsl/util/device_name_utils_test.cc b/third_party/xla/xla/tsl/util/device_name_utils_test.cc index 1f5f5114550d40..1457297599d74f 100644 --- a/third_party/xla/xla/tsl/util/device_name_utils_test.cc +++ b/third_party/xla/xla/tsl/util/device_name_utils_test.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tsl/lib/core/status_test_util.h" +#include "xla/tsl/lib/core/status_test_util.h" #include "tsl/platform/errors.h" #include "tsl/platform/strcat.h" #include "tsl/platform/test.h" diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 35e795b1680df8..64d092261b645e 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -28,7 +28,100 @@ message CompilationEnvironmentsProto { // Debugging options for XLA. These options may change at any time - there are // no guarantees about backward or forward compatibility for these fields. +// +// Debug options naming and organization: +// +// 1. Backend-agnostic options: `xla_$flag_name` - go first, and sorted +// alphabetically by the flag name. +// +// 2. Backend-specific options: `xla_$backend_$flag_name` - must be in the +// corresponding backend section, and sorted alphabetically by the flag name. +// message DebugOptions { + //--------------------------------------------------------------------------// + // XLA backend-agnostic options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:CPU options. + //--------------------------------------------------------------------------// + + // go/keep-sorted start newline_separated=yes + // + // When true, XLA:CPU uses HLO module scheduler that is optimized for + // extracting concurrency at the cost of extra memory: we extend the live + // ranges of temporaries to allow XLA runtime to schedule independent + // operations in parallel on separate threads. + bool xla_cpu_enable_concurrency_optimized_scheduler = 307; + + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf (this + // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to use the reciprocal of an argument instead of division. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_division = 126; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to approximate calculations for functions. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_functions = 129; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_infs = 121; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_nans = 120; + + // When true, XLA:CPU uses the thunk runtime to execute compiled program. + bool xla_cpu_use_thunk_runtime = 298; + + // A `prefer-vector-width` value that is passed to the LLVM backend. Default + // value is `256` (AVX2 on x86 platforms). + int32 xla_cpu_prefer_vector_width = 308; + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:GPU options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // XLA:TPU options. + //--------------------------------------------------------------------------// + // go/keep-sorted start + + // go/keep-sorted end + + //--------------------------------------------------------------------------// + // A bag of XLA options that have to be categorized. + //--------------------------------------------------------------------------// + // Show addresses of HLO ops in graph dump. bool xla_hlo_graph_addresses = 2; @@ -115,58 +208,9 @@ message DebugOptions { bool xla_cpu_use_mkl_dnn = 97; reserved 177; // Was xla_cpu_use_xla_runtime - bool xla_cpu_use_thunk_runtime = 298; - - // When true, XLA:CPU uses HLO module scheduler that is optimized for - // extracting concurrency at the cost of extra memory: we extend the live - // ranges of temporaries to allow XLA runtime to schedule independent - // operations in parallel on separate threads. - bool xla_cpu_enable_concurrency_optimized_scheduler = 307; - - // A `prefer-vector-width` value that is passed to the LLVM backend. Default - // value is `256` (AVX2 on x86 platforms). - int32 xla_cpu_prefer_vector_width = 308; reserved 98; // Was xla_gpu_max_kernel_unroll_factor - // When true, "unsafe" mathematical optimizations are enabled. These - // transformations include but are not limited to: - // - // - Reducing the precision of operations (e.g. using an approximate sin - // function, or transforming x/y into x * (1/y)). - // - Assuming that operations never produce or consume NaN or +/- Inf (this - // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). - // - Assuming that +0 and -0 are indistinguishable. - bool xla_cpu_enable_fast_math = 99; - - // When xla_cpu_enable_fast_math is true then this controls whether we allow - // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is - // false. - bool xla_cpu_fast_math_honor_nans = 120; - - // When xla_cpu_enable_fast_math is true then this controls whether we allow - // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is - // false. - bool xla_cpu_fast_math_honor_infs = 121; - - // When xla_cpu_enable_fast_math is true then this controls whether we forbid - // to use the reciprocal of an argument instead of division. Ignored when - // xla_cpu_enable_fast_math is false. - bool xla_cpu_fast_math_honor_division = 126; - - // When xla_cpu_enable_fast_math is true then this controls whether we forbid - // to approximate calculations for functions. Ignored when - // xla_cpu_enable_fast_math is false. - bool xla_cpu_fast_math_honor_functions = 129; - - // When false we lower the Minimum and Maximum hlos in the CPU backend such - // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag - // this is false we always propagate NaNs through Min and Max. - // - // Note, this does not correspond to the exact same behavior as the gpu flag - // below! - bool xla_cpu_enable_fast_min_max = 140; - // When true we lower the Minimum and Maximum hlos in the GPU backend such // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag // this is true we don't propagate NaNs through Min and Max. @@ -476,7 +520,7 @@ message DebugOptions { // Enables address computation fusion to optimize dynamic-slice and // dynamic-update-slice operations around library calls. - bool xla_gpu_enable_address_computation_fusion = 105; + bool xla_gpu_enable_dynamic_slice_fusion = 105; reserved 233; // was xla_gpu_enable_gpu2_runtime reserved 234; // was xla_gpu_enable_gpu2_hal @@ -749,7 +793,7 @@ message DebugOptions { // are counted. reserved 282; // was xla_gpu_skip_mlir_kernels - // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // Threshold to rewrite matmul to cuBLAS or Triton (minimum combined number of // elements of both matrices in non-batch dimensions to be considered for a // rewrite). int64 xla_gpu_gemm_rewrite_size_threshold = 283; @@ -806,10 +850,7 @@ message DebugOptions { // If true, Nccl errors will terminate the process. bool xla_gpu_nccl_terminate_on_error = 301; - // Use Shardy, a new partitioner, to replace the existing - // ShardingPropagation and SpmdPartitioner. See go/xla-sdy-pipeline for - // details. - bool xla_use_shardy = 302; + reserved 302; // was xla_use_shardy bool xla_gpu_shard_autotuning = 304; @@ -838,7 +879,7 @@ message DebugOptions { // Custom call targets with legacy registry API (non FFI API), // that support recording to command buffer custom command, - // i.e, custom call target supports cuda-graph capturing for CUDA devices. + // i.e., custom call target supports cuda-graph capturing for CUDA devices. // This flag is read if CUSTOM_CALL command type is recorded into // command buffer. repeated string legacy_command_buffer_custom_call_targets = 314; @@ -878,7 +919,22 @@ message DebugOptions { // TODO(b/355487968): Remove this option when validation complete. bool xla_enable_command_buffers_during_profiling = 317; - // Next id: 318 + // Limit for the number of kernel configurations (plans) to use during + // autotuning of cuDNN GEMM fusions. The more - the slower the autotuning + // but potentially higher the performance. + int32 xla_gpu_cudnn_gemm_max_plans = 318; + + // If enabled, uses the libnvjitlink library for PTX compilation and linking + bool xla_gpu_enable_libnvjitlink = 319; + + // If enabled, generates triton gemm kernels for int4 inputs. + bool xla_gpu_enable_triton_gemm_int4 = 320; + + // If true, XLA will wrap `dot` operations into async computations in an + // effort to parallelize matrix operations. + bool xla_gpu_async_dot = 321; + + // Next id: 322 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 4c7d47b1bf66b9..335b59e2064d23 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -661,6 +661,13 @@ message GatherDimensionNumbers { // The dimension in the start_indices input that contains the starting // indices. int64 index_vector_dim = 4; + + // This is the batch dimensions in the operand. + repeated int64 operand_batching_dims = 5; + + // This is the batch dimensions in the index, and it should be the same size + // as operand_batching_dims. + repeated int64 start_indices_batching_dims = 6; } // Describes the dimension numbers for a scatter operation. @@ -675,6 +682,12 @@ message ScatterDimensionNumbers { repeated int64 scatter_dims_to_operand_dims = 3; int64 index_vector_dim = 4; + + // This is the batch dimension in the input. + repeated int64 input_batching_dims = 5; + + // This is the batch dimension in the index. + repeated int64 scatter_indices_batching_dims = 6; } message ConvolutionDimensionNumbers { @@ -1087,3 +1100,13 @@ message OutputOperandAliasing { int64 operand_index = 2; repeated int64 operand_shape_index = 3; } + +message OriginalArrayProto { + repeated int64 leaf_shape_index = 1; + string instruction_name = 2; + repeated int64 shape_index = 3; +} + +message OriginalValueProto { + repeated OriginalArrayProto leaves = 1; +}